@@ -43,9 +43,10 @@ def __init__(
43
43
num_max_dispatch_tokens_per_rank : int ,
44
44
hidden : int ,
45
45
num_experts : int ,
46
- moe_phase : MoEPhase ,
47
46
ep_size : int ,
48
47
ep_rank : int ,
48
+ splitwise_role : str ,
49
+ moe_phase : MoEPhase ,
49
50
async_finish : bool = False ,
50
51
):
51
52
"""
@@ -65,26 +66,44 @@ def __init__(
65
66
self .hidden = hidden
66
67
self .num_experts = num_experts
67
68
self .num_local_experts = num_experts // ep_size
68
- self .moe_phase = moe_phase
69
69
self .async_finish = async_finish
70
70
71
- self .deepep_engine = None
71
+ self .prefill_deepep_engine = None
72
+ self .decode_deepep_engine = None
73
+
74
+ self .ep_config = Config (24 , 6 , 256 )
75
+ self .num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
72
76
73
- if moe_phase == MoEPhase .DECODER :
77
+ # In mixed EP mode on a single node, we dynamically switch between
78
+ # high throughput and low latency modes.
79
+ if splitwise_role == "mixed" :
80
+ # decode engine
74
81
logger .info ("Initializing Low Latency Buffer" )
75
- self .num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
76
82
self .get_low_latency_buffer ()
77
- elif moe_phase == MoEPhase . PREFILL :
78
- self .deepep_engine = deep_ep .Buffer (
83
+ # prefill engine
84
+ self .prefill_deepep_engine = deep_ep .Buffer (
79
85
self .group ,
80
86
int (5e8 ),
81
87
0 ,
82
88
low_latency_mode = False ,
83
89
num_qps_per_rank = 1 ,
84
90
)
85
- self .ep_config = Config (24 , 6 , 256 )
91
+ # In disaggregated mode on mutiple nodes, we either use
92
+ # high throughput mode or low latency mode.
86
93
else :
87
- raise ValueError (f"Unknown generation phase { moe_phase } " )
94
+ if moe_phase .phase == "decode" :
95
+ logger .info ("Initializing Low Latency Buffer" )
96
+ self .get_low_latency_buffer ()
97
+ elif moe_phase .phase == "prefill" :
98
+ self .prefill_deepep_engine = deep_ep .Buffer (
99
+ self .group ,
100
+ int (5e8 ),
101
+ 0 ,
102
+ low_latency_mode = False ,
103
+ num_qps_per_rank = 1 ,
104
+ )
105
+ else :
106
+ raise ValueError (f"Unknown generation phase { moe_phase } " )
88
107
89
108
def get_low_latency_buffer (self ):
90
109
"""
@@ -105,14 +124,14 @@ def get_low_latency_buffer(self):
105
124
)
106
125
# Allocate a buffer if not existed or not enough buffer size
107
126
if (
108
- self .deepep_engine is None
109
- or self .deepep_engine .group != self .group
110
- or not self .deepep_engine .low_latency_mode
111
- or self .deepep_engine .num_rdma_bytes < num_rdma_bytes
127
+ self .decode_deepep_engine is None
128
+ or self .decode_deepep_engine .group != self .group
129
+ or not self .decode_deepep_engine .low_latency_mode
130
+ or self .decode_deepep_engine .num_rdma_bytes < num_rdma_bytes
112
131
):
113
132
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
114
133
assert self .num_experts % self .ep_size == 0
115
- self .deepep_engine = deep_ep .Buffer (
134
+ self .decode_deepep_engine = deep_ep .Buffer (
116
135
self .group ,
117
136
0 ,
118
137
num_rdma_bytes ,
@@ -149,7 +168,7 @@ def low_latency_dispatch(
149
168
handle ,
150
169
_ ,
151
170
dispatch_hook ,
152
- ) = self .deepep_engine .low_latency_dispatch (
171
+ ) = self .decode_deepep_engine .low_latency_dispatch (
153
172
hidden_states ,
154
173
topk_idx ,
155
174
expertwise_scale ,
@@ -174,8 +193,22 @@ def low_latency_combine(
174
193
Return:
175
194
combined_hidden_states: [num_tokens, hidden]
176
195
"""
196
+ # TODO(@wufeisheng): Delete them when deepep in PaddlePaddle is fixed
197
+ (
198
+ src_info ,
199
+ layout_range ,
200
+ num_max_dispatch_tokens_per_rank ,
201
+ num_experts ,
202
+ ) = handle
203
+ handle = (
204
+ src_info ,
205
+ layout_range ,
206
+ num_max_dispatch_tokens_per_rank ,
207
+ None ,
208
+ num_experts ,
209
+ )
177
210
178
- combined_hidden_states , _ , combine_hook = self .deepep_engine .low_latency_combine (
211
+ combined_hidden_states , _ , combine_hook = self .decode_deepep_engine .low_latency_combine (
179
212
hidden_states ,
180
213
topk_idx ,
181
214
topk_weights ,
@@ -189,15 +222,19 @@ def clean_low_latency_buffer(self):
189
222
"""
190
223
clean_low_latency_buffer
191
224
"""
192
- self .deepep_engine .clean_low_latency_buffer (
225
+ self .decode_deepep_engine .clean_low_latency_buffer (
193
226
self .num_max_dispatch_tokens_per_rank , self .hidden , self .num_experts
194
227
)
195
228
196
229
def barrier_all (self ):
197
230
"""
198
231
barrier_all
199
232
"""
200
- self .deepep_engine .barrier_all ()
233
+ if self .prefill_deepep_engine is not None :
234
+ self .prefill_deepep_engine .barrier_all ()
235
+
236
+ if self .decode_deepep_engine is not None :
237
+ self .decode_deepep_engine .barrier_all ()
201
238
202
239
203
240
class EPRunner :
@@ -210,6 +247,7 @@ def __init__(
210
247
top_k : int ,
211
248
hidden : int ,
212
249
num_experts : int ,
250
+ splitwise_role : str ,
213
251
moe_phase : MoEPhase ,
214
252
num_max_dispatch_tokens_per_rank : int = 1 ,
215
253
ep_size : int = 1 ,
@@ -223,9 +261,10 @@ def __init__(
223
261
num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank ,
224
262
hidden = hidden ,
225
263
num_experts = num_experts + redundant_experts_num ,
226
- moe_phase = moe_phase ,
227
264
ep_size = ep_size ,
228
265
ep_rank = ep_rank ,
266
+ splitwise_role = splitwise_role ,
267
+ moe_phase = moe_phase ,
229
268
)
230
269
231
270
def moe_select (self , layer : nn .Layer , gate_out : paddle .Tensor ):
@@ -286,15 +325,19 @@ def __init__(
286
325
top_k : int ,
287
326
hidden : int ,
288
327
num_experts : int ,
328
+ splitwise_role : str ,
289
329
ep_size : int = 1 ,
290
330
ep_rank : int = 0 ,
291
331
redundant_experts_num : int = 0 ,
332
+ moe_phase : MoEPhase = MoEPhase ("prefill" ),
292
333
):
293
334
super ().__init__ (
294
335
top_k ,
295
336
hidden ,
296
337
num_experts ,
297
- MoEPhase .PREFILL ,
338
+ splitwise_role ,
339
+ moe_phase ,
340
+ num_max_dispatch_tokens_per_rank = 256 ,
298
341
ep_size = ep_size ,
299
342
ep_rank = ep_rank ,
300
343
redundant_experts_num = redundant_experts_num ,
@@ -314,7 +357,7 @@ def dispatch(
314
357
num_tokens_per_expert ,
315
358
is_token_in_rank ,
316
359
_ ,
317
- ) = self .ep_engine .deepep_engine .get_dispatch_layout (topk_idx , self .num_experts )
360
+ ) = self .ep_engine .prefill_deepep_engine .get_dispatch_layout (topk_idx , self .num_experts )
318
361
319
362
x_scale_tensor = kwargs .get ("x_scale_tensor" , None )
320
363
dispatch_args = {
@@ -327,7 +370,7 @@ def dispatch(
327
370
"topk_idx" : topk_idx ,
328
371
"topk_weights" : topk_weights ,
329
372
}
330
- return self .ep_engine .deepep_engine .dispatch (** dispatch_args )
373
+ return self .ep_engine .prefill_deepep_engine .dispatch (** dispatch_args )
331
374
332
375
def combine (
333
376
self ,
@@ -342,7 +385,7 @@ def combine(
342
385
"async_finish" : self .ep_engine .async_finish ,
343
386
"topk_weights" : recv_topk_weights ,
344
387
}
345
- fused_moe_out , _ , _ = self .ep_engine .deepep_engine .combine (** combine_args )
388
+ fused_moe_out , _ , _ = self .ep_engine .prefill_deepep_engine .combine (** combine_args )
346
389
347
390
return fused_moe_out
348
391
@@ -357,16 +400,19 @@ def __init__(
357
400
top_k : int ,
358
401
hidden : int ,
359
402
num_experts : int ,
403
+ splitwise_role : str ,
360
404
num_max_dispatch_tokens_per_rank : int ,
361
405
ep_size : int = 1 ,
362
406
ep_rank : int = 0 ,
363
407
redundant_experts_num : int = 0 ,
408
+ moe_phase : MoEPhase = MoEPhase ("decode" ),
364
409
):
365
410
super ().__init__ (
366
411
top_k ,
367
412
hidden ,
368
413
num_experts ,
369
- MoEPhase .DECODER ,
414
+ splitwise_role ,
415
+ moe_phase ,
370
416
num_max_dispatch_tokens_per_rank ,
371
417
ep_size = ep_size ,
372
418
ep_rank = ep_rank ,
0 commit comments