34
34
)
35
35
from sglang .srt .disaggregation .mooncake .transfer_engine import MooncakeTransferEngine
36
36
from sglang .srt .disaggregation .utils import DisaggregationMode
37
+ from sglang .srt .distributed import get_pp_group
37
38
from sglang .srt .layers .dp_attention import (
38
39
get_attention_dp_rank ,
39
40
get_attention_dp_size ,
@@ -180,6 +181,7 @@ def __init__(
180
181
self .session_failures = defaultdict (int )
181
182
self .failed_sessions = set ()
182
183
self .session_lock = threading .Lock ()
184
+ self .pp_group = get_pp_group ()
183
185
# Determine the number of threads to use for kv sender
184
186
cpu_count = os .cpu_count ()
185
187
transfer_thread_pool_size = get_int_env_var (
@@ -313,11 +315,11 @@ def send_kvcache(
313
315
layers_params = None
314
316
315
317
# pp is not supported on the decode side yet
318
+ start_layer = self .kv_args .prefill_start_layer
319
+ end_layer = start_layer + len (self .kv_args .kv_data_ptrs )
316
320
if self .is_mla_backend :
317
321
src_kv_ptrs = self .kv_args .kv_data_ptrs
318
322
layers_per_pp_stage = len (src_kv_ptrs )
319
- start_layer = self .pp_rank * layers_per_pp_stage
320
- end_layer = start_layer + layers_per_pp_stage
321
323
dst_kv_ptrs = dst_kv_ptrs [start_layer :end_layer ]
322
324
kv_item_len = self .kv_args .kv_item_lens [0 ]
323
325
layers_params = [
@@ -330,17 +332,15 @@ def send_kvcache(
330
332
]
331
333
else :
332
334
num_kv_layers = len (self .kv_args .kv_data_ptrs ) // 2
335
+ dst_num_total_layers = num_kv_layers * self .pp_size
333
336
src_k_ptrs = self .kv_args .kv_data_ptrs [:num_kv_layers ]
334
337
src_v_ptrs = self .kv_args .kv_data_ptrs [num_kv_layers :]
335
338
layers_per_pp_stage = len (src_k_ptrs )
336
- start_layer = self .pp_rank * layers_per_pp_stage
337
- end_layer = start_layer + layers_per_pp_stage
338
339
dst_k_ptrs = dst_kv_ptrs [start_layer :end_layer ]
339
340
dst_v_ptrs = dst_kv_ptrs [
340
- num_kv_layers + start_layer : num_kv_layers + end_layer
341
+ dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
341
342
]
342
343
kv_item_len = self .kv_args .kv_item_lens [0 ]
343
-
344
344
layers_params = [
345
345
(
346
346
src_k_ptrs [layer_id ],
@@ -452,14 +452,15 @@ def send_kvcache_slice(
452
452
453
453
# pp is not supported on the decode side yet
454
454
num_kv_layers = len (self .kv_args .kv_data_ptrs ) // 2
455
+ dst_num_total_layers = num_kv_layers * self .pp_size
455
456
src_k_ptrs = self .kv_args .kv_data_ptrs [:num_kv_layers ]
456
457
src_v_ptrs = self .kv_args .kv_data_ptrs [num_kv_layers :]
457
458
layers_per_pp_stage = len (src_k_ptrs )
458
459
start_layer = self .pp_rank * layers_per_pp_stage
459
460
end_layer = start_layer + layers_per_pp_stage
460
461
dst_k_ptrs = dst_kv_ptrs [start_layer :end_layer ]
461
462
dst_v_ptrs = dst_kv_ptrs [
462
- num_kv_layers + start_layer : num_kv_layers + end_layer
463
+ dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
463
464
]
464
465
465
466
# Calculate precise byte offset and length for the sub-slice within the token
@@ -612,7 +613,7 @@ def transfer_worker(
612
613
)
613
614
polls = []
614
615
dst_ranks_infos = []
615
- local_rank = self .kv_args . engine_rank
616
+ local_rank = self .attn_tp_rank * self . pp_size + self . pp_rank
616
617
for req in reqs_to_be_processed :
617
618
if not req .is_dummy :
618
619
# Early exit if the request has failed
@@ -695,13 +696,14 @@ def transfer_worker(
695
696
break
696
697
697
698
if kv_chunk .is_last :
698
- # Only the last chunk we need to send the aux data
699
- ret = self .send_aux (
700
- req .mooncake_session_id ,
701
- kv_chunk .prefill_aux_index ,
702
- target_rank_registration_info .dst_aux_ptrs ,
703
- req .dst_aux_index ,
704
- )
699
+ if self .pp_group .is_last_rank :
700
+ # Only the last chunk we need to send the aux data
701
+ ret = self .send_aux (
702
+ req .mooncake_session_id ,
703
+ kv_chunk .prefill_aux_index ,
704
+ target_rank_registration_info .dst_aux_ptrs ,
705
+ req .dst_aux_index ,
706
+ )
705
707
polls .append (True if ret == 0 else False )
706
708
dst_ranks_infos .append (
707
709
(req .endpoint , req .dst_port , req .room )
@@ -798,10 +800,7 @@ def decode_thread():
798
800
arrived_response_num = len (
799
801
self .prefill_response_tracker [bootstrap_room ]
800
802
)
801
- if (
802
- self .is_mla_backend
803
- or arrived_response_num == expected_response_num
804
- ):
803
+ if arrived_response_num == expected_response_num :
805
804
self .update_status (bootstrap_room , KVPoll .Success )
806
805
elif status == KVPoll .Failed :
807
806
self .record_failure (
@@ -1183,7 +1182,9 @@ def __init__(
1183
1182
self .kv_mgr .kv_args .engine_rank % self .kv_mgr .attn_tp_size
1184
1183
)
1185
1184
self .required_dst_info_num = 1
1186
- self .required_prefill_response_num = 1
1185
+ self .required_prefill_response_num = 1 * (
1186
+ self .prefill_pp_size // self .kv_mgr .pp_size
1187
+ )
1187
1188
self .target_tp_ranks = [self .target_tp_rank ]
1188
1189
elif self .kv_mgr .attn_tp_size > self .prefill_attn_tp_size :
1189
1190
if not self .kv_mgr .is_mla_backend :
@@ -1196,7 +1197,9 @@ def __init__(
1196
1197
self .required_dst_info_num = (
1197
1198
self .kv_mgr .attn_tp_size // self .prefill_attn_tp_size
1198
1199
)
1199
- self .required_prefill_response_num = 1
1200
+ self .required_prefill_response_num = 1 * (
1201
+ self .prefill_pp_size // self .kv_mgr .pp_size
1202
+ )
1200
1203
self .target_tp_ranks = [self .target_tp_rank ]
1201
1204
else :
1202
1205
if not self .kv_mgr .is_mla_backend :
@@ -1219,9 +1222,14 @@ def __init__(
1219
1222
# or the KVPoll will never be set correctly
1220
1223
self .target_tp_rank = self .target_tp_ranks [0 ]
1221
1224
self .required_dst_info_num = 1
1222
- self .required_prefill_response_num = (
1223
- self .prefill_attn_tp_size // self .kv_mgr .attn_tp_size
1224
- )
1225
+ if self .kv_mgr .is_mla_backend :
1226
+ self .required_prefill_response_num = (
1227
+ self .prefill_pp_size // self .kv_mgr .pp_size
1228
+ )
1229
+ else :
1230
+ self .required_prefill_response_num = (
1231
+ self .prefill_attn_tp_size // self .kv_mgr .attn_tp_size
1232
+ ) * (self .prefill_pp_size // self .kv_mgr .pp_size )
1225
1233
1226
1234
if self .data_parallel_rank is not None :
1227
1235
logger .debug (f"Targeting DP rank: { self .data_parallel_rank } " )
@@ -1530,7 +1538,7 @@ async def _handle_route_put(self, request: web.Request):
1530
1538
"rank_port" : rank_port ,
1531
1539
}
1532
1540
logger .debug (
1533
- f"Register prefill bootstrap: DP { dp_group } TP{ attn_tp_rank } PP{ pp_rank } with rank_ip: { rank_ip } and rank_port: { rank_port } "
1541
+ f"Register prefill bootstrap: DP{ dp_group } TP{ attn_tp_rank } PP{ pp_rank } with rank_ip: { rank_ip } and rank_port: { rank_port } "
1534
1542
)
1535
1543
1536
1544
return web .Response (text = "OK" , status = 200 )
0 commit comments