Skip to content

Commit 384f8ab

Browse files
ShangmingCairootYing1123ssssnowzhjc1124
authored
[PD] Support PD disaggregation with Prefill PP (sgl-project#8846)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com> Signed-off-by: Shangming Cai <csmthu@gmail.com> Co-authored-by: root <huzhiyuan@xiaohongshu.com> Co-authored-by: Ying Sheng <sqy1415@gmail.com> Co-authored-by: Francis <38564764+ssssnow@users.noreply.github.com> Co-authored-by: zitto <zhjc1124@gmail.com>
1 parent 6a9d6ca commit 384f8ab

File tree

11 files changed

+632
-82
lines changed

11 files changed

+632
-82
lines changed

python/sglang/srt/disaggregation/base/conn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class KVArgs:
3030
# for pp prefill
3131
prefill_pp_size: int
3232
pp_rank: int
33+
prefill_start_layer: int
3334
# for system dp
3435
system_dp_rank: int
3536

python/sglang/srt/disaggregation/mooncake/conn.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
)
3535
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
3636
from sglang.srt.disaggregation.utils import DisaggregationMode
37+
from sglang.srt.distributed import get_pp_group
3738
from sglang.srt.layers.dp_attention import (
3839
get_attention_dp_rank,
3940
get_attention_dp_size,
@@ -180,6 +181,7 @@ def __init__(
180181
self.session_failures = defaultdict(int)
181182
self.failed_sessions = set()
182183
self.session_lock = threading.Lock()
184+
self.pp_group = get_pp_group()
183185
# Determine the number of threads to use for kv sender
184186
cpu_count = os.cpu_count()
185187
transfer_thread_pool_size = get_int_env_var(
@@ -313,11 +315,11 @@ def send_kvcache(
313315
layers_params = None
314316

315317
# pp is not supported on the decode side yet
318+
start_layer = self.kv_args.prefill_start_layer
319+
end_layer = start_layer + len(self.kv_args.kv_data_ptrs)
316320
if self.is_mla_backend:
317321
src_kv_ptrs = self.kv_args.kv_data_ptrs
318322
layers_per_pp_stage = len(src_kv_ptrs)
319-
start_layer = self.pp_rank * layers_per_pp_stage
320-
end_layer = start_layer + layers_per_pp_stage
321323
dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
322324
kv_item_len = self.kv_args.kv_item_lens[0]
323325
layers_params = [
@@ -330,17 +332,15 @@ def send_kvcache(
330332
]
331333
else:
332334
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
335+
dst_num_total_layers = num_kv_layers * self.pp_size
333336
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
334337
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
335338
layers_per_pp_stage = len(src_k_ptrs)
336-
start_layer = self.pp_rank * layers_per_pp_stage
337-
end_layer = start_layer + layers_per_pp_stage
338339
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
339340
dst_v_ptrs = dst_kv_ptrs[
340-
num_kv_layers + start_layer : num_kv_layers + end_layer
341+
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
341342
]
342343
kv_item_len = self.kv_args.kv_item_lens[0]
343-
344344
layers_params = [
345345
(
346346
src_k_ptrs[layer_id],
@@ -452,14 +452,15 @@ def send_kvcache_slice(
452452

453453
# pp is not supported on the decode side yet
454454
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
455+
dst_num_total_layers = num_kv_layers * self.pp_size
455456
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
456457
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
457458
layers_per_pp_stage = len(src_k_ptrs)
458459
start_layer = self.pp_rank * layers_per_pp_stage
459460
end_layer = start_layer + layers_per_pp_stage
460461
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
461462
dst_v_ptrs = dst_kv_ptrs[
462-
num_kv_layers + start_layer : num_kv_layers + end_layer
463+
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
463464
]
464465

465466
# Calculate precise byte offset and length for the sub-slice within the token
@@ -612,7 +613,7 @@ def transfer_worker(
612613
)
613614
polls = []
614615
dst_ranks_infos = []
615-
local_rank = self.kv_args.engine_rank
616+
local_rank = self.attn_tp_rank * self.pp_size + self.pp_rank
616617
for req in reqs_to_be_processed:
617618
if not req.is_dummy:
618619
# Early exit if the request has failed
@@ -695,13 +696,14 @@ def transfer_worker(
695696
break
696697

697698
if kv_chunk.is_last:
698-
# Only the last chunk we need to send the aux data
699-
ret = self.send_aux(
700-
req.mooncake_session_id,
701-
kv_chunk.prefill_aux_index,
702-
target_rank_registration_info.dst_aux_ptrs,
703-
req.dst_aux_index,
704-
)
699+
if self.pp_group.is_last_rank:
700+
# Only the last chunk we need to send the aux data
701+
ret = self.send_aux(
702+
req.mooncake_session_id,
703+
kv_chunk.prefill_aux_index,
704+
target_rank_registration_info.dst_aux_ptrs,
705+
req.dst_aux_index,
706+
)
705707
polls.append(True if ret == 0 else False)
706708
dst_ranks_infos.append(
707709
(req.endpoint, req.dst_port, req.room)
@@ -798,10 +800,7 @@ def decode_thread():
798800
arrived_response_num = len(
799801
self.prefill_response_tracker[bootstrap_room]
800802
)
801-
if (
802-
self.is_mla_backend
803-
or arrived_response_num == expected_response_num
804-
):
803+
if arrived_response_num == expected_response_num:
805804
self.update_status(bootstrap_room, KVPoll.Success)
806805
elif status == KVPoll.Failed:
807806
self.record_failure(
@@ -1183,7 +1182,9 @@ def __init__(
11831182
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
11841183
)
11851184
self.required_dst_info_num = 1
1186-
self.required_prefill_response_num = 1
1185+
self.required_prefill_response_num = 1 * (
1186+
self.prefill_pp_size // self.kv_mgr.pp_size
1187+
)
11871188
self.target_tp_ranks = [self.target_tp_rank]
11881189
elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
11891190
if not self.kv_mgr.is_mla_backend:
@@ -1196,7 +1197,9 @@ def __init__(
11961197
self.required_dst_info_num = (
11971198
self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
11981199
)
1199-
self.required_prefill_response_num = 1
1200+
self.required_prefill_response_num = 1 * (
1201+
self.prefill_pp_size // self.kv_mgr.pp_size
1202+
)
12001203
self.target_tp_ranks = [self.target_tp_rank]
12011204
else:
12021205
if not self.kv_mgr.is_mla_backend:
@@ -1219,9 +1222,14 @@ def __init__(
12191222
# or the KVPoll will never be set correctly
12201223
self.target_tp_rank = self.target_tp_ranks[0]
12211224
self.required_dst_info_num = 1
1222-
self.required_prefill_response_num = (
1223-
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
1224-
)
1225+
if self.kv_mgr.is_mla_backend:
1226+
self.required_prefill_response_num = (
1227+
self.prefill_pp_size // self.kv_mgr.pp_size
1228+
)
1229+
else:
1230+
self.required_prefill_response_num = (
1231+
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
1232+
) * (self.prefill_pp_size // self.kv_mgr.pp_size)
12251233

12261234
if self.data_parallel_rank is not None:
12271235
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
@@ -1530,7 +1538,7 @@ async def _handle_route_put(self, request: web.Request):
15301538
"rank_port": rank_port,
15311539
}
15321540
logger.debug(
1533-
f"Register prefill bootstrap: DP {dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
1541+
f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
15341542
)
15351543

15361544
return web.Response(text="OK", status=200)

0 commit comments

Comments
 (0)