31
31
32
32
logger = logging .getLogger (__name__ )
33
33
34
- NixlEngineInfo : TypeAlias = Dict [str , Union [str , int ]]
35
-
36
34
GUARD = "NixlMsgGuard" .encode ("ascii" )
37
35
38
36
39
37
@dataclasses .dataclass
40
38
class TransferInfo :
39
+ """Contains indices for a transfer, sent by KVReceiver. Received by prefill bootstrap thread."""
40
+
41
41
room : int
42
42
endpoint : str
43
43
dst_port : int
44
- agent_metadata : bytes
45
44
agent_name : str
46
- dst_kv_ptrs : list [int ]
47
45
dst_kv_indices : npt .NDArray [np .int32 ]
48
- dst_aux_ptrs : list [int ]
49
46
dst_aux_index : int
50
- dst_gpu_id : int
51
47
required_dst_info_num : int
52
48
53
49
def is_dummy (self ):
@@ -59,14 +55,37 @@ def from_zmq(cls, msg: List[bytes]):
59
55
room = int (msg [0 ].decode ("ascii" )),
60
56
endpoint = msg [1 ].decode ("ascii" ),
61
57
dst_port = int (msg [2 ].decode ("ascii" )),
62
- agent_metadata = msg [3 ],
63
- agent_name = msg [4 ].decode ("ascii" ),
58
+ agent_name = msg [3 ].decode ("ascii" ),
59
+ dst_kv_indices = np .frombuffer (msg [4 ], dtype = np .int32 ),
60
+ dst_aux_index = int (msg [5 ].decode ("ascii" )),
61
+ required_dst_info_num = int (msg [6 ].decode ("ascii" )),
62
+ )
63
+
64
+
65
+ @dataclasses .dataclass
66
+ class KVArgsRegisterInfo :
67
+ """Contains base pointers and other info which only needs to be sent once by KVReceiver. Received by prefill bootstrap thread."""
68
+
69
+ room : str
70
+ endpoint : str
71
+ dst_port : int
72
+ agent_name : str
73
+ agent_metadata : bytes
74
+ dst_kv_ptrs : list [int ]
75
+ dst_aux_ptrs : list [int ]
76
+ gpu_id : int
77
+
78
+ @classmethod
79
+ def from_zmq (cls , msg : List [bytes ]):
80
+ return cls (
81
+ room = str (msg [0 ].decode ("ascii" )),
82
+ endpoint = msg [1 ].decode ("ascii" ),
83
+ dst_port = int (msg [2 ].decode ("ascii" )),
84
+ agent_name = msg [3 ].decode ("ascii" ),
85
+ agent_metadata = msg [4 ],
64
86
dst_kv_ptrs = list (struct .unpack (f"{ len (msg [5 ])// 8 } Q" , msg [5 ])),
65
- dst_kv_indices = np .frombuffer (msg [6 ], dtype = np .int32 ),
66
- dst_aux_ptrs = list (struct .unpack (f"{ len (msg [7 ])// 8 } Q" , msg [7 ])),
67
- dst_aux_index = int (msg [8 ].decode ("ascii" )),
68
- dst_gpu_id = int (msg [9 ].decode ("ascii" )),
69
- required_dst_info_num = int (msg [10 ].decode ("ascii" )),
87
+ dst_aux_ptrs = list (struct .unpack (f"{ len (msg [6 ])// 8 } Q" , msg [6 ])),
88
+ gpu_id = int (msg [7 ].decode ("ascii" )),
70
89
)
71
90
72
91
@@ -109,9 +128,9 @@ def __init__(
109
128
self .register_buffer_to_engine ()
110
129
111
130
if self .disaggregation_mode == DisaggregationMode .PREFILL :
112
- self .request_status = {}
113
- self .transfer_infos : Dict [int , TransferInfo ] = {}
114
- self .peer_names : Dict [str , str ] = {}
131
+ self .request_status : Dict [ int , KVPoll ] = {}
132
+ self .transfer_infos : Dict [int , Dict [ str , TransferInfo ] ] = {}
133
+ self .decode_kv_args_table : Dict [str , KVArgsRegisterInfo ] = {}
115
134
self ._start_bootstrap_thread ()
116
135
elif self .disaggregation_mode == DisaggregationMode .DECODE :
117
136
self .transfer_statuses : Dict [int , TransferStatus ] = defaultdict (
@@ -154,10 +173,13 @@ def register_buffer_to_engine(self):
154
173
if not self .aux_descs :
155
174
raise Exception ("NIXL memory registration failed for aux tensors" )
156
175
157
- def _add_remote (self , agent_name : str , agent_metadata : bytes ):
158
- if agent_name not in self .peer_names :
159
- self .peer_names [agent_name ] = self .agent .add_remote_agent (agent_metadata )
160
- return self .peer_names [agent_name ]
176
+ def _add_remote_peer (self , decode_kv_args : KVArgsRegisterInfo ):
177
+ agent_name = decode_kv_args .agent_name
178
+ if agent_name in self .decode_kv_args_table :
179
+ logger .info (f"Peer { agent_name } was already registered, ignoring." )
180
+ return
181
+ self .decode_kv_args_table [agent_name ] = decode_kv_args
182
+ self .agent .add_remote_agent (decode_kv_args .agent_metadata )
161
183
162
184
def send_kvcache (
163
185
self ,
@@ -262,31 +284,33 @@ def add_transfer_request(
262
284
if req .is_dummy ():
263
285
continue
264
286
265
- peer_name = self ._add_remote (req .agent_name , req .agent_metadata )
266
287
chunked_dst_kv_indice = req .dst_kv_indices [index_slice ]
267
288
assert len (chunked_dst_kv_indice ) == len (kv_indices )
289
+ assert req .agent_name in self .decode_kv_args_table
268
290
269
291
notif = "_" .join ([str (req .room ), "kv" , str (chunk_id ), str (int (is_last ))])
270
292
kv_xfer_handle = self .send_kvcache (
271
- peer_name ,
293
+ req . agent_name ,
272
294
kv_indices ,
273
- req .dst_kv_ptrs ,
295
+ self . decode_kv_args_table [ req . agent_name ] .dst_kv_ptrs ,
274
296
chunked_dst_kv_indice ,
275
- req .dst_gpu_id ,
297
+ self . decode_kv_args_table [ req .agent_name ]. gpu_id ,
276
298
notif ,
277
299
)
278
300
handles .append (kv_xfer_handle )
279
301
# Only the last chunk we need to send the aux data.
280
302
if is_last :
281
303
assert aux_index is not None
282
304
aux_xfer_handle = self .send_aux (
283
- peer_name ,
305
+ req . agent_name ,
284
306
aux_index ,
285
- req .dst_aux_ptrs ,
307
+ self . decode_kv_args_table [ req . agent_name ] .dst_aux_ptrs ,
286
308
req .dst_aux_index ,
287
309
str (req .room ) + "_aux" ,
288
310
)
289
311
handles .append (aux_xfer_handle )
312
+ if is_last :
313
+ del self .transfer_infos [bootstrap_room ]
290
314
return handles
291
315
292
316
def update_transfer_status (self ):
@@ -328,16 +352,23 @@ def bootstrap_thread():
328
352
), f"First message should be { GUARD } . Foreign traffic?"
329
353
waiting_req_bytes = waiting_req_bytes [1 :]
330
354
room = waiting_req_bytes [0 ].decode ("ascii" )
331
-
332
- required_dst_info_num = int (waiting_req_bytes [10 ].decode ("ascii" ))
355
+ agent_name = waiting_req_bytes [3 ].decode ("ascii" )
356
+ if room == "None" :
357
+ # Register new peer and save KV base pointers.
358
+ self ._add_remote_peer (
359
+ KVArgsRegisterInfo .from_zmq (waiting_req_bytes )
360
+ )
361
+ logger .debug (f"Register KVArgs from { agent_name } successfully" )
362
+ continue
333
363
room = int (room )
334
- agent_name = waiting_req_bytes [4 ].decode ("ascii" )
335
364
if room not in self .transfer_infos :
336
365
self .transfer_infos [room ] = {}
337
366
self .transfer_infos [room ][agent_name ] = TransferInfo .from_zmq (
338
367
waiting_req_bytes
339
368
)
340
-
369
+ required_dst_info_num = self .transfer_infos [room ][
370
+ agent_name
371
+ ].required_dst_info_num
341
372
logger .debug (f"got info { room = } { agent_name = } { required_dst_info_num = } " )
342
373
if len (self .transfer_infos [room ]) == required_dst_info_num :
343
374
logger .debug (f"{ room = } is bootstrapped" )
@@ -391,6 +422,7 @@ def send(
391
422
self .chunk_id += 1
392
423
if is_last :
393
424
self .has_sent = True
425
+ del self .kv_mgr .request_status [self .bootstrap_room ]
394
426
395
427
def poll (self ) -> KVPoll :
396
428
if not self .has_sent :
@@ -415,6 +447,7 @@ def __init__(
415
447
data_parallel_rank : Optional [int ] = None ,
416
448
):
417
449
self .started_transfer = False
450
+ self .conclude_state = None
418
451
super ().__init__ (mgr , bootstrap_addr , bootstrap_room , data_parallel_rank )
419
452
420
453
def init (self , kv_indices : npt .NDArray [np .int32 ], aux_index : Optional [int ] = None ):
@@ -426,17 +459,8 @@ def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = Non
426
459
f"Fetched bootstrap info: { bootstrap_info } for engine rank: { self .kv_mgr .kv_args .engine_rank } "
427
460
)
428
461
is_dummy = bootstrap_info ["is_dummy" ]
429
-
430
- # TODO: send_kv_args earlier
431
- packed_kv_data_ptrs = b"" .join (
432
- struct .pack ("Q" , ptr ) for ptr in self .kv_mgr .kv_args .kv_data_ptrs
433
- )
434
- packed_aux_data_ptrs = b"" .join (
435
- struct .pack ("Q" , ptr ) for ptr in self .kv_mgr .kv_args .aux_data_ptrs
436
- )
437
-
438
462
logger .debug (
439
- f"Sending to { self .prefill_server_url } with bootstrap room { self .bootstrap_room } "
463
+ f"Sending to { self .prefill_server_url } with bootstrap room { self .bootstrap_room } { is_dummy = } "
440
464
)
441
465
sock , lock = self ._connect ("tcp://" + self .prefill_server_url )
442
466
with lock :
@@ -446,31 +470,55 @@ def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = Non
446
470
str (self .bootstrap_room ).encode ("ascii" ),
447
471
get_local_ip_by_remote ().encode ("ascii" ),
448
472
str (self .kv_mgr .rank_port ).encode ("ascii" ),
449
- self .kv_mgr .agent .get_agent_metadata (),
450
473
self .kv_mgr .agent .name .encode ("ascii" ),
451
- packed_kv_data_ptrs ,
452
474
kv_indices .tobytes () if not is_dummy else b"" ,
453
- packed_aux_data_ptrs ,
454
475
str (aux_index ).encode ("ascii" ),
455
- str (self .kv_mgr .kv_args .gpu_id ).encode ("ascii" ),
456
476
str (self .required_dst_info_num ).encode ("ascii" ),
457
477
]
458
478
)
459
479
460
480
self .started_transfer = True
461
481
462
482
def poll (self ) -> KVPoll :
483
+ if self .conclude_state is not None :
484
+ return self .conclude_state
463
485
if not self .started_transfer :
464
486
return KVPoll .WaitingForInput # type: ignore
465
487
466
488
self .kv_mgr .update_transfer_status ()
467
-
468
489
if self .kv_mgr .check_transfer_done (self .bootstrap_room ): # type: ignore
490
+ self .conclude_state = KVPoll .Success
491
+ del self .kv_mgr .transfer_statuses [self .bootstrap_room ]
469
492
return KVPoll .Success # type: ignore
470
493
return KVPoll .WaitingForInput # type: ignore
471
494
472
495
def _register_kv_args (self ):
473
- pass
496
+ for bootstrap_info in self .bootstrap_infos :
497
+ self .prefill_server_url = (
498
+ f"{ bootstrap_info ['rank_ip' ]} :{ bootstrap_info ['rank_port' ]} "
499
+ )
500
+ packed_kv_data_ptrs = b"" .join (
501
+ struct .pack ("Q" , ptr ) for ptr in self .kv_mgr .kv_args .kv_data_ptrs
502
+ )
503
+ packed_aux_data_ptrs = b"" .join (
504
+ struct .pack ("Q" , ptr ) for ptr in self .kv_mgr .kv_args .aux_data_ptrs
505
+ )
506
+
507
+ sock , lock = self ._connect ("tcp://" + self .prefill_server_url )
508
+ with lock :
509
+ sock .send_multipart (
510
+ [
511
+ GUARD ,
512
+ "None" .encode ("ascii" ),
513
+ get_local_ip_by_remote ().encode ("ascii" ),
514
+ str (self .kv_mgr .rank_port ).encode ("ascii" ),
515
+ self .kv_mgr .agent .name .encode ("ascii" ),
516
+ self .kv_mgr .agent .get_agent_metadata (),
517
+ packed_kv_data_ptrs ,
518
+ packed_aux_data_ptrs ,
519
+ str (self .kv_mgr .kv_args .gpu_id ).encode ("ascii" ),
520
+ ]
521
+ )
474
522
475
523
def failure_exception (self ):
476
524
raise Exception ("Fake KVReceiver Exception" )
0 commit comments