Skip to content

Commit 98c73d7

Browse files
authored
[Minor] make the __init__ function of model_runner.py shorter (sgl-project#4132)
1 parent fcc2e37 commit 98c73d7

File tree

2 files changed

+87
-70
lines changed

2 files changed

+87
-70
lines changed

python/sglang/srt/model_executor/cuda_graph_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def recapture_if_needed(self, forward_batch: ForwardBatch):
427427
self.capture_hidden_mode = hidden_mode_from_spec_info
428428
self.capture()
429429

430-
def replay(self, forward_batch: ForwardBatch):
430+
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
431431
self.recapture_if_needed(forward_batch)
432432

433433
raw_bs = forward_batch.batch_size

python/sglang/srt/model_executor/model_runner.py

Lines changed: 86 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -122,66 +122,17 @@ def __init__(
122122
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
123123

124124
# Model-specific adjustment
125-
if (
126-
self.model_config.attention_arch == AttentionArch.MLA
127-
and not self.server_args.disable_mla
128-
):
129-
# TODO: add MLA optimization on CPU
130-
if self.server_args.device != "cpu":
131-
if server_args.enable_flashinfer_mla:
132-
logger.info(
133-
"MLA optimization is turned on. Use flashinfer mla backend."
134-
)
135-
self.server_args.attention_backend = "flashinfer_mla"
136-
else:
137-
logger.info("MLA optimization is turned on. Use triton backend.")
138-
self.server_args.attention_backend = "triton"
139-
140-
if self.server_args.enable_double_sparsity:
141-
logger.info(
142-
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
143-
)
144-
self.server_args.attention_backend = "triton"
145-
self.server_args.disable_cuda_graph = True
146-
if self.server_args.ds_heavy_channel_type is None:
147-
raise ValueError(
148-
"Please specify the heavy channel type for double sparsity optimization."
149-
)
150-
self.init_double_sparsity_channel_config(
151-
self.server_args.ds_heavy_channel_type
152-
)
125+
self.model_specific_adjustment()
153126

154-
if self.is_multimodal:
155-
self.mem_fraction_static *= 0.95
156-
logger.info(
157-
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
158-
f"because this is a multimodal model."
159-
)
160-
161-
if self.model_config.hf_config.architectures == [
162-
"MllamaForConditionalGeneration"
163-
]:
164-
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
165-
server_args.chunked_prefill_size = -1
166-
167-
if self.model_config.hf_config.architectures == [
168-
"Qwen2VLForConditionalGeneration"
169-
]:
170-
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
171-
logger.info(
172-
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
173-
)
174-
server_args.chunked_prefill_size = -1
175-
server_args.disable_radix_cache = True
176-
177-
# Global vars
178127
if server_args.show_time_cost:
179128
enable_show_time_cost()
129+
180130
if server_args.disable_outlines_disk_cache:
181131
from outlines.caching import disable_cache
182132

183133
disable_cache()
184134

135+
# Global vars
185136
global_server_args_dict.update(
186137
{
187138
"attention_backend": server_args.attention_backend,
@@ -203,6 +154,7 @@ def __init__(
203154
}
204155
)
205156

157+
# CPU offload
206158
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
207159

208160
# Get memory before model loading
@@ -216,18 +168,6 @@ def __init__(
216168
self.sampler = Sampler()
217169
self.load_model()
218170

219-
# Handle the case where some of models don't finish loading.
220-
try:
221-
dist.monitored_barrier(
222-
group=get_tp_group().cpu_group,
223-
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
224-
wait_all_ranks=True,
225-
)
226-
except RuntimeError:
227-
raise ValueError(
228-
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
229-
) from None
230-
231171
# Apply torchao quantization
232172
torchao_applied = getattr(self.model, "torchao_applied", False)
233173
# In layered loading, torchao may have been applied
@@ -244,9 +184,11 @@ def __init__(
244184
else:
245185
self.torch_tp_applied = False
246186

247-
# Init memory pool and attention backends
187+
# Init lora
248188
if server_args.lora_paths is not None:
249189
self.init_lora_manager()
190+
191+
# Init memory pool and attention backends
250192
self.init_memory_pool(
251193
min_per_gpu_memory,
252194
server_args.max_running_requests,
@@ -260,10 +202,63 @@ def __init__(
260202
self.cuda_graph_runner = None
261203
self.init_attention_backend()
262204

205+
def model_specific_adjustment(self):
206+
server_args = self.server_args
207+
208+
if (
209+
self.model_config.attention_arch == AttentionArch.MLA
210+
and not server_args.disable_mla
211+
):
212+
# TODO: add MLA optimization on CPU
213+
if server_args.device != "cpu":
214+
if server_args.enable_flashinfer_mla:
215+
logger.info(
216+
"MLA optimization is turned on. Use flashinfer mla backend."
217+
)
218+
server_args.attention_backend = "flashinfer_mla"
219+
else:
220+
logger.info("MLA optimization is turned on. Use triton backend.")
221+
server_args.attention_backend = "triton"
222+
223+
if server_args.enable_double_sparsity:
224+
logger.info(
225+
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
226+
)
227+
server_args.attention_backend = "triton"
228+
server_args.disable_cuda_graph = True
229+
if server_args.ds_heavy_channel_type is None:
230+
raise ValueError(
231+
"Please specify the heavy channel type for double sparsity optimization."
232+
)
233+
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
234+
235+
if self.is_multimodal:
236+
self.mem_fraction_static *= 0.95
237+
logger.info(
238+
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
239+
f"because this is a multimodal model."
240+
)
241+
242+
if self.model_config.hf_config.architectures == [
243+
"MllamaForConditionalGeneration"
244+
]:
245+
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
246+
server_args.chunked_prefill_size = -1
247+
248+
if self.model_config.hf_config.architectures == [
249+
"Qwen2VLForConditionalGeneration"
250+
]:
251+
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
252+
logger.info(
253+
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
254+
)
255+
server_args.chunked_prefill_size = -1
256+
server_args.disable_radix_cache = True
257+
263258
def init_torch_distributed(self):
264259
logger.info("Init torch distributed begin.")
265-
torch.get_device_module(self.device).set_device(self.gpu_id)
266260

261+
torch.get_device_module(self.device).set_device(self.gpu_id)
267262
if self.device == "cuda":
268263
backend = "nccl"
269264
elif self.device == "xpu":
@@ -400,6 +395,18 @@ def load_model(self):
400395
f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
401396
)
402397

398+
# Handle the case where some ranks do not finish loading.
399+
try:
400+
dist.monitored_barrier(
401+
group=get_tp_group().cpu_group,
402+
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
403+
wait_all_ranks=True,
404+
)
405+
except RuntimeError:
406+
raise ValueError(
407+
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
408+
) from None
409+
403410
def update_weights_from_disk(
404411
self, model_path: str, load_format: str
405412
) -> tuple[bool, str]:
@@ -772,6 +779,10 @@ def init_cublas(self):
772779
def init_attention_backend(self):
773780
"""Init attention kernel backend."""
774781
if self.server_args.attention_backend == "flashinfer":
782+
# Init streams
783+
if self.server_args.speculative_algorithm == "EAGLE":
784+
self.plan_stream_for_flashinfer = torch.cuda.Stream()
785+
775786
self.attn_backend = FlashInferAttnBackend(self)
776787
elif self.server_args.attention_backend == "triton":
777788
assert self.sliding_window_size is None, (
@@ -880,18 +891,24 @@ def forward_idle(self, forward_batch: ForwardBatch):
880891
forward_batch.input_ids, forward_batch.positions, forward_batch
881892
)
882893

883-
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
894+
def forward(
895+
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
896+
) -> LogitsProcessorOutput:
884897
if (
885898
forward_batch.forward_mode.is_cuda_graph()
886899
and self.cuda_graph_runner
887900
and self.cuda_graph_runner.can_run(forward_batch)
888901
):
889-
return self.cuda_graph_runner.replay(forward_batch)
902+
return self.cuda_graph_runner.replay(
903+
forward_batch, skip_attn_backend_init=skip_attn_backend_init
904+
)
890905

891906
if forward_batch.forward_mode.is_decode():
892907
return self.forward_decode(forward_batch)
893908
elif forward_batch.forward_mode.is_extend():
894-
return self.forward_extend(forward_batch)
909+
return self.forward_extend(
910+
forward_batch, skip_attn_backend_init=skip_attn_backend_init
911+
)
895912
elif forward_batch.forward_mode.is_idle():
896913
return self.forward_idle(forward_batch)
897914
else:

0 commit comments

Comments
 (0)