Skip to content

Commit 6e63a53

Browse files
committed
Support DP+TP+EP hybrid parallel deployment strategy
1 parent 9ff2dfb commit 6e63a53

18 files changed

+225
-196
lines changed

custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,16 @@
4343
__VA_ARGS__ \
4444
break; \
4545
} \
46-
case 48: { \
47-
constexpr size_t NUM_EXPERTS_PER_RANK = 48; \
48-
__VA_ARGS__ \
49-
break; \
50-
} \
46+
case 32: { \
47+
constexpr size_t NUM_EXPERTS_PER_RANK = 32; \
48+
__VA_ARGS__ \
49+
break; \
50+
} \
51+
case 48: { \
52+
constexpr size_t NUM_EXPERTS_PER_RANK = 48; \
53+
__VA_ARGS__ \
54+
break; \
55+
} \
5156
case 64: { \
5257
constexpr size_t NUM_EXPERTS_PER_RANK = 64; \
5358
__VA_ARGS__ \

custom_ops/gpu_ops/save_with_output_msg.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ void SaveOutMmsg(const paddle::Tensor& x,
105105
int64_t rank_id,
106106
int msg_queue_id,
107107
bool save_each_rank) {
108-
if (!save_each_rank && rank_id > 0) {
108+
// don't use save_each_rank now!
109+
if (rank_id > 0) {
109110
return;
110111
}
111112
if (x.place() == paddle::CPUPlace()) {

fastdeploy/config.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from enum import Enum
2323
from typing import Literal, Optional, Union
2424

25+
import paddle.distributed as dist
2526
from paddleformers.transformers.configuration_utils import PretrainedConfig
2627

2728
import fastdeploy
@@ -276,7 +277,10 @@ def __init__(
276277
setattr(self, key, value)
277278

278279
# currently, the expert parallel size is equal data parallel size
279-
self.expert_parallel_size = self.data_parallel_size
280+
if self.enable_expert_parallel:
281+
self.expert_parallel_size = self.data_parallel_size * self.tensor_parallel_size
282+
else:
283+
self.expert_parallel_size = 1
280284
self.use_ep = self.expert_parallel_size > 1
281285
if self.splitwise_role == "mixed":
282286
self.moe_phase = MoEPhase(phase="prefill")
@@ -297,6 +301,22 @@ def __init__(
297301
else:
298302
self.pd_disaggregation_mode = "None"
299303

304+
def set_tp_group(self):
305+
# different tp group id
306+
# prevent different tp_groups using the same group_id
307+
dist.collective._set_custom_gid(self.data_parallel_rank + 100)
308+
self.tp_group = dist.new_group(
309+
range(
310+
self.data_parallel_rank * self.tensor_parallel_size,
311+
(self.data_parallel_rank + 1) * self.tensor_parallel_size,
312+
)
313+
)
314+
# same ep group id
315+
dist.collective._set_custom_gid(self.data_parallel_size + 100)
316+
logger.info(
317+
f"data_parallel_size: {self.data_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}, expert_parallel_size: {self.expert_parallel_size}, data_parallel_rank: {self.data_parallel_rank}, tensor_parallel_rank: {self.tensor_parallel_rank}, expert_parallel_rank: {self.expert_parallel_rank}, tp_group: {self.tp_group}."
318+
)
319+
300320
def print(self):
301321
"""
302322
print all config

fastdeploy/distributed/communication.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,20 @@ def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024):
4747
@paddle.jit.marker.unified
4848
def tensor_model_parallel_all_reduce(
4949
input_: paddle.Tensor,
50+
group_: paddle.distributed.communication.group.Group = None,
5051
) -> paddle.Tensor:
5152
"""All-reduce the input tensor across model parallel group."""
5253
global _TP_AR
5354
if _TP_AR is not None and _TP_AR.should_custom_ar(input_):
55+
# TODO: supports different_group custom allreduce
5456
_TP_AR.custom_all_reduce(input_)
5557
elif paddle.in_dynamic_mode():
56-
hcg = fleet.get_hybrid_communicate_group()
57-
mp_group = hcg.get_model_parallel_group()
58-
dist.all_reduce(input_, group=mp_group)
58+
if group_ is not None:
59+
dist.all_reduce(input_, group=group_)
60+
else:
61+
hcg = fleet.get_hybrid_communicate_group()
62+
mp_group = hcg.get_model_parallel_group()
63+
dist.all_reduce(input_, group=mp_group)
5964
else:
6065
dist.all_reduce(input_)
6166

fastdeploy/engine/config.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,7 @@ def __init__(
192192
if self.enable_mm:
193193
self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化
194194

195-
# TODO(@wufeisheng): TP and EP need to be supported simultaneously.
196-
assert (self.tensor_parallel_size == 1 and self.parallel_config.expert_parallel_size >= 1) or (
197-
self.tensor_parallel_size >= 1 and self.parallel_config.expert_parallel_size == 1
198-
), "TP and EP cannot be enabled at the same time"
199-
200-
num_ranks = self.tensor_parallel_size * self.parallel_config.expert_parallel_size
195+
num_ranks = self.tensor_parallel_size * self.parallel_config.data_parallel_size
201196
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
202197
if num_ranks > self.max_chips_per_node:
203198
self.worker_num_per_node = self.max_chips_per_node

fastdeploy/model_executor/layers/embeddings.py

Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -57,43 +57,37 @@ def __init__(
5757
hcg = fleet.get_hybrid_communicate_group()
5858
self.mp_rank: int = hcg.get_model_parallel_rank()
5959
self.column_cut = False
60-
self.world_size: int = hcg.get_model_parallel_world_size()
61-
self.ring_id: int = hcg.get_model_parallel_group().id
62-
self.use_ep: bool = fd_config.parallel_config.use_ep
60+
self.world_size: int = fd_config.parallel_config.tensor_parallel_size
61+
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
62+
self.tp_group = fd_config.parallel_config.tp_group
6363
self.hidden_dropout_prob: float = fd_config.model_config.hidden_dropout_prob
6464
self.initializer_range: float = fd_config.model_config.initializer_range
6565
self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings
6666
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
6767
self.params_dtype: str = params_dtype
6868

69-
if self.use_ep:
70-
self.embeddings = nn.Embedding(
69+
if not self.column_cut:
70+
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
7171
num_embeddings,
7272
embedding_dim,
73+
mp_group=self.tp_group,
74+
weight_attr=paddle.ParamAttr(
75+
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
76+
),
7377
)
78+
if self.world_size > 1:
79+
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
7480
else:
75-
if not self.column_cut:
76-
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
77-
num_embeddings,
78-
embedding_dim,
79-
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
80-
weight_attr=paddle.ParamAttr(
81-
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
82-
),
83-
)
84-
if self.world_size > 1:
85-
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
86-
else:
87-
# column cut embedding
88-
self.embeddings = nn.Embedding(
89-
num_embeddings,
90-
embedding_dim // self.world_size,
91-
)
92-
93-
self.embeddings.weight.is_distributed = True
94-
self.embeddings.weight.split_axis = 1
95-
if self.world_size > 1:
96-
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
81+
# column cut embedding
82+
self.embeddings = nn.Embedding(
83+
num_embeddings,
84+
embedding_dim // self.world_size,
85+
)
86+
87+
self.embeddings.weight.is_distributed = True
88+
self.embeddings.weight.split_axis = 1
89+
if self.world_size > 1:
90+
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
9791

9892
self.prefix = prefix
9993
self.dropout = nn.Dropout(self.hidden_dropout_prob)
@@ -125,20 +119,17 @@ def forward(self, ids_remove_padding=None) -> paddle.Tensor:
125119
Returns:
126120
Tensor: Embedded tensor representation of the input IDs.
127121
"""
128-
if self.use_ep:
122+
if self.column_cut:
129123
input_embedings = self.embeddings(ids_remove_padding)
124+
inputs_embeds_temp = []
125+
paddle.distributed.all_gather(
126+
inputs_embeds_temp,
127+
input_embedings,
128+
group=self.tp_group,
129+
sync_op=True,
130+
)
131+
input_embedings = paddle.concat(inputs_embeds_temp, -1)
130132
else:
131-
if self.column_cut:
132-
input_embedings = self.embeddings(ids_remove_padding)
133-
inputs_embeds_temp = []
134-
paddle.distributed.all_gather(
135-
inputs_embeds_temp,
136-
input_embedings,
137-
group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
138-
sync_op=True,
139-
)
140-
input_embedings = paddle.concat(inputs_embeds_temp, -1)
141-
else:
142-
input_embedings = self.embeddings(ids_remove_padding)
133+
input_embedings = self.embeddings(ids_remove_padding)
143134

144135
return input_embedings

fastdeploy/model_executor/layers/linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,7 @@ def __init__(
670670
self.fd_config = fd_config
671671
self.skip_quant = False
672672
self.nranks = fd_config.parallel_config.tensor_parallel_size
673+
self.tp_group = fd_config.parallel_config.tp_group
673674
self.hidden_size = fd_config.model_config.hidden_size
674675
self.head_dim = fd_config.model_config.head_dim
675676
self.num_heads = fd_config.model_config.num_attention_heads // self.nranks
@@ -719,7 +720,7 @@ def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
719720
out = paddle.matmul(x, self.weight)
720721

721722
if self.reduce_results and self.nranks > 1:
722-
tensor_model_parallel_all_reduce(out)
723+
tensor_model_parallel_all_reduce(out, self.tp_group)
723724

724725
return out
725726

fastdeploy/model_executor/layers/lm_head.py

Lines changed: 37 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858
self.bias_key: Optional[str] = prefix + ".bias"
5959
else:
6060
self.bias_key: Optional[str] = None
61-
self.use_ep: bool = fd_config.parallel_config.use_ep
61+
self.tp_group = fd_config.parallel_config.tp_group
6262
self.column_cut = True
6363
self.nranks = fd_config.parallel_config.tensor_parallel_size
6464

@@ -67,45 +67,31 @@ def __init__(
6767

6868
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
6969

70-
if self.use_ep:
71-
self.weight = self.create_parameter(
72-
shape=[embedding_dim, num_embeddings],
73-
dtype=paddle.get_default_dtype(),
74-
is_bias=False,
70+
if self.column_cut:
71+
need_gather = True
72+
self.linear = ColumnParallelLinear(
73+
embedding_dim,
74+
num_embeddings,
75+
mp_group=self.tp_group,
76+
weight_attr=None,
77+
has_bias=True if self.bias_key is not None else False,
78+
gather_output=need_gather,
79+
fuse_matmul_bias=False, # False diff更小
7580
)
76-
if self.bias_key is not None:
77-
self.bias = self.create_parameter(
78-
shape=[num_embeddings],
79-
dtype=paddle.get_default_dtype(),
80-
is_bias=True,
81-
)
82-
81+
if self.nranks > 1:
82+
set_weight_attrs(self.linear.weight, {"output_dim": True})
8383
else:
84-
if self.column_cut:
85-
need_gather = True
86-
self.linear = ColumnParallelLinear(
87-
embedding_dim,
88-
num_embeddings,
89-
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
90-
weight_attr=None,
91-
has_bias=True if self.bias_key is not None else False,
92-
gather_output=need_gather,
93-
fuse_matmul_bias=False, # False diff更小
94-
)
95-
if self.nranks > 1:
96-
set_weight_attrs(self.linear.weight, {"output_dim": True})
97-
else:
98-
self.linear = RowParallelLinear(
99-
embedding_dim,
100-
num_embeddings,
101-
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
102-
weight_attr=None,
103-
has_bias=True if self.bias_key is not None else False,
104-
input_is_parallel=False,
105-
fuse_matmul_bias=False, # False diff更小
106-
)
107-
if self.nranks > 1:
108-
set_weight_attrs(self.linear.weight, {"output_dim": False})
84+
self.linear = RowParallelLinear(
85+
embedding_dim,
86+
num_embeddings,
87+
mp_group=self.tp_group,
88+
weight_attr=None,
89+
has_bias=True if self.bias_key is not None else False,
90+
input_is_parallel=False,
91+
fuse_matmul_bias=False, # False diff更小
92+
)
93+
if self.nranks > 1:
94+
set_weight_attrs(self.linear.weight, {"output_dim": False})
10995

11096
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
11197
"""
@@ -115,24 +101,19 @@ def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
115101
state_dict (dict): A dictionary containing the checkpoint weights and biases.
116102
"""
117103

118-
if self.use_ep:
119-
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
120-
if self.bias_key is not None:
121-
self.bias.set_value(get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype()))
104+
if self.tie_word_embeddings:
105+
self.linear.weight.set_value(
106+
get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()).transpose([1, 0])
107+
)
122108
else:
123-
if self.tie_word_embeddings:
124-
self.linear.weight.set_value(
125-
get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()).transpose([1, 0])
126-
)
127-
else:
128-
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
129-
if self.linear.weight.shape != weight_tensor.shape:
130-
weight_tensor = weight_tensor.transpose([1, 0])
131-
self.linear.weight.set_value(weight_tensor)
132-
133-
if self.bias_key is not None:
134-
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
135-
self.linear.bias.set_value(bias)
109+
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
110+
if self.linear.weight.shape != weight_tensor.shape:
111+
weight_tensor = weight_tensor.transpose([1, 0])
112+
self.linear.weight.set_value(weight_tensor)
113+
114+
if self.bias_key is not None:
115+
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
116+
self.linear.bias.set_value(bias)
136117

137118
def forward(self, input: paddle.Tensor) -> paddle.Tensor:
138119
"""
@@ -145,11 +126,5 @@ def forward(self, input: paddle.Tensor) -> paddle.Tensor:
145126
Tensor: The output tensor after processing through the layer.
146127
"""
147128
logits = input
148-
if self.use_ep:
149-
if self.bias_key is None:
150-
logits = paddle.matmul(logits, self.weight)
151-
else:
152-
logits = paddle.incubate.nn.functional.fused_linear(logits, self.weight, self.bias)
153-
else:
154-
logits = self.linear(logits)
129+
logits = self.linear(logits)
155130
return logits

fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def apply_tp(
309309
)
310310

311311
if layer.reduce_results and layer.tp_size > 1:
312-
tensor_model_parallel_all_reduce(fused_moe_out)
312+
tensor_model_parallel_all_reduce(fused_moe_out, self.tp_group)
313313

314314
return fused_moe_out
315315

fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,6 @@ def apply_tp(
465465
1.0,
466466
)[0]
467467
if layer.tp_size > 1:
468-
tensor_model_parallel_all_reduce(tmp_ffn_out)
468+
tensor_model_parallel_all_reduce(tmp_ffn_out, self.tp_group)
469469

470470
return tmp_ffn_out

0 commit comments

Comments
 (0)