@@ -268,98 +268,97 @@ def write_thread_func_direct(self):
268
268
"""
269
269
Directly write through KV caches to host memory without buffering.
270
270
"""
271
- with torch .cuda .stream (self .write_stream ):
272
- while not self .stop_event .is_set ():
273
- try :
274
- operation = self .write_queue .get (block = True , timeout = 1 )
275
- self .mem_pool_host .write_page_all_layers (
276
- operation .host_indices ,
277
- operation .device_indices ,
278
- self .mem_pool_device ,
279
- )
280
- self .write_stream .synchronize ()
281
- self .mem_pool_host .complete_io (operation .host_indices )
282
- for node_id in operation .node_ids :
283
- if node_id != 0 :
284
- self .ack_write_queue .put (node_id )
285
- except Empty :
286
- continue
287
- except Exception as e :
288
- logger .error (e )
271
+ torch .cuda .set_stream (self .write_stream )
272
+ while not self .stop_event .is_set ():
273
+ try :
274
+ operation = self .write_queue .get (block = True , timeout = 1 )
275
+ self .mem_pool_host .write_page_all_layers (
276
+ operation .host_indices ,
277
+ operation .device_indices ,
278
+ self .mem_pool_device ,
279
+ )
280
+ self .write_stream .synchronize ()
281
+ self .mem_pool_host .complete_io (operation .host_indices )
282
+ for node_id in operation .node_ids :
283
+ if node_id != 0 :
284
+ self .ack_write_queue .put (node_id )
285
+ except Empty :
286
+ continue
287
+ except Exception as e :
288
+ logger .error (e )
289
289
290
290
def load_thread_func_direct (self ):
291
291
"""
292
292
Directly load KV caches from host memory to device memory without buffering.
293
293
"""
294
- with torch .cuda .stream (self .load_stream ):
295
- while not self .stop_event .is_set ():
296
- try :
297
- operation = self .load_queue .get (block = True , timeout = 1 )
298
- # time.sleep(18e-6 * len(operation.host_indices))
299
- operation .data = self .mem_pool_host .get_flat_data (
300
- operation .host_indices
301
- )
302
- self .mem_pool_device .transfer (
303
- operation .device_indices , operation .data
304
- )
305
- self .mem_pool_host .complete_io (operation .host_indices )
306
- for node_id in operation .node_ids :
307
- if node_id != 0 :
308
- self .ack_load_queue .put (node_id )
309
- except Empty :
310
- continue
311
- except Exception as e :
312
- logger .error (e )
294
+ torch .cuda .set_stream (self .load_stream )
295
+ while not self .stop_event .is_set ():
296
+ try :
297
+ operation = self .load_queue .get (block = True , timeout = 1 )
298
+ # time.sleep(18e-6 * len(operation.host_indices))
299
+ operation .data = self .mem_pool_host .get_flat_data (
300
+ operation .host_indices
301
+ )
302
+ self .mem_pool_device .transfer (operation .device_indices , operation .data )
303
+ self .mem_pool_host .complete_io (operation .host_indices )
304
+ for node_id in operation .node_ids :
305
+ if node_id != 0 :
306
+ self .ack_load_queue .put (node_id )
307
+ except Empty :
308
+ continue
309
+ except Exception as e :
310
+ logger .error (e )
313
311
314
312
def load_thread_func_layer_by_layer (self ):
315
313
"""
316
314
Load KV caches from host memory to device memory layer by layer.
317
315
"""
318
- with torch .cuda .stream (self .load_stream ):
319
- while not self .stop_event .is_set ():
320
- self .load_cache_event .wait (timeout = 1 )
321
- if not self .load_cache_event .is_set ():
322
- continue
323
- self .load_cache_event .clear ()
316
+ torch .cuda .set_stream (self .load_stream )
317
+ while not self .stop_event .is_set ():
318
+ self .load_cache_event .wait (timeout = 1 )
319
+ if not self .load_cache_event .is_set ():
320
+ continue
321
+ self .load_cache_event .clear ()
324
322
325
- batch_operation = None
326
- while self .load_queue .qsize () > 0 :
327
- op = self .load_queue .get (block = True )
328
- if batch_operation is None :
329
- batch_operation = op
330
- else :
331
- batch_operation .merge (op )
323
+ batch_operation = None
324
+ while self .load_queue .qsize () > 0 :
325
+ op = self .load_queue .get (block = True )
332
326
if batch_operation is None :
333
- continue
327
+ batch_operation = op
328
+ else :
329
+ batch_operation .merge (op )
330
+ if batch_operation is None :
331
+ continue
334
332
335
- self .layer_done_counter .reset ()
336
- for i in range (self .mem_pool_host .layer_num ):
337
- if self .page_size == 1 :
338
- flat_data = self .mem_pool_host .get_flat_data_by_layer (
339
- batch_operation .host_indices , i
340
- )
341
- self .mem_pool_device .transfer_per_layer (
342
- batch_operation .device_indices , flat_data , i
343
- )
344
- else :
345
- self .mem_pool_host .load_page_per_layer (
346
- batch_operation .host_indices ,
347
- batch_operation .device_indices ,
348
- self .mem_pool_device ,
349
- i ,
350
- )
351
- self .load_stream .synchronize ()
352
- self .layer_done_counter .increment ()
353
-
354
- self .mem_pool_host .complete_io (batch_operation .host_indices )
355
- for node_id in batch_operation .node_ids :
356
- if node_id != 0 :
357
- self .ack_load_queue .put (node_id )
333
+ self .layer_done_counter .reset ()
334
+ for i in range (self .mem_pool_host .layer_num ):
335
+ if self .page_size == 1 :
336
+ flat_data = self .mem_pool_host .get_flat_data_by_layer (
337
+ batch_operation .host_indices , i
338
+ )
339
+ self .mem_pool_device .transfer_per_layer (
340
+ batch_operation .device_indices , flat_data , i
341
+ )
342
+ else :
343
+ self .mem_pool_host .load_page_per_layer (
344
+ batch_operation .host_indices ,
345
+ batch_operation .device_indices ,
346
+ self .mem_pool_device ,
347
+ i ,
348
+ )
349
+ self .load_stream .synchronize ()
350
+ self .layer_done_counter .increment ()
351
+
352
+ self .mem_pool_host .complete_io (batch_operation .host_indices )
353
+ for node_id in batch_operation .node_ids :
354
+ if node_id != 0 :
355
+ self .ack_load_queue .put (node_id )
358
356
359
357
def write_aux_func (self , no_wait = False ):
360
358
"""
361
359
Auxiliary function to prepare the buffer for write operations.
362
360
"""
361
+ torch .cuda .set_stream (self .write_stream )
363
362
364
363
def _to_op (op_ ):
365
364
assert op_ .device_indices .is_cuda , "Device indices should be on GPU"
@@ -370,44 +369,42 @@ def _to_op(op_):
370
369
return op_
371
370
372
371
buffer = None
373
- with torch .cuda .stream (self .write_stream ):
374
- while not self .stop_event .is_set ():
375
- try :
376
- operation = self .write_queue .get (block = True , timeout = 1 )
377
- factor = (
378
- len (operation .device_indices )
379
- // self .write_buffer .max_buffer_size
380
- )
372
+ while not self .stop_event .is_set ():
373
+ try :
374
+ operation = self .write_queue .get (block = True , timeout = 1 )
375
+ factor = (
376
+ len (operation .device_indices ) // self .write_buffer .max_buffer_size
377
+ )
381
378
382
- if factor >= 1 :
383
- if buffer is not None :
384
- _to_op (buffer )
385
- buffer = None
386
-
387
- if factor < 2 :
388
- _to_op (operation )
389
- else :
390
- split_ops = operation .split (factor )
391
- for op_ in split_ops :
392
- _to_op (op_ )
393
- continue
394
-
395
- if buffer is None :
396
- buffer = operation
397
- else :
398
- buffer .merge (operation )
399
- if (
400
- no_wait
401
- or len (buffer .host_indices ) >= self .write_buffer .max_buffer_size
402
- or self .write_queue .empty ()
403
- or self .write_buffer .empty ()
404
- ):
379
+ if factor >= 1 :
380
+ if buffer is not None :
405
381
_to_op (buffer )
406
382
buffer = None
407
- except Empty :
383
+
384
+ if factor < 2 :
385
+ _to_op (operation )
386
+ else :
387
+ split_ops = operation .split (factor )
388
+ for op_ in split_ops :
389
+ _to_op (op_ )
408
390
continue
409
- except Exception as e :
410
- logger .error (e )
391
+
392
+ if buffer is None :
393
+ buffer = operation
394
+ else :
395
+ buffer .merge (operation )
396
+ if (
397
+ no_wait
398
+ or len (buffer .host_indices ) >= self .write_buffer .max_buffer_size
399
+ or self .write_queue .empty ()
400
+ or self .write_buffer .empty ()
401
+ ):
402
+ _to_op (buffer )
403
+ buffer = None
404
+ except Empty :
405
+ continue
406
+ except Exception as e :
407
+ logger .error (e )
411
408
412
409
def load_aux_func (self ):
413
410
"""
@@ -484,19 +481,18 @@ def write_thread_func_buffer(self):
484
481
aux_thread .join ()
485
482
486
483
def load_thread_func_buffer (self ):
484
+ torch .cuda .set_stream (self .load_stream )
487
485
aux_thread = threading .Thread (target = self .load_aux_func , daemon = True )
488
486
aux_thread .start ()
489
-
490
- with torch .cuda .stream (self .load_stream ):
491
- while not self .stop_event .is_set ():
492
- operation = self .load_buffer .get ()
493
- if operation is None :
494
- continue
495
- self .mem_pool_device .transfer (operation .device_indices , operation .data )
496
- self .mem_pool_host .complete_io (operation .host_indices )
497
- for node_id in operation .node_ids :
498
- if node_id != 0 :
499
- self .ack_load_queue .put (node_id )
487
+ while not self .stop_event .is_set ():
488
+ operation = self .load_buffer .get ()
489
+ if operation is None :
490
+ continue
491
+ self .mem_pool_device .transfer (operation .device_indices , operation .data )
492
+ self .mem_pool_host .complete_io (operation .host_indices )
493
+ for node_id in operation .node_ids :
494
+ if node_id != 0 :
495
+ self .ack_load_queue .put (node_id )
500
496
aux_thread .join ()
501
497
502
498
def evict_device (
0 commit comments