Skip to content

Commit 5f52783

Browse files
authored
[PD] NIXL: Register kv args in advance and cleanup finished requests (sgl-project#6717)
1 parent 9f1787f commit 5f52783

File tree

1 file changed

+94
-46
lines changed
  • python/sglang/srt/disaggregation/nixl

1 file changed

+94
-46
lines changed

python/sglang/srt/disaggregation/nixl/conn.py

Lines changed: 94 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,19 @@
3131

3232
logger = logging.getLogger(__name__)
3333

34-
NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
35-
3634
GUARD = "NixlMsgGuard".encode("ascii")
3735

3836

3937
@dataclasses.dataclass
4038
class TransferInfo:
39+
"""Contains indices for a transfer, sent by KVReceiver. Received by prefill bootstrap thread."""
40+
4141
room: int
4242
endpoint: str
4343
dst_port: int
44-
agent_metadata: bytes
4544
agent_name: str
46-
dst_kv_ptrs: list[int]
4745
dst_kv_indices: npt.NDArray[np.int32]
48-
dst_aux_ptrs: list[int]
4946
dst_aux_index: int
50-
dst_gpu_id: int
5147
required_dst_info_num: int
5248

5349
def is_dummy(self):
@@ -59,14 +55,37 @@ def from_zmq(cls, msg: List[bytes]):
5955
room=int(msg[0].decode("ascii")),
6056
endpoint=msg[1].decode("ascii"),
6157
dst_port=int(msg[2].decode("ascii")),
62-
agent_metadata=msg[3],
63-
agent_name=msg[4].decode("ascii"),
58+
agent_name=msg[3].decode("ascii"),
59+
dst_kv_indices=np.frombuffer(msg[4], dtype=np.int32),
60+
dst_aux_index=int(msg[5].decode("ascii")),
61+
required_dst_info_num=int(msg[6].decode("ascii")),
62+
)
63+
64+
65+
@dataclasses.dataclass
66+
class KVArgsRegisterInfo:
67+
"""Contains base pointers and other info which only needs to be sent once by KVReceiver. Received by prefill bootstrap thread."""
68+
69+
room: str
70+
endpoint: str
71+
dst_port: int
72+
agent_name: str
73+
agent_metadata: bytes
74+
dst_kv_ptrs: list[int]
75+
dst_aux_ptrs: list[int]
76+
gpu_id: int
77+
78+
@classmethod
79+
def from_zmq(cls, msg: List[bytes]):
80+
return cls(
81+
room=str(msg[0].decode("ascii")),
82+
endpoint=msg[1].decode("ascii"),
83+
dst_port=int(msg[2].decode("ascii")),
84+
agent_name=msg[3].decode("ascii"),
85+
agent_metadata=msg[4],
6486
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
65-
dst_kv_indices=np.frombuffer(msg[6], dtype=np.int32),
66-
dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
67-
dst_aux_index=int(msg[8].decode("ascii")),
68-
dst_gpu_id=int(msg[9].decode("ascii")),
69-
required_dst_info_num=int(msg[10].decode("ascii")),
87+
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
88+
gpu_id=int(msg[7].decode("ascii")),
7089
)
7190

7291

@@ -109,9 +128,9 @@ def __init__(
109128
self.register_buffer_to_engine()
110129

111130
if self.disaggregation_mode == DisaggregationMode.PREFILL:
112-
self.request_status = {}
113-
self.transfer_infos: Dict[int, TransferInfo] = {}
114-
self.peer_names: Dict[str, str] = {}
131+
self.request_status: Dict[int, KVPoll] = {}
132+
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
133+
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
115134
self._start_bootstrap_thread()
116135
elif self.disaggregation_mode == DisaggregationMode.DECODE:
117136
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
@@ -154,10 +173,13 @@ def register_buffer_to_engine(self):
154173
if not self.aux_descs:
155174
raise Exception("NIXL memory registration failed for aux tensors")
156175

157-
def _add_remote(self, agent_name: str, agent_metadata: bytes):
158-
if agent_name not in self.peer_names:
159-
self.peer_names[agent_name] = self.agent.add_remote_agent(agent_metadata)
160-
return self.peer_names[agent_name]
176+
def _add_remote_peer(self, decode_kv_args: KVArgsRegisterInfo):
177+
agent_name = decode_kv_args.agent_name
178+
if agent_name in self.decode_kv_args_table:
179+
logger.info(f"Peer {agent_name} was already registered, ignoring.")
180+
return
181+
self.decode_kv_args_table[agent_name] = decode_kv_args
182+
self.agent.add_remote_agent(decode_kv_args.agent_metadata)
161183

162184
def send_kvcache(
163185
self,
@@ -262,31 +284,33 @@ def add_transfer_request(
262284
if req.is_dummy():
263285
continue
264286

265-
peer_name = self._add_remote(req.agent_name, req.agent_metadata)
266287
chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
267288
assert len(chunked_dst_kv_indice) == len(kv_indices)
289+
assert req.agent_name in self.decode_kv_args_table
268290

269291
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
270292
kv_xfer_handle = self.send_kvcache(
271-
peer_name,
293+
req.agent_name,
272294
kv_indices,
273-
req.dst_kv_ptrs,
295+
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
274296
chunked_dst_kv_indice,
275-
req.dst_gpu_id,
297+
self.decode_kv_args_table[req.agent_name].gpu_id,
276298
notif,
277299
)
278300
handles.append(kv_xfer_handle)
279301
# Only the last chunk we need to send the aux data.
280302
if is_last:
281303
assert aux_index is not None
282304
aux_xfer_handle = self.send_aux(
283-
peer_name,
305+
req.agent_name,
284306
aux_index,
285-
req.dst_aux_ptrs,
307+
self.decode_kv_args_table[req.agent_name].dst_aux_ptrs,
286308
req.dst_aux_index,
287309
str(req.room) + "_aux",
288310
)
289311
handles.append(aux_xfer_handle)
312+
if is_last:
313+
del self.transfer_infos[bootstrap_room]
290314
return handles
291315

292316
def update_transfer_status(self):
@@ -328,16 +352,23 @@ def bootstrap_thread():
328352
), f"First message should be {GUARD}. Foreign traffic?"
329353
waiting_req_bytes = waiting_req_bytes[1:]
330354
room = waiting_req_bytes[0].decode("ascii")
331-
332-
required_dst_info_num = int(waiting_req_bytes[10].decode("ascii"))
355+
agent_name = waiting_req_bytes[3].decode("ascii")
356+
if room == "None":
357+
# Register new peer and save KV base pointers.
358+
self._add_remote_peer(
359+
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
360+
)
361+
logger.debug(f"Register KVArgs from {agent_name} successfully")
362+
continue
333363
room = int(room)
334-
agent_name = waiting_req_bytes[4].decode("ascii")
335364
if room not in self.transfer_infos:
336365
self.transfer_infos[room] = {}
337366
self.transfer_infos[room][agent_name] = TransferInfo.from_zmq(
338367
waiting_req_bytes
339368
)
340-
369+
required_dst_info_num = self.transfer_infos[room][
370+
agent_name
371+
].required_dst_info_num
341372
logger.debug(f"got info {room=} {agent_name=} {required_dst_info_num=}")
342373
if len(self.transfer_infos[room]) == required_dst_info_num:
343374
logger.debug(f"{room=} is bootstrapped")
@@ -391,6 +422,7 @@ def send(
391422
self.chunk_id += 1
392423
if is_last:
393424
self.has_sent = True
425+
del self.kv_mgr.request_status[self.bootstrap_room]
394426

395427
def poll(self) -> KVPoll:
396428
if not self.has_sent:
@@ -415,6 +447,7 @@ def __init__(
415447
data_parallel_rank: Optional[int] = None,
416448
):
417449
self.started_transfer = False
450+
self.conclude_state = None
418451
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
419452

420453
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
@@ -426,17 +459,8 @@ def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = Non
426459
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
427460
)
428461
is_dummy = bootstrap_info["is_dummy"]
429-
430-
# TODO: send_kv_args earlier
431-
packed_kv_data_ptrs = b"".join(
432-
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
433-
)
434-
packed_aux_data_ptrs = b"".join(
435-
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
436-
)
437-
438462
logger.debug(
439-
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}"
463+
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room} {is_dummy=}"
440464
)
441465
sock, lock = self._connect("tcp://" + self.prefill_server_url)
442466
with lock:
@@ -446,31 +470,55 @@ def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = Non
446470
str(self.bootstrap_room).encode("ascii"),
447471
get_local_ip_by_remote().encode("ascii"),
448472
str(self.kv_mgr.rank_port).encode("ascii"),
449-
self.kv_mgr.agent.get_agent_metadata(),
450473
self.kv_mgr.agent.name.encode("ascii"),
451-
packed_kv_data_ptrs,
452474
kv_indices.tobytes() if not is_dummy else b"",
453-
packed_aux_data_ptrs,
454475
str(aux_index).encode("ascii"),
455-
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
456476
str(self.required_dst_info_num).encode("ascii"),
457477
]
458478
)
459479

460480
self.started_transfer = True
461481

462482
def poll(self) -> KVPoll:
483+
if self.conclude_state is not None:
484+
return self.conclude_state
463485
if not self.started_transfer:
464486
return KVPoll.WaitingForInput # type: ignore
465487

466488
self.kv_mgr.update_transfer_status()
467-
468489
if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
490+
self.conclude_state = KVPoll.Success
491+
del self.kv_mgr.transfer_statuses[self.bootstrap_room]
469492
return KVPoll.Success # type: ignore
470493
return KVPoll.WaitingForInput # type: ignore
471494

472495
def _register_kv_args(self):
473-
pass
496+
for bootstrap_info in self.bootstrap_infos:
497+
self.prefill_server_url = (
498+
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
499+
)
500+
packed_kv_data_ptrs = b"".join(
501+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
502+
)
503+
packed_aux_data_ptrs = b"".join(
504+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
505+
)
506+
507+
sock, lock = self._connect("tcp://" + self.prefill_server_url)
508+
with lock:
509+
sock.send_multipart(
510+
[
511+
GUARD,
512+
"None".encode("ascii"),
513+
get_local_ip_by_remote().encode("ascii"),
514+
str(self.kv_mgr.rank_port).encode("ascii"),
515+
self.kv_mgr.agent.name.encode("ascii"),
516+
self.kv_mgr.agent.get_agent_metadata(),
517+
packed_kv_data_ptrs,
518+
packed_aux_data_ptrs,
519+
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
520+
]
521+
)
474522

475523
def failure_exception(self):
476524
raise Exception("Fake KVReceiver Exception")

0 commit comments

Comments
 (0)