Skip to content

Commit cf5f385

Browse files
huangtingwei9988xiezhq-hermann
authored andcommitted
fix bug that gpu0 occupies more memory when hicache is turned on (sgl-project#5778)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
1 parent 1aa33ff commit cf5f385

File tree

1 file changed

+115
-119
lines changed

1 file changed

+115
-119
lines changed

python/sglang/srt/managers/cache_controller.py

Lines changed: 115 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -268,98 +268,97 @@ def write_thread_func_direct(self):
268268
"""
269269
Directly write through KV caches to host memory without buffering.
270270
"""
271-
with torch.cuda.stream(self.write_stream):
272-
while not self.stop_event.is_set():
273-
try:
274-
operation = self.write_queue.get(block=True, timeout=1)
275-
self.mem_pool_host.write_page_all_layers(
276-
operation.host_indices,
277-
operation.device_indices,
278-
self.mem_pool_device,
279-
)
280-
self.write_stream.synchronize()
281-
self.mem_pool_host.complete_io(operation.host_indices)
282-
for node_id in operation.node_ids:
283-
if node_id != 0:
284-
self.ack_write_queue.put(node_id)
285-
except Empty:
286-
continue
287-
except Exception as e:
288-
logger.error(e)
271+
torch.cuda.set_stream(self.write_stream)
272+
while not self.stop_event.is_set():
273+
try:
274+
operation = self.write_queue.get(block=True, timeout=1)
275+
self.mem_pool_host.write_page_all_layers(
276+
operation.host_indices,
277+
operation.device_indices,
278+
self.mem_pool_device,
279+
)
280+
self.write_stream.synchronize()
281+
self.mem_pool_host.complete_io(operation.host_indices)
282+
for node_id in operation.node_ids:
283+
if node_id != 0:
284+
self.ack_write_queue.put(node_id)
285+
except Empty:
286+
continue
287+
except Exception as e:
288+
logger.error(e)
289289

290290
def load_thread_func_direct(self):
291291
"""
292292
Directly load KV caches from host memory to device memory without buffering.
293293
"""
294-
with torch.cuda.stream(self.load_stream):
295-
while not self.stop_event.is_set():
296-
try:
297-
operation = self.load_queue.get(block=True, timeout=1)
298-
# time.sleep(18e-6 * len(operation.host_indices))
299-
operation.data = self.mem_pool_host.get_flat_data(
300-
operation.host_indices
301-
)
302-
self.mem_pool_device.transfer(
303-
operation.device_indices, operation.data
304-
)
305-
self.mem_pool_host.complete_io(operation.host_indices)
306-
for node_id in operation.node_ids:
307-
if node_id != 0:
308-
self.ack_load_queue.put(node_id)
309-
except Empty:
310-
continue
311-
except Exception as e:
312-
logger.error(e)
294+
torch.cuda.set_stream(self.load_stream)
295+
while not self.stop_event.is_set():
296+
try:
297+
operation = self.load_queue.get(block=True, timeout=1)
298+
# time.sleep(18e-6 * len(operation.host_indices))
299+
operation.data = self.mem_pool_host.get_flat_data(
300+
operation.host_indices
301+
)
302+
self.mem_pool_device.transfer(operation.device_indices, operation.data)
303+
self.mem_pool_host.complete_io(operation.host_indices)
304+
for node_id in operation.node_ids:
305+
if node_id != 0:
306+
self.ack_load_queue.put(node_id)
307+
except Empty:
308+
continue
309+
except Exception as e:
310+
logger.error(e)
313311

314312
def load_thread_func_layer_by_layer(self):
315313
"""
316314
Load KV caches from host memory to device memory layer by layer.
317315
"""
318-
with torch.cuda.stream(self.load_stream):
319-
while not self.stop_event.is_set():
320-
self.load_cache_event.wait(timeout=1)
321-
if not self.load_cache_event.is_set():
322-
continue
323-
self.load_cache_event.clear()
316+
torch.cuda.set_stream(self.load_stream)
317+
while not self.stop_event.is_set():
318+
self.load_cache_event.wait(timeout=1)
319+
if not self.load_cache_event.is_set():
320+
continue
321+
self.load_cache_event.clear()
324322

325-
batch_operation = None
326-
while self.load_queue.qsize() > 0:
327-
op = self.load_queue.get(block=True)
328-
if batch_operation is None:
329-
batch_operation = op
330-
else:
331-
batch_operation.merge(op)
323+
batch_operation = None
324+
while self.load_queue.qsize() > 0:
325+
op = self.load_queue.get(block=True)
332326
if batch_operation is None:
333-
continue
327+
batch_operation = op
328+
else:
329+
batch_operation.merge(op)
330+
if batch_operation is None:
331+
continue
334332

335-
self.layer_done_counter.reset()
336-
for i in range(self.mem_pool_host.layer_num):
337-
if self.page_size == 1:
338-
flat_data = self.mem_pool_host.get_flat_data_by_layer(
339-
batch_operation.host_indices, i
340-
)
341-
self.mem_pool_device.transfer_per_layer(
342-
batch_operation.device_indices, flat_data, i
343-
)
344-
else:
345-
self.mem_pool_host.load_page_per_layer(
346-
batch_operation.host_indices,
347-
batch_operation.device_indices,
348-
self.mem_pool_device,
349-
i,
350-
)
351-
self.load_stream.synchronize()
352-
self.layer_done_counter.increment()
353-
354-
self.mem_pool_host.complete_io(batch_operation.host_indices)
355-
for node_id in batch_operation.node_ids:
356-
if node_id != 0:
357-
self.ack_load_queue.put(node_id)
333+
self.layer_done_counter.reset()
334+
for i in range(self.mem_pool_host.layer_num):
335+
if self.page_size == 1:
336+
flat_data = self.mem_pool_host.get_flat_data_by_layer(
337+
batch_operation.host_indices, i
338+
)
339+
self.mem_pool_device.transfer_per_layer(
340+
batch_operation.device_indices, flat_data, i
341+
)
342+
else:
343+
self.mem_pool_host.load_page_per_layer(
344+
batch_operation.host_indices,
345+
batch_operation.device_indices,
346+
self.mem_pool_device,
347+
i,
348+
)
349+
self.load_stream.synchronize()
350+
self.layer_done_counter.increment()
351+
352+
self.mem_pool_host.complete_io(batch_operation.host_indices)
353+
for node_id in batch_operation.node_ids:
354+
if node_id != 0:
355+
self.ack_load_queue.put(node_id)
358356

359357
def write_aux_func(self, no_wait=False):
360358
"""
361359
Auxiliary function to prepare the buffer for write operations.
362360
"""
361+
torch.cuda.set_stream(self.write_stream)
363362

364363
def _to_op(op_):
365364
assert op_.device_indices.is_cuda, "Device indices should be on GPU"
@@ -370,44 +369,42 @@ def _to_op(op_):
370369
return op_
371370

372371
buffer = None
373-
with torch.cuda.stream(self.write_stream):
374-
while not self.stop_event.is_set():
375-
try:
376-
operation = self.write_queue.get(block=True, timeout=1)
377-
factor = (
378-
len(operation.device_indices)
379-
// self.write_buffer.max_buffer_size
380-
)
372+
while not self.stop_event.is_set():
373+
try:
374+
operation = self.write_queue.get(block=True, timeout=1)
375+
factor = (
376+
len(operation.device_indices) // self.write_buffer.max_buffer_size
377+
)
381378

382-
if factor >= 1:
383-
if buffer is not None:
384-
_to_op(buffer)
385-
buffer = None
386-
387-
if factor < 2:
388-
_to_op(operation)
389-
else:
390-
split_ops = operation.split(factor)
391-
for op_ in split_ops:
392-
_to_op(op_)
393-
continue
394-
395-
if buffer is None:
396-
buffer = operation
397-
else:
398-
buffer.merge(operation)
399-
if (
400-
no_wait
401-
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
402-
or self.write_queue.empty()
403-
or self.write_buffer.empty()
404-
):
379+
if factor >= 1:
380+
if buffer is not None:
405381
_to_op(buffer)
406382
buffer = None
407-
except Empty:
383+
384+
if factor < 2:
385+
_to_op(operation)
386+
else:
387+
split_ops = operation.split(factor)
388+
for op_ in split_ops:
389+
_to_op(op_)
408390
continue
409-
except Exception as e:
410-
logger.error(e)
391+
392+
if buffer is None:
393+
buffer = operation
394+
else:
395+
buffer.merge(operation)
396+
if (
397+
no_wait
398+
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
399+
or self.write_queue.empty()
400+
or self.write_buffer.empty()
401+
):
402+
_to_op(buffer)
403+
buffer = None
404+
except Empty:
405+
continue
406+
except Exception as e:
407+
logger.error(e)
411408

412409
def load_aux_func(self):
413410
"""
@@ -484,19 +481,18 @@ def write_thread_func_buffer(self):
484481
aux_thread.join()
485482

486483
def load_thread_func_buffer(self):
484+
torch.cuda.set_stream(self.load_stream)
487485
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
488486
aux_thread.start()
489-
490-
with torch.cuda.stream(self.load_stream):
491-
while not self.stop_event.is_set():
492-
operation = self.load_buffer.get()
493-
if operation is None:
494-
continue
495-
self.mem_pool_device.transfer(operation.device_indices, operation.data)
496-
self.mem_pool_host.complete_io(operation.host_indices)
497-
for node_id in operation.node_ids:
498-
if node_id != 0:
499-
self.ack_load_queue.put(node_id)
487+
while not self.stop_event.is_set():
488+
operation = self.load_buffer.get()
489+
if operation is None:
490+
continue
491+
self.mem_pool_device.transfer(operation.device_indices, operation.data)
492+
self.mem_pool_host.complete_io(operation.host_indices)
493+
for node_id in operation.node_ids:
494+
if node_id != 0:
495+
self.ack_load_queue.put(node_id)
500496
aux_thread.join()
501497

502498
def evict_device(

0 commit comments

Comments
 (0)