Skip to content

Commit 2ed68d7

Browse files
authored
[PD Disaggregation] replace transfer with batch transfer for better performance (sgl-project#7236)
1 parent e984d50 commit 2ed68d7

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

python/sglang/srt/disaggregation/mooncake/conn.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -251,17 +251,19 @@ def send_kvcache(
251251

252252
# Worker function for processing a single layer
253253
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
254+
src_addr_list = []
255+
dst_addr_list = []
256+
length_list = []
254257
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
255258
src_addr = src_ptr + int(prefill_index[0]) * item_len
256259
dst_addr = dst_ptr + int(decode_index[0]) * item_len
257260
length = item_len * len(prefill_index)
258-
259-
status = self.engine.transfer_sync(
260-
mooncake_session_id, src_addr, dst_addr, length
261-
)
262-
if status != 0:
263-
return status
264-
return 0
261+
src_addr_list.append(src_addr)
262+
dst_addr_list.append(dst_addr)
263+
length_list.append(length)
264+
return self.engine.batch_transfer_sync(
265+
mooncake_session_id, src_addr_list, dst_addr_list, length_list
266+
)
265267

266268
futures = [
267269
executor.submit(

python/sglang/srt/disaggregation/mooncake/transfer_engine.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import logging
33
from dataclasses import dataclass
4-
from typing import Optional
4+
from typing import List, Optional
55

66
logger = logging.getLogger(__name__)
77

@@ -90,5 +90,29 @@ def transfer_sync(
9090

9191
return ret
9292

93+
def batch_transfer_sync(
94+
self,
95+
session_id: str,
96+
buffers: List[int],
97+
peer_buffer_addresses: List[int],
98+
lengths: List[int],
99+
) -> int:
100+
"""Synchronously transfer data to the specified address."""
101+
try:
102+
ret = self.engine.batch_transfer_sync_write(
103+
session_id, buffers, peer_buffer_addresses, lengths
104+
)
105+
except Exception:
106+
ret = -1
107+
108+
if ret < 0:
109+
logger.debug(
110+
"Failed to batch transfer data. Buffers: %s, Session: %s, Peer addresses: %s",
111+
buffers,
112+
session_id,
113+
peer_buffer_addresses,
114+
)
115+
return ret
116+
93117
def get_session_id(self):
94118
return self.session_id

0 commit comments

Comments
 (0)