Skip to content

Commit 9f1787f

Browse files
authored
Support multi-thread model weight loading (sgl-project#7277)
1 parent 8ecad0b commit 9f1787f

File tree

4 files changed

+143
-10
lines changed

4 files changed

+143
-10
lines changed

python/sglang/srt/model_executor/model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,7 @@ def load_model(self):
547547
self.load_config = LoadConfig(
548548
load_format=self.server_args.load_format,
549549
download_dir=self.server_args.download_dir,
550+
model_loader_extra_config=self.server_args.model_loader_extra_config,
550551
)
551552
if self.server_args.load_format == "gguf":
552553
monkey_patch_vllm_gguf_config()

python/sglang/srt/model_loader/loader.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
# ruff: noqa: SIM117
44
import collections
5+
import concurrent
56
import dataclasses
67
import fnmatch
78
import glob
@@ -11,14 +12,17 @@
1112
import os
1213
import time
1314
from abc import ABC, abstractmethod
15+
from concurrent.futures import ThreadPoolExecutor
1416
from contextlib import contextmanager
1517
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
1618

1719
import huggingface_hub
1820
import numpy as np
21+
import safetensors.torch
1922
import torch
2023
from huggingface_hub import HfApi, hf_hub_download
2124
from torch import nn
25+
from tqdm.auto import tqdm
2226
from transformers import AutoModelForCausalLM
2327
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
2428

@@ -41,6 +45,7 @@
4145
set_default_torch_dtype,
4246
)
4347
from sglang.srt.model_loader.weight_utils import (
48+
_BAR_FORMAT,
4449
download_safetensors_index_file_from_hf,
4550
download_weights_from_hf,
4651
filter_duplicate_safetensors_files,
@@ -49,6 +54,8 @@
4954
get_quant_config,
5055
gguf_quant_weights_iterator,
5156
initialize_dummy_weights,
57+
multi_thread_pt_weights_iterator,
58+
multi_thread_safetensors_weights_iterator,
5259
np_cache_weights_iterator,
5360
pt_weights_iterator,
5461
safetensors_weights_iterator,
@@ -181,6 +188,9 @@ def load_model(
181188
class DefaultModelLoader(BaseModelLoader):
182189
"""Model loader that can load different file types from disk."""
183190

191+
# default number of thread when enable multithread weight loading
192+
DEFAULT_NUM_THREADS = 8
193+
184194
@dataclasses.dataclass
185195
class Source:
186196
"""A source for weights."""
@@ -208,10 +218,15 @@ def init_new(cls, model_config: ModelConfig, model):
208218

209219
def __init__(self, load_config: LoadConfig):
210220
super().__init__(load_config)
211-
if load_config.model_loader_extra_config:
221+
extra_config = load_config.model_loader_extra_config
222+
allowed_keys = {"enable_multithread_load", "num_threads"}
223+
unexpected_keys = set(extra_config.keys()) - allowed_keys
224+
225+
if unexpected_keys:
212226
raise ValueError(
213-
f"Model loader extra config is not supported for "
214-
f"load format {load_config.load_format}"
227+
f"Unexpected extra config keys for load format "
228+
f"{load_config.load_format}: "
229+
f"{unexpected_keys}"
215230
)
216231

217232
def _maybe_download_from_modelscope(
@@ -324,6 +339,7 @@ def _get_weights_iterator(
324339
self, source: "Source"
325340
) -> Generator[Tuple[str, torch.Tensor], None, None]:
326341
"""Get an iterator for the model weights based on the load format."""
342+
extra_config = self.load_config.model_loader_extra_config
327343
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
328344
source.model_or_path, source.revision, source.fall_back_to_pt
329345
)
@@ -342,11 +358,30 @@ def _get_weights_iterator(
342358
weight_loader_disable_mmap = global_server_args_dict.get(
343359
"weight_loader_disable_mmap"
344360
)
345-
weights_iterator = safetensors_weights_iterator(
346-
hf_weights_files, disable_mmap=weight_loader_disable_mmap
347-
)
361+
362+
if extra_config.get("enable_multithread_load"):
363+
weights_iterator = multi_thread_safetensors_weights_iterator(
364+
hf_weights_files,
365+
max_workers=extra_config.get(
366+
"num_threads", self.DEFAULT_NUM_THREADS
367+
),
368+
disable_mmap=weight_loader_disable_mmap,
369+
)
370+
else:
371+
weights_iterator = safetensors_weights_iterator(
372+
hf_weights_files, disable_mmap=weight_loader_disable_mmap
373+
)
374+
348375
else:
349-
weights_iterator = pt_weights_iterator(hf_weights_files)
376+
if extra_config.get("enable_multithread_load"):
377+
weights_iterator = multi_thread_pt_weights_iterator(
378+
hf_weights_files,
379+
max_workers=extra_config.get(
380+
"num_threads", self.DEFAULT_NUM_THREADS
381+
),
382+
)
383+
else:
384+
weights_iterator = pt_weights_iterator(hf_weights_files)
350385

351386
# Apply the prefix.
352387
return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
@@ -385,9 +420,9 @@ def load_model(
385420
self.load_config,
386421
)
387422

388-
self.load_weights_and_postprocess(
389-
model, self._get_all_weights(model_config, model), target_device
390-
)
423+
self.load_weights_and_postprocess(
424+
model, self._get_all_weights(model_config, model), target_device
425+
)
391426

392427
return model.eval()
393428

python/sglang/srt/model_loader/weight_utils.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py
22

33
"""Utilities for downloading and initializing model weights."""
4+
import concurrent.futures
45
import fnmatch
56
import glob
67
import hashlib
78
import json
89
import logging
910
import os
11+
import queue
1012
import tempfile
1113
from collections import defaultdict
1214
from typing import (
@@ -453,6 +455,60 @@ def safetensors_weights_iterator(
453455
yield name, param
454456

455457

458+
def multi_thread_safetensors_weights_iterator(
459+
hf_weights_files: List[str],
460+
is_all_weights_sharded: bool = False,
461+
decryption_key: Optional[str] = None,
462+
max_workers: int = 4,
463+
disable_mmap: bool = False,
464+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
465+
"""Multi-Thread iterate over the weights in the model safetensor files.
466+
467+
If is_all_weights_sharded is True, it uses more optimize read by reading an
468+
entire file instead of reading each tensor one by one.
469+
"""
470+
if decryption_key:
471+
logger.warning(
472+
"Multi-Thread loading is not working for encrypted safetensor weights."
473+
)
474+
yield from safetensors_encrypted_weights_iterator(
475+
hf_weights_files, is_all_weights_sharded, decryption_key
476+
)
477+
return
478+
479+
enable_tqdm = (
480+
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
481+
)
482+
483+
def _load_file(st_file: str):
484+
if disable_mmap:
485+
with open(st_file, "rb") as f:
486+
result = safetensors.torch.load(f.read())
487+
else:
488+
result = safetensors.torch.load_file(st_file, device="cpu")
489+
490+
return result
491+
492+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
493+
futures = [executor.submit(_load_file, st_file) for st_file in hf_weights_files]
494+
495+
if enable_tqdm:
496+
futures_iter = tqdm(
497+
concurrent.futures.as_completed(futures),
498+
total=len(hf_weights_files),
499+
desc="Multi-thread loading shards",
500+
disable=not enable_tqdm,
501+
bar_format=_BAR_FORMAT,
502+
)
503+
else:
504+
futures_iter = concurrent.futures.as_completed(futures)
505+
506+
for future in futures_iter:
507+
state_dict = future.result()
508+
for name, param in state_dict.items():
509+
yield name, param
510+
511+
456512
def pt_weights_iterator(
457513
hf_weights_files: List[str],
458514
) -> Generator[Tuple[str, torch.Tensor], None, None]:
@@ -471,6 +527,39 @@ def pt_weights_iterator(
471527
del state
472528

473529

530+
def multi_thread_pt_weights_iterator(
531+
hf_weights_files: List[str],
532+
max_workers: int = 4,
533+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
534+
"""Multi-Thread iterate over the weights in the model bin/pt files."""
535+
enable_tqdm = (
536+
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
537+
)
538+
539+
def _load_file(bin_file: str):
540+
return torch.load(bin_file, map_location="cpu", weights_only=True)
541+
542+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
543+
futures = [
544+
executor.submit(_load_file, bin_file) for bin_file in hf_weights_files
545+
]
546+
547+
if enable_tqdm:
548+
futures_iter = tqdm(
549+
concurrent.futures.as_completed(futures),
550+
total=len(hf_weights_files),
551+
desc="Multi-thread loading pt checkpoint shards",
552+
disable=not enable_tqdm,
553+
bar_format=_BAR_FORMAT,
554+
)
555+
else:
556+
futures_iter = concurrent.futures.as_completed(futures)
557+
558+
for future in futures_iter:
559+
state = future.result()
560+
yield from state.items()
561+
562+
474563
def get_gguf_extra_tensor_names(
475564
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
476565
) -> List[str]:

python/sglang/srt/server_args.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class ServerArgs:
4747
tokenizer_mode: str = "auto"
4848
skip_tokenizer_init: bool = False
4949
load_format: str = "auto"
50+
model_loader_extra_config: str = "{}"
5051
trust_remote_code: bool = False
5152
dtype: str = "auto"
5253
kv_cache_dtype: str = "auto"
@@ -632,6 +633,13 @@ def add_cli_args(parser: argparse.ArgumentParser):
632633
"layer before loading another to make the peak memory envelope "
633634
"smaller.",
634635
)
636+
parser.add_argument(
637+
"--model-loader-extra-config",
638+
type=str,
639+
help="Extra config for model loader. "
640+
"This will be passed to the model loader corresponding to the chosen load_format.",
641+
default=ServerArgs.model_loader_extra_config,
642+
)
635643
parser.add_argument(
636644
"--trust-remote-code",
637645
action="store_true",

0 commit comments

Comments
 (0)