@@ -122,66 +122,17 @@ def __init__(
122
122
self .token_to_kv_pool_allocator = token_to_kv_pool_allocator
123
123
124
124
# Model-specific adjustment
125
- if (
126
- self .model_config .attention_arch == AttentionArch .MLA
127
- and not self .server_args .disable_mla
128
- ):
129
- # TODO: add MLA optimization on CPU
130
- if self .server_args .device != "cpu" :
131
- if server_args .enable_flashinfer_mla :
132
- logger .info (
133
- "MLA optimization is turned on. Use flashinfer mla backend."
134
- )
135
- self .server_args .attention_backend = "flashinfer_mla"
136
- else :
137
- logger .info ("MLA optimization is turned on. Use triton backend." )
138
- self .server_args .attention_backend = "triton"
139
-
140
- if self .server_args .enable_double_sparsity :
141
- logger .info (
142
- "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
143
- )
144
- self .server_args .attention_backend = "triton"
145
- self .server_args .disable_cuda_graph = True
146
- if self .server_args .ds_heavy_channel_type is None :
147
- raise ValueError (
148
- "Please specify the heavy channel type for double sparsity optimization."
149
- )
150
- self .init_double_sparsity_channel_config (
151
- self .server_args .ds_heavy_channel_type
152
- )
125
+ self .model_specific_adjustment ()
153
126
154
- if self .is_multimodal :
155
- self .mem_fraction_static *= 0.95
156
- logger .info (
157
- f"Automatically reduce --mem-fraction-static to { self .mem_fraction_static :.3f} "
158
- f"because this is a multimodal model."
159
- )
160
-
161
- if self .model_config .hf_config .architectures == [
162
- "MllamaForConditionalGeneration"
163
- ]:
164
- logger .info ("Automatically turn off --chunked-prefill-size for mllama." )
165
- server_args .chunked_prefill_size = - 1
166
-
167
- if self .model_config .hf_config .architectures == [
168
- "Qwen2VLForConditionalGeneration"
169
- ]:
170
- # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
171
- logger .info (
172
- "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
173
- )
174
- server_args .chunked_prefill_size = - 1
175
- server_args .disable_radix_cache = True
176
-
177
- # Global vars
178
127
if server_args .show_time_cost :
179
128
enable_show_time_cost ()
129
+
180
130
if server_args .disable_outlines_disk_cache :
181
131
from outlines .caching import disable_cache
182
132
183
133
disable_cache ()
184
134
135
+ # Global vars
185
136
global_server_args_dict .update (
186
137
{
187
138
"attention_backend" : server_args .attention_backend ,
@@ -203,6 +154,7 @@ def __init__(
203
154
}
204
155
)
205
156
157
+ # CPU offload
206
158
set_cpu_offload_max_bytes (int (server_args .cpu_offload_gb * 1024 ** 3 ))
207
159
208
160
# Get memory before model loading
@@ -216,18 +168,6 @@ def __init__(
216
168
self .sampler = Sampler ()
217
169
self .load_model ()
218
170
219
- # Handle the case where some of models don't finish loading.
220
- try :
221
- dist .monitored_barrier (
222
- group = get_tp_group ().cpu_group ,
223
- timeout = datetime .timedelta (seconds = UNBALANCED_MODEL_LOADING_TIMEOUT_S ),
224
- wait_all_ranks = True ,
225
- )
226
- except RuntimeError :
227
- raise ValueError (
228
- f"TP rank { self .tp_rank } could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
229
- ) from None
230
-
231
171
# Apply torchao quantization
232
172
torchao_applied = getattr (self .model , "torchao_applied" , False )
233
173
# In layered loading, torchao may have been applied
@@ -244,9 +184,11 @@ def __init__(
244
184
else :
245
185
self .torch_tp_applied = False
246
186
247
- # Init memory pool and attention backends
187
+ # Init lora
248
188
if server_args .lora_paths is not None :
249
189
self .init_lora_manager ()
190
+
191
+ # Init memory pool and attention backends
250
192
self .init_memory_pool (
251
193
min_per_gpu_memory ,
252
194
server_args .max_running_requests ,
@@ -260,10 +202,63 @@ def __init__(
260
202
self .cuda_graph_runner = None
261
203
self .init_attention_backend ()
262
204
205
+ def model_specific_adjustment (self ):
206
+ server_args = self .server_args
207
+
208
+ if (
209
+ self .model_config .attention_arch == AttentionArch .MLA
210
+ and not server_args .disable_mla
211
+ ):
212
+ # TODO: add MLA optimization on CPU
213
+ if server_args .device != "cpu" :
214
+ if server_args .enable_flashinfer_mla :
215
+ logger .info (
216
+ "MLA optimization is turned on. Use flashinfer mla backend."
217
+ )
218
+ server_args .attention_backend = "flashinfer_mla"
219
+ else :
220
+ logger .info ("MLA optimization is turned on. Use triton backend." )
221
+ server_args .attention_backend = "triton"
222
+
223
+ if server_args .enable_double_sparsity :
224
+ logger .info (
225
+ "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
226
+ )
227
+ server_args .attention_backend = "triton"
228
+ server_args .disable_cuda_graph = True
229
+ if server_args .ds_heavy_channel_type is None :
230
+ raise ValueError (
231
+ "Please specify the heavy channel type for double sparsity optimization."
232
+ )
233
+ self .init_double_sparsity_channel_config (server_args .ds_heavy_channel_type )
234
+
235
+ if self .is_multimodal :
236
+ self .mem_fraction_static *= 0.95
237
+ logger .info (
238
+ f"Automatically reduce --mem-fraction-static to { self .mem_fraction_static :.3f} "
239
+ f"because this is a multimodal model."
240
+ )
241
+
242
+ if self .model_config .hf_config .architectures == [
243
+ "MllamaForConditionalGeneration"
244
+ ]:
245
+ logger .info ("Automatically turn off --chunked-prefill-size for mllama." )
246
+ server_args .chunked_prefill_size = - 1
247
+
248
+ if self .model_config .hf_config .architectures == [
249
+ "Qwen2VLForConditionalGeneration"
250
+ ]:
251
+ # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
252
+ logger .info (
253
+ "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
254
+ )
255
+ server_args .chunked_prefill_size = - 1
256
+ server_args .disable_radix_cache = True
257
+
263
258
def init_torch_distributed (self ):
264
259
logger .info ("Init torch distributed begin." )
265
- torch .get_device_module (self .device ).set_device (self .gpu_id )
266
260
261
+ torch .get_device_module (self .device ).set_device (self .gpu_id )
267
262
if self .device == "cuda" :
268
263
backend = "nccl"
269
264
elif self .device == "xpu" :
@@ -400,6 +395,18 @@ def load_model(self):
400
395
f"mem usage={ (before_avail_memory - after_avail_memory ):.2f} GB."
401
396
)
402
397
398
+ # Handle the case where some ranks do not finish loading.
399
+ try :
400
+ dist .monitored_barrier (
401
+ group = get_tp_group ().cpu_group ,
402
+ timeout = datetime .timedelta (seconds = UNBALANCED_MODEL_LOADING_TIMEOUT_S ),
403
+ wait_all_ranks = True ,
404
+ )
405
+ except RuntimeError :
406
+ raise ValueError (
407
+ f"TP rank { self .tp_rank } could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
408
+ ) from None
409
+
403
410
def update_weights_from_disk (
404
411
self , model_path : str , load_format : str
405
412
) -> tuple [bool , str ]:
@@ -772,6 +779,10 @@ def init_cublas(self):
772
779
def init_attention_backend (self ):
773
780
"""Init attention kernel backend."""
774
781
if self .server_args .attention_backend == "flashinfer" :
782
+ # Init streams
783
+ if self .server_args .speculative_algorithm == "EAGLE" :
784
+ self .plan_stream_for_flashinfer = torch .cuda .Stream ()
785
+
775
786
self .attn_backend = FlashInferAttnBackend (self )
776
787
elif self .server_args .attention_backend == "triton" :
777
788
assert self .sliding_window_size is None , (
@@ -880,18 +891,24 @@ def forward_idle(self, forward_batch: ForwardBatch):
880
891
forward_batch .input_ids , forward_batch .positions , forward_batch
881
892
)
882
893
883
- def forward (self , forward_batch : ForwardBatch ) -> LogitsProcessorOutput :
894
+ def forward (
895
+ self , forward_batch : ForwardBatch , skip_attn_backend_init : bool = False
896
+ ) -> LogitsProcessorOutput :
884
897
if (
885
898
forward_batch .forward_mode .is_cuda_graph ()
886
899
and self .cuda_graph_runner
887
900
and self .cuda_graph_runner .can_run (forward_batch )
888
901
):
889
- return self .cuda_graph_runner .replay (forward_batch )
902
+ return self .cuda_graph_runner .replay (
903
+ forward_batch , skip_attn_backend_init = skip_attn_backend_init
904
+ )
890
905
891
906
if forward_batch .forward_mode .is_decode ():
892
907
return self .forward_decode (forward_batch )
893
908
elif forward_batch .forward_mode .is_extend ():
894
- return self .forward_extend (forward_batch )
909
+ return self .forward_extend (
910
+ forward_batch , skip_attn_backend_init = skip_attn_backend_init
911
+ )
895
912
elif forward_batch .forward_mode .is_idle ():
896
913
return self .forward_idle (forward_batch )
897
914
else :
0 commit comments