Skip to content

Commit 0700c90

Browse files
[Feat] support mixed ep (#2969)
* Support mixed ep * fix comment * fix comment * update mixep * fix conflict * fix typo * update * fix typo * fix code style * fix conflict
1 parent 332154f commit 0700c90

File tree

4 files changed

+137
-48
lines changed

4 files changed

+137
-48
lines changed

fastdeploy/config.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import os
2020
from dataclasses import dataclass, field
21-
from enum import Enum
2221
from typing import Literal, Optional
2322

2423
from paddleformers.transformers.configuration_utils import PretrainedConfig
@@ -30,13 +29,24 @@
3029
logger = get_logger("config", "config.log")
3130

3231

33-
class MoEPhase(Enum):
32+
class MoEPhase:
3433
"""
3534
The generation phase of the moe.
3635
"""
3736

38-
PREFILL = 1
39-
DECODER = 2
37+
def __init__(self, phase="prefill"):
38+
self._phase = phase
39+
40+
@property
41+
def phase(self):
42+
return self._phase
43+
44+
@phase.setter
45+
def phase(self, value):
46+
if value not in ["prefill", "decode"]:
47+
raise ValueError(f"The moe_phase is invalid, only support prefill and decode, but got {value}")
48+
else:
49+
self._phase = value
4050

4151

4252
class ErnieArchitectures:
@@ -146,7 +156,7 @@ def __init__(
146156
):
147157
self.sequence_parallel = False # Whether to enable sequence parallelism.
148158
self.use_ep = False # Whether to enable Expert Parallelism
149-
self.moe_phase = MoEPhase.PREFILL # Generation phase
159+
self.moe_phase = MoEPhase("prefill") # Generation phase
150160
self.msg_queue_id = 1 # mesage queue id
151161

152162
self.tensor_parallel_rank = 0 # TP rank ID
@@ -210,11 +220,11 @@ def __init__(
210220
setattr(self, key, value)
211221
self.use_ep = args["expert_parallel_size"] > 1
212222
if self.splitwise_role == "mixed":
213-
self.moe_phase = MoEPhase.PREFILL
223+
self.moe_phase = MoEPhase(phase="prefill")
214224
elif self.splitwise_role == "prefill":
215-
self.moe_phase = MoEPhase.PREFILL
225+
self.moe_phase = MoEPhase(phase="prefill")
216226
elif self.splitwise_role == "decode":
217-
self.moe_phase = MoEPhase.DECODER
227+
self.moe_phase = MoEPhase(phase="decode")
218228
else:
219229
raise NotImplementedError
220230

fastdeploy/model_executor/layers/moe/ep.py

Lines changed: 70 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,10 @@ def __init__(
4343
num_max_dispatch_tokens_per_rank: int,
4444
hidden: int,
4545
num_experts: int,
46-
moe_phase: MoEPhase,
4746
ep_size: int,
4847
ep_rank: int,
48+
splitwise_role: str,
49+
moe_phase: MoEPhase,
4950
async_finish: bool = False,
5051
):
5152
"""
@@ -65,26 +66,44 @@ def __init__(
6566
self.hidden = hidden
6667
self.num_experts = num_experts
6768
self.num_local_experts = num_experts // ep_size
68-
self.moe_phase = moe_phase
6969
self.async_finish = async_finish
7070

71-
self.deepep_engine = None
71+
self.prefill_deepep_engine = None
72+
self.decode_deepep_engine = None
73+
74+
self.ep_config = Config(24, 6, 256)
75+
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
7276

73-
if moe_phase == MoEPhase.DECODER:
77+
# In mixed EP mode on a single node, we dynamically switch between
78+
# high throughput and low latency modes.
79+
if splitwise_role == "mixed":
80+
# decode engine
7481
logger.info("Initializing Low Latency Buffer")
75-
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
7682
self.get_low_latency_buffer()
77-
elif moe_phase == MoEPhase.PREFILL:
78-
self.deepep_engine = deep_ep.Buffer(
83+
# prefill engine
84+
self.prefill_deepep_engine = deep_ep.Buffer(
7985
self.group,
8086
int(5e8),
8187
0,
8288
low_latency_mode=False,
8389
num_qps_per_rank=1,
8490
)
85-
self.ep_config = Config(24, 6, 256)
91+
# In disaggregated mode on mutiple nodes, we either use
92+
# high throughput mode or low latency mode.
8693
else:
87-
raise ValueError(f"Unknown generation phase {moe_phase}")
94+
if moe_phase.phase == "decode":
95+
logger.info("Initializing Low Latency Buffer")
96+
self.get_low_latency_buffer()
97+
elif moe_phase.phase == "prefill":
98+
self.prefill_deepep_engine = deep_ep.Buffer(
99+
self.group,
100+
int(5e8),
101+
0,
102+
low_latency_mode=False,
103+
num_qps_per_rank=1,
104+
)
105+
else:
106+
raise ValueError(f"Unknown generation phase {moe_phase}")
88107

89108
def get_low_latency_buffer(self):
90109
"""
@@ -105,14 +124,14 @@ def get_low_latency_buffer(self):
105124
)
106125
# Allocate a buffer if not existed or not enough buffer size
107126
if (
108-
self.deepep_engine is None
109-
or self.deepep_engine.group != self.group
110-
or not self.deepep_engine.low_latency_mode
111-
or self.deepep_engine.num_rdma_bytes < num_rdma_bytes
127+
self.decode_deepep_engine is None
128+
or self.decode_deepep_engine.group != self.group
129+
or not self.decode_deepep_engine.low_latency_mode
130+
or self.decode_deepep_engine.num_rdma_bytes < num_rdma_bytes
112131
):
113132
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
114133
assert self.num_experts % self.ep_size == 0
115-
self.deepep_engine = deep_ep.Buffer(
134+
self.decode_deepep_engine = deep_ep.Buffer(
116135
self.group,
117136
0,
118137
num_rdma_bytes,
@@ -149,7 +168,7 @@ def low_latency_dispatch(
149168
handle,
150169
_,
151170
dispatch_hook,
152-
) = self.deepep_engine.low_latency_dispatch(
171+
) = self.decode_deepep_engine.low_latency_dispatch(
153172
hidden_states,
154173
topk_idx,
155174
expertwise_scale,
@@ -174,8 +193,22 @@ def low_latency_combine(
174193
Return:
175194
combined_hidden_states: [num_tokens, hidden]
176195
"""
196+
# TODO(@wufeisheng): Delete them when deepep in PaddlePaddle is fixed
197+
(
198+
src_info,
199+
layout_range,
200+
num_max_dispatch_tokens_per_rank,
201+
num_experts,
202+
) = handle
203+
handle = (
204+
src_info,
205+
layout_range,
206+
num_max_dispatch_tokens_per_rank,
207+
None,
208+
num_experts,
209+
)
177210

178-
combined_hidden_states, _, combine_hook = self.deepep_engine.low_latency_combine(
211+
combined_hidden_states, _, combine_hook = self.decode_deepep_engine.low_latency_combine(
179212
hidden_states,
180213
topk_idx,
181214
topk_weights,
@@ -189,15 +222,19 @@ def clean_low_latency_buffer(self):
189222
"""
190223
clean_low_latency_buffer
191224
"""
192-
self.deepep_engine.clean_low_latency_buffer(
225+
self.decode_deepep_engine.clean_low_latency_buffer(
193226
self.num_max_dispatch_tokens_per_rank, self.hidden, self.num_experts
194227
)
195228

196229
def barrier_all(self):
197230
"""
198231
barrier_all
199232
"""
200-
self.deepep_engine.barrier_all()
233+
if self.prefill_deepep_engine is not None:
234+
self.prefill_deepep_engine.barrier_all()
235+
236+
if self.decode_deepep_engine is not None:
237+
self.decode_deepep_engine.barrier_all()
201238

202239

203240
class EPRunner:
@@ -210,6 +247,7 @@ def __init__(
210247
top_k: int,
211248
hidden: int,
212249
num_experts: int,
250+
splitwise_role: str,
213251
moe_phase: MoEPhase,
214252
num_max_dispatch_tokens_per_rank: int = 1,
215253
ep_size: int = 1,
@@ -223,9 +261,10 @@ def __init__(
223261
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
224262
hidden=hidden,
225263
num_experts=num_experts + redundant_experts_num,
226-
moe_phase=moe_phase,
227264
ep_size=ep_size,
228265
ep_rank=ep_rank,
266+
splitwise_role=splitwise_role,
267+
moe_phase=moe_phase,
229268
)
230269

231270
def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
@@ -286,15 +325,19 @@ def __init__(
286325
top_k: int,
287326
hidden: int,
288327
num_experts: int,
328+
splitwise_role: str,
289329
ep_size: int = 1,
290330
ep_rank: int = 0,
291331
redundant_experts_num: int = 0,
332+
moe_phase: MoEPhase = MoEPhase("prefill"),
292333
):
293334
super().__init__(
294335
top_k,
295336
hidden,
296337
num_experts,
297-
MoEPhase.PREFILL,
338+
splitwise_role,
339+
moe_phase,
340+
num_max_dispatch_tokens_per_rank=256,
298341
ep_size=ep_size,
299342
ep_rank=ep_rank,
300343
redundant_experts_num=redundant_experts_num,
@@ -314,7 +357,7 @@ def dispatch(
314357
num_tokens_per_expert,
315358
is_token_in_rank,
316359
_,
317-
) = self.ep_engine.deepep_engine.get_dispatch_layout(topk_idx, self.num_experts)
360+
) = self.ep_engine.prefill_deepep_engine.get_dispatch_layout(topk_idx, self.num_experts)
318361

319362
x_scale_tensor = kwargs.get("x_scale_tensor", None)
320363
dispatch_args = {
@@ -327,7 +370,7 @@ def dispatch(
327370
"topk_idx": topk_idx,
328371
"topk_weights": topk_weights,
329372
}
330-
return self.ep_engine.deepep_engine.dispatch(**dispatch_args)
373+
return self.ep_engine.prefill_deepep_engine.dispatch(**dispatch_args)
331374

332375
def combine(
333376
self,
@@ -342,7 +385,7 @@ def combine(
342385
"async_finish": self.ep_engine.async_finish,
343386
"topk_weights": recv_topk_weights,
344387
}
345-
fused_moe_out, _, _ = self.ep_engine.deepep_engine.combine(**combine_args)
388+
fused_moe_out, _, _ = self.ep_engine.prefill_deepep_engine.combine(**combine_args)
346389

347390
return fused_moe_out
348391

@@ -357,16 +400,19 @@ def __init__(
357400
top_k: int,
358401
hidden: int,
359402
num_experts: int,
403+
splitwise_role: str,
360404
num_max_dispatch_tokens_per_rank: int,
361405
ep_size: int = 1,
362406
ep_rank: int = 0,
363407
redundant_experts_num: int = 0,
408+
moe_phase: MoEPhase = MoEPhase("decode"),
364409
):
365410
super().__init__(
366411
top_k,
367412
hidden,
368413
num_experts,
369-
MoEPhase.DECODER,
414+
splitwise_role,
415+
moe_phase,
370416
num_max_dispatch_tokens_per_rank,
371417
ep_size=ep_size,
372418
ep_rank=ep_rank,

fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
import paddle
2020
from paddle import nn
2121

22-
from fastdeploy.config import MoEPhase
23-
2422
from ..quantization.quant_base import QuantMethodBase
2523

2624

@@ -45,29 +43,54 @@ def init_ep(self, layer: nn.Layer) -> None:
4543
Init EP related module
4644
"""
4745
if layer.ep_size > 1:
48-
if layer.fd_config.parallel_config.moe_phase == MoEPhase.DECODER:
49-
from .ep import EPDecoderRunner
46+
if layer.fd_config.parallel_config.splitwise_role == "mixed":
47+
from .ep import EPDecoderRunner, EPPrefillRunner
5048

51-
self.ep_decoder_runner = EPDecoderRunner(
49+
self.ep_prefill_runner = EPPrefillRunner(
5250
layer.top_k,
5351
layer.hidden_size,
5452
layer.num_experts,
55-
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
53+
layer.fd_config.parallel_config.splitwise_role,
5654
layer.ep_size,
5755
layer.ep_rank,
5856
layer.fd_config.model_config.redundant_experts_num,
5957
)
60-
else:
61-
from .ep import EPPrefillRunner
62-
63-
self.ep_prefill_runner = EPPrefillRunner(
58+
self.ep_decoder_runner = EPDecoderRunner(
6459
layer.top_k,
6560
layer.hidden_size,
6661
layer.num_experts,
62+
layer.fd_config.parallel_config.splitwise_role,
63+
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
6764
layer.ep_size,
6865
layer.ep_rank,
6966
layer.fd_config.model_config.redundant_experts_num,
7067
)
68+
else:
69+
if layer.fd_config.parallel_config.moe_phase == "prefill":
70+
from .ep import EPPrefillRunner
71+
72+
self.ep_prefill_runner = EPPrefillRunner(
73+
layer.top_k,
74+
layer.hidden_size,
75+
layer.num_experts,
76+
layer.fd_config.parallel_config.splitwise_role,
77+
layer.ep_size,
78+
layer.ep_rank,
79+
layer.fd_config.model_config.redundant_experts_num,
80+
)
81+
else:
82+
from .ep import EPDecoderRunner
83+
84+
self.ep_decoder_runner = EPDecoderRunner(
85+
layer.top_k,
86+
layer.hidden_size,
87+
layer.num_experts,
88+
layer.moe_config.num_max_dispatch_tokens_per_rank,
89+
layer.fd_config.parallel_config.splitwise_role,
90+
layer.ep_size,
91+
layer.ep_rank,
92+
layer.fd_config.model_config.redundant_experts_num,
93+
)
7194

7295
def process_loaded_weights(self, layer, weights) -> None:
7396
"""
@@ -141,7 +164,7 @@ def apply(
141164
Paddle Cutlass compute Fused MoE.
142165
"""
143166
if layer.ep_size > 1:
144-
if layer.fd_config.parallel_config.moe_phase == MoEPhase.PREFILL:
167+
if layer.fd_config.parallel_config.moe_phase.phase == "prefill":
145168
return self.apply_ep_prefill(layer, x, gate_out)
146169
else:
147170
return self.apply_ep_decode(layer, x, gate_out)

0 commit comments

Comments
 (0)