Skip to content

Commit 3774f07

Browse files
authored
Multi-Stage Awake: Support Resume and Pause KV Cache and Weights separately (sgl-project#7099)
1 parent 9179ea1 commit 3774f07

File tree

14 files changed

+308
-119
lines changed

14 files changed

+308
-119
lines changed

python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ srt_npu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"]
9898
openai = ["openai>=1.0", "tiktoken"]
9999
anthropic = ["anthropic>=0.20.0"]
100100
litellm = ["litellm>=1.0.0"]
101-
torch_memory_saver = ["torch_memory_saver>=0.0.4"]
101+
torch_memory_saver = ["torch_memory_saver>=0.0.8"]
102102
decord = ["decord"]
103103
test = [
104104
"accelerate",

python/sglang/srt/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# GPU Memory Types
2+
GPU_MEMORY_TYPE_KV_CACHE = "kv_cache"
3+
GPU_MEMORY_TYPE_WEIGHTS = "weights"

python/sglang/srt/disaggregation/decode.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import torch
3232
from torch.distributed import ProcessGroup
3333

34+
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
3435
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
3536
from sglang.srt.disaggregation.utils import (
3637
FAKE_BOOTSTRAP_HOST,
@@ -90,7 +91,7 @@ def __init__(
9091
self.max_context_len = max_context_len
9192
self.device = device
9293
self.pre_alloc_size = pre_alloc_size
93-
with memory_saver_adapter.region():
94+
with memory_saver_adapter.region(tag=GPU_MEMORY_TYPE_KV_CACHE):
9495
self.req_to_token = torch.zeros(
9596
(size + pre_alloc_size, max_context_len),
9697
dtype=torch.int32,

python/sglang/srt/entrypoints/engine.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -479,17 +479,15 @@ def get_weights_by_name(self, name: str, truncate_size: int = 100):
479479
self.tokenizer_manager.get_weights_by_name(obj, None)
480480
)
481481

482-
def release_memory_occupation(self):
483-
"""Release GPU occupation temporarily."""
484-
obj = ReleaseMemoryOccupationReqInput()
482+
def release_memory_occupation(self, tags: Optional[List[str]] = None):
483+
obj = ReleaseMemoryOccupationReqInput(tags=tags)
485484
loop = asyncio.get_event_loop()
486485
return loop.run_until_complete(
487486
self.tokenizer_manager.release_memory_occupation(obj, None)
488487
)
489488

490-
def resume_memory_occupation(self):
491-
"""Resume GPU occupation."""
492-
obj = ResumeMemoryOccupationReqInput()
489+
def resume_memory_occupation(self, tags: Optional[List[str]] = None):
490+
obj = ResumeMemoryOccupationReqInput(tags=tags)
493491
loop = asyncio.get_event_loop()
494492
return loop.run_until_complete(
495493
self.tokenizer_manager.resume_memory_occupation(obj, None)
@@ -670,11 +668,9 @@ def _launch_subprocesses(
670668

671669
scheduler_procs = []
672670
if server_args.dp_size == 1:
673-
# Launch tensor parallel scheduler processes
674671
memory_saver_adapter = TorchMemorySaverAdapter.create(
675672
enable=server_args.enable_memory_saver
676673
)
677-
678674
scheduler_pipe_readers = []
679675

680676
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
@@ -710,6 +706,7 @@ def _launch_subprocesses(
710706
writer,
711707
),
712708
)
709+
713710
with memory_saver_adapter.configure_subprocess():
714711
proc.start()
715712
scheduler_procs.append(proc)

python/sglang/srt/managers/io_struct.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,9 @@ class GetWeightsByNameReqOutput:
812812

813813
@dataclass
814814
class ReleaseMemoryOccupationReqInput:
815-
pass
815+
# Optional tags to identify the memory region, which is primarily used for RL
816+
# Currently we only support `weights` and `kv_cache`
817+
tags: Optional[List[str]] = None
816818

817819

818820
@dataclass
@@ -822,7 +824,9 @@ class ReleaseMemoryOccupationReqOutput:
822824

823825
@dataclass
824826
class ResumeMemoryOccupationReqInput:
825-
pass
827+
# Optional tags to identify the memory region, which is primarily used for RL
828+
# Currently we only support `weights` and `kv_cache`
829+
tags: Optional[List[str]] = None
826830

827831

828832
@dataclass

python/sglang/srt/managers/scheduler.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
from sglang.global_config import global_config
3838
from sglang.srt.configs.model_config import ModelConfig
39+
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
3940
from sglang.srt.constrained.base_grammar_backend import (
4041
INVALID_GRAMMAR_OBJ,
4142
create_grammar_backend,
@@ -450,8 +451,6 @@ def __init__(
450451
t = threading.Thread(target=self.watchdog_thread, daemon=True)
451452
t.start()
452453
self.parent_process = psutil.Process().parent()
453-
454-
# Init memory saver
455454
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
456455
enable=server_args.enable_memory_saver
457456
)
@@ -2227,23 +2226,40 @@ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
22272226
return GetWeightsByNameReqOutput(parameter)
22282227

22292228
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2230-
self.memory_saver_adapter.check_validity(
2231-
caller_name="release_memory_occupation"
2232-
)
2233-
self.stashed_model_static_state = _export_static_state(
2234-
self.tp_worker.worker.model_runner.model
2235-
)
2236-
self.memory_saver_adapter.pause()
2237-
self.flush_cache()
2229+
tags = recv_req.tags
2230+
import subprocess
2231+
2232+
if tags is None:
2233+
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
2234+
2235+
if GPU_MEMORY_TYPE_KV_CACHE in tags:
2236+
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
2237+
self.flush_cache()
2238+
2239+
if GPU_MEMORY_TYPE_WEIGHTS in tags:
2240+
self.stashed_model_static_state = _export_static_state(
2241+
self.tp_worker.worker.model_runner.model
2242+
)
2243+
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
2244+
22382245
return ReleaseMemoryOccupationReqOutput()
22392246

22402247
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2241-
self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
2242-
self.memory_saver_adapter.resume()
2243-
_import_static_state(
2244-
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
2245-
)
2246-
del self.stashed_model_static_state
2248+
tags = recv_req.tags
2249+
if tags is None or len(tags) == 0:
2250+
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
2251+
2252+
if GPU_MEMORY_TYPE_WEIGHTS in tags:
2253+
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
2254+
_import_static_state(
2255+
self.tp_worker.worker.model_runner.model,
2256+
self.stashed_model_static_state,
2257+
)
2258+
del self.stashed_model_static_state
2259+
2260+
if GPU_MEMORY_TYPE_KV_CACHE in tags:
2261+
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
2262+
22472263
return ResumeMemoryOccupationReqOutput()
22482264

22492265
def slow_down(self, recv_req: SlowDownReqInput):

python/sglang/srt/mem_cache/memory_pool.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import triton
3636
import triton.language as tl
3737

38+
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
3839
from sglang.srt.layers.radix_attention import RadixAttention
3940
from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
4041

@@ -54,14 +55,15 @@ def __init__(
5455
device: str,
5556
enable_memory_saver: bool,
5657
):
58+
5759
memory_saver_adapter = TorchMemorySaverAdapter.create(
5860
enable=enable_memory_saver
5961
)
6062

6163
self.size = size
6264
self.max_context_len = max_context_len
6365
self.device = device
64-
with memory_saver_adapter.region():
66+
with memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
6567
self.req_to_token = torch.zeros(
6668
(size, max_context_len), dtype=torch.int32, device=device
6769
)
@@ -292,7 +294,7 @@ def __init__(
292294
)
293295

294296
def _create_buffers(self):
295-
with self.memory_saver_adapter.region():
297+
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
296298
with (
297299
torch.cuda.use_mem_pool(self.custom_mem_pool)
298300
if self.enable_custom_mem_pool
@@ -610,7 +612,7 @@ def __init__(
610612
else:
611613
self.custom_mem_pool = None
612614

613-
with self.memory_saver_adapter.region():
615+
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
614616
with (
615617
torch.cuda.use_mem_pool(self.custom_mem_pool)
616618
if self.custom_mem_pool
@@ -753,7 +755,7 @@ def __init__(
753755
end_layer,
754756
)
755757

756-
with self.memory_saver_adapter.region():
758+
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
757759
# [size, head_num, head_dim] for each layer
758760
self.k_buffer = [
759761
torch.zeros(

python/sglang/srt/model_executor/model_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from sglang.srt.configs.device_config import DeviceConfig
3131
from sglang.srt.configs.load_config import LoadConfig
3232
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
33+
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
3334
from sglang.srt.distributed import (
3435
get_tp_group,
3536
get_world_group,
@@ -222,6 +223,7 @@ def __init__(
222223

223224
def initialize(self, min_per_gpu_memory: float):
224225
server_args = self.server_args
226+
225227
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
226228
enable=self.server_args.enable_memory_saver
227229
)
@@ -547,7 +549,7 @@ def load_model(self):
547549
monkey_patch_vllm_parallel_state()
548550
monkey_patch_isinstance_for_vllm_base_layer()
549551

550-
with self.memory_saver_adapter.region():
552+
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
551553
self.model = get_model(
552554
model_config=self.model_config,
553555
load_config=self.load_config,

python/sglang/srt/torch_memory_saver_adapter.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import logging
2+
import threading
3+
import time
24
from abc import ABC
3-
from contextlib import contextmanager
5+
from contextlib import contextmanager, nullcontext
46

57
try:
68
import torch_memory_saver
79

8-
_primary_memory_saver = torch_memory_saver.TorchMemorySaver()
10+
_memory_saver = torch_memory_saver.torch_memory_saver
911
import_error = None
1012
except ImportError as e:
1113
import_error = e
@@ -38,13 +40,13 @@ def check_validity(self, caller_name):
3840
def configure_subprocess(self):
3941
raise NotImplementedError
4042

41-
def region(self):
43+
def region(self, tag: str):
4244
raise NotImplementedError
4345

44-
def pause(self):
46+
def pause(self, tag: str):
4547
raise NotImplementedError
4648

47-
def resume(self):
49+
def resume(self, tag: str):
4850
raise NotImplementedError
4951

5052
@property
@@ -53,21 +55,23 @@ def enabled(self):
5355

5456

5557
class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
58+
"""Adapter for TorchMemorySaver with tag-based control"""
59+
5660
def configure_subprocess(self):
5761
return torch_memory_saver.configure_subprocess()
5862

59-
def region(self):
60-
return _primary_memory_saver.region()
63+
def region(self, tag: str):
64+
return _memory_saver.region(tag=tag)
6165

62-
def pause(self):
63-
return _primary_memory_saver.pause()
66+
def pause(self, tag: str):
67+
return _memory_saver.pause(tag=tag)
6468

65-
def resume(self):
66-
return _primary_memory_saver.resume()
69+
def resume(self, tag: str):
70+
return _memory_saver.resume(tag=tag)
6771

6872
@property
6973
def enabled(self):
70-
return _primary_memory_saver.enabled
74+
return _memory_saver is not None and _memory_saver.enabled
7175

7276

7377
class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
@@ -76,13 +80,13 @@ def configure_subprocess(self):
7680
yield
7781

7882
@contextmanager
79-
def region(self):
83+
def region(self, tag: str):
8084
yield
8185

82-
def pause(self):
86+
def pause(self, tag: str):
8387
pass
8488

85-
def resume(self):
89+
def resume(self, tag: str):
8690
pass
8791

8892
@property

python/sglang/test/test_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
# General test models
3838
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
3939
DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
40+
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE = "meta-llama/Llama-3.2-1B"
4041
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
4142
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"
4243

0 commit comments

Comments
 (0)