Skip to content

Commit b7cd743

Browse files
authored
[Feat] QWen-1M context support[2/2]: Update block sparse attention backend (sgl-project#5949)
1 parent a69b637 commit b7cd743

15 files changed

+2121
-4
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""
2+
Usage:
3+
python3 offline_batch_inference.py
4+
"""
5+
6+
from urllib.request import urlopen
7+
8+
import sglang as sgl
9+
10+
11+
def load_prompt() -> str:
12+
# Test cases with various lengths can be found at:
13+
#
14+
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt
15+
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt
16+
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt
17+
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt
18+
19+
with urlopen(
20+
"https://qianwen-res.oss-cn-beijing.aliyuncs.com"
21+
"/Qwen2.5-1M/test-data/64k.txt",
22+
timeout=5,
23+
) as response:
24+
prompt = response.read().decode("utf-8")
25+
return prompt
26+
27+
28+
# Processing the prompt.
29+
def process_requests(llm: sgl.Engine, prompts: list[str]) -> None:
30+
# Create a sampling params object.
31+
sampling_params = {
32+
"temperature": 0.7,
33+
"top_p": 0.8,
34+
"top_k": 20,
35+
"repetition_penalty": 1.05,
36+
"max_new_tokens": 256,
37+
}
38+
# Generate texts from the prompts.
39+
outputs = llm.generate(prompts, sampling_params)
40+
# Print the outputs.
41+
for output in outputs:
42+
prompt_token_ids = output["meta_info"]["prompt_tokens"]
43+
generated_text = output["text"]
44+
print(
45+
f"Prompt length: {prompt_token_ids}, " f"Generated text: {generated_text!r}"
46+
)
47+
48+
49+
# Create an LLM.
50+
def initialize_engine() -> sgl.Engine:
51+
llm = sgl.Engine(
52+
model_path="Qwen/Qwen2.5-7B-Instruct-1M",
53+
context_length=1048576,
54+
page_size=256,
55+
attention_backend="dual_chunk_flash_attn",
56+
tp_size=4,
57+
disable_radix_cache=True,
58+
enable_mixed_chunk=False,
59+
enable_torch_compile=False,
60+
chunked_prefill_size=131072,
61+
mem_fraction_static=0.6,
62+
log_level="DEBUG",
63+
)
64+
return llm
65+
66+
67+
def main():
68+
llm = initialize_engine()
69+
prompt = load_prompt()
70+
process_requests(llm, [prompt])
71+
72+
73+
if __name__ == "__main__":
74+
main()

python/sglang/srt/configs/model_config.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
get_context_length,
2828
get_generation_config,
2929
get_hf_text_config,
30+
get_sparse_attention_config,
3031
)
3132
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
3233
from sglang.srt.server_args import ServerArgs
@@ -270,6 +271,9 @@ def __init__(
270271
# Verify quantization
271272
self._verify_quantization()
272273

274+
# Verify dual-chunk attention config
275+
self._verify_dual_chunk_attention_config()
276+
273277
# Cache attributes
274278
self.hf_eos_token_id = self.get_hf_eos_token_id()
275279

@@ -297,6 +301,13 @@ def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
297301
**kwargs,
298302
)
299303

304+
def get_total_num_attention_heads(self) -> int:
305+
return self.num_attention_heads
306+
307+
def get_num_attention_heads(self, tensor_parallel_size) -> int:
308+
total_num_attention_heads = self.num_attention_heads
309+
return max(1, total_num_attention_heads // tensor_parallel_size)
310+
300311
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
301312
def get_total_num_kv_heads(self) -> int:
302313
"""Returns the total number of KV heads."""
@@ -484,6 +495,23 @@ def _verify_quantization(self) -> None:
484495
self.quantization,
485496
)
486497

498+
def _verify_dual_chunk_attention_config(self) -> None:
499+
if hasattr(self.hf_config, "dual_chunk_attention_config"):
500+
# Try loading the sparse attention config
501+
sparse_attn_config = get_sparse_attention_config(self.model_path)
502+
if not sparse_attn_config:
503+
return
504+
self.hf_config.dual_chunk_attention_config["sparse_attention_config"] = (
505+
sparse_attn_config
506+
)
507+
if (
508+
"sparse_attention_enabled"
509+
not in self.hf_config.dual_chunk_attention_config
510+
):
511+
self.hf_config.dual_chunk_attention_config[
512+
"sparse_attention_enabled"
513+
] = True
514+
487515
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
488516
eos_ids = getattr(self.hf_config, "eos_token_id", None)
489517
if eos_ids is not None:

python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ def prepare_for_prebuilt_extend(self: ScheduleBatch):
7676
req_pool_indices, dtype=torch.int64, device=self.device
7777
)
7878
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
79+
self.orig_seq_lens = torch.tensor(
80+
seq_lens, dtype=torch.int32, device=self.device
81+
)
7982
self.out_cache_loc = out_cache_loc
8083
self.seq_lens_sum = sum(seq_lens)
8184

python/sglang/srt/hf_transformers_utils.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
"""Utilities for Huggingface Transformers."""
1515

1616
import contextlib
17+
import json
1718
import os
1819
import warnings
1920
from pathlib import Path
20-
from typing import Dict, Optional, Type, Union
21+
from typing import Any, Dict, Optional, Type, Union
2122

2223
import torch
2324
from huggingface_hub import snapshot_download
@@ -62,11 +63,17 @@
6263
AutoConfig.register(name, cls)
6364

6465

65-
def download_from_hf(model_path: str):
66+
def download_from_hf(
67+
model_path: str,
68+
allow_patterns: Optional[Union[str, list]] = None,
69+
):
6670
if os.path.exists(model_path):
6771
return model_path
6872

69-
return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"])
73+
if not allow_patterns:
74+
allow_patterns = ["*.json", "*.bin", "*.model"]
75+
76+
return snapshot_download(model_path, allow_patterns=allow_patterns)
7077

7178

7279
def get_hf_text_config(config: PretrainedConfig):
@@ -171,6 +178,26 @@ def get_generation_config(
171178
return None
172179

173180

181+
# Qwen-1M related
182+
def get_sparse_attention_config(
183+
model: str,
184+
sparse_attention_config_filename: str = "sparse_attention_config.json",
185+
) -> Dict[str, Any]:
186+
is_local = os.path.isdir(model)
187+
if not is_local:
188+
# Download the config files.
189+
model = download_from_hf(model, allow_patterns=["*.json"])
190+
191+
config_file = os.path.join(model, sparse_attention_config_filename)
192+
if not os.path.exists(config_file):
193+
return {}
194+
195+
# Load the sparse attention config.
196+
with open(config_file) as f:
197+
config = json.load(f)
198+
return config
199+
200+
174201
# Models don't use the same configuration key for determining the maximum
175202
# context length. Store them here so we can sanely check them.
176203
# NOTE: The ordering here is important. Some models have two of these and we

0 commit comments

Comments
 (0)