Skip to content

Commit 4ede677

Browse files
authored
Fix retract for page size > 1 (sgl-project#4914)
1 parent b26bc86 commit 4ede677

File tree

10 files changed

+68
-120
lines changed

10 files changed

+68
-120
lines changed

.github/workflows/pr-test.yml

Lines changed: 2 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -87,53 +87,11 @@ jobs:
8787
run: |
8888
bash scripts/ci_install_dependency.sh
8989
90-
- name: Test data parallelism (DP=2)
91-
timeout-minutes: 10
92-
run: |
93-
cd test/srt
94-
python3 test_data_parallelism.py
95-
96-
- name: Test data parallelism attention (DP=2)
97-
timeout-minutes: 10
98-
run: |
99-
cd test/srt
100-
python3 test_dp_attention.py
101-
102-
- name: Test update weights from distributed
103-
timeout-minutes: 10
104-
run: |
105-
cd test/srt
106-
python3 test_update_weights_from_distributed.py
107-
108-
- name: Test VerlEngine
109-
timeout-minutes: 10
110-
run: |
111-
cd test/srt
112-
python3 test_verl_engine.py
113-
114-
- name: Test Patch Torch
115-
timeout-minutes: 10
116-
run: |
117-
cd test/srt
118-
python3 test_patch_torch.py
119-
120-
- name: Test expert parallelism (EP=2)
121-
timeout-minutes: 10
122-
run: |
123-
cd test/srt
124-
python3 test_moe_ep.py
125-
126-
- name: Test torch compile (TP=2)
90+
- name: Run test
12791
timeout-minutes: 10
12892
run: |
12993
cd test/srt
130-
python3 test_mla_tp.py
131-
132-
- name: Test lora tensor parallelism (TP=2)
133-
timeout-minutes: 10
134-
run: |
135-
cd test/srt/models/lora
136-
python3 test_lora_tp.py
94+
python3 run_suite.py --suite per-commit-2-gpu
13795
13896
performance-test-1-gpu-part-1:
13997
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&

python/sglang/srt/constrained/base_grammar_backend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,9 @@ def reset(self):
169169
self.cache.clear()
170170

171171

172-
def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
172+
def create_grammar_backend(
173+
server_args: ServerArgs, tokenizer, vocab_size: int
174+
) -> Optional[BaseGrammarBackend]:
173175
if server_args.grammar_backend == "outlines":
174176
from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
175177

@@ -188,6 +190,8 @@ def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
188190
tokenizer=tokenizer,
189191
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
190192
)
193+
elif server_args.grammar_backend == "none":
194+
return None
191195
else:
192196
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
193197

python/sglang/srt/managers/schedule_batch.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,7 @@ def reset_for_retract(self):
599599
self.extend_logprob_start_len = 0
600600
self.is_chunked = 0
601601
self.req_pool_idx = None
602+
self.already_computed = 0
602603

603604
def __repr__(self):
604605
return (
@@ -960,8 +961,6 @@ def prepare_for_extend(self):
960961
# If req.input_embeds is already a list, append its content directly
961962
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
962963

963-
if req.is_retracted:
964-
req.already_computed = 0
965964
req.cached_tokens += pre_len - req.already_computed
966965
req.already_computed = seq_len
967966
req.is_retracted = False
@@ -1189,7 +1188,11 @@ def get_required_tokens(num_reqs: int):
11891188
self.req_to_token_pool.free(req.req_pool_idx)
11901189
else:
11911190
# TODO: apply more fine-grained retraction
1192-
last_uncached_pos = len(req.prefix_indices)
1191+
last_uncached_pos = (
1192+
(len(req.prefix_indices) + server_args.page_size - 1)
1193+
// server_args.page_size
1194+
* server_args.page_size
1195+
)
11931196
token_indices = self.req_to_token_pool.req_to_token[
11941197
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
11951198
]

python/sglang/srt/metrics/collector.py

Lines changed: 23 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class SchedulerMetricsCollector:
3333

3434
def __init__(self, labels: Dict[str, str]) -> None:
3535
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
36-
from prometheus_client import Gauge
36+
from prometheus_client import Gauge, Histogram
3737

3838
self.labels = labels
3939
self.last_log_time = time.time()
@@ -139,10 +139,10 @@ def __init__(self, labels: Dict[str, str]) -> None:
139139
labelnames=labels.keys(),
140140
buckets=[
141141
0.1,
142-
0.3,
143-
0.5,
144-
0.7,
145-
0.9,
142+
0.2,
143+
0.4,
144+
0.6,
145+
0.8,
146146
1,
147147
2,
148148
4,
@@ -153,36 +153,9 @@ def __init__(self, labels: Dict[str, str]) -> None:
153153
40,
154154
60,
155155
80,
156-
120,
157-
160,
158-
],
159-
)
160-
161-
self.histogram_time_per_output_token = Histogram(
162-
name="sglang:time_per_output_token_seconds",
163-
documentation="Histogram of time per output token in seconds.",
164-
labelnames=labels.keys(),
165-
buckets=[
166-
0.002,
167-
0.005,
168-
0.010,
169-
0.020,
170-
0.030,
171-
0.040,
172-
0.050,
173-
0.060,
174-
0.070,
175-
0.080,
176-
0.090,
177-
0.100,
178-
0.150,
179-
0.200,
180-
0.300,
181-
0.400,
182-
0.600,
183-
0.800,
184-
1.000,
185-
2.000,
156+
100,
157+
200,
158+
400,
186159
],
187160
)
188161

@@ -202,17 +175,18 @@ def __init__(self, labels: Dict[str, str]) -> None:
202175
0.030,
203176
0.035,
204177
0.040,
205-
0.050,
206-
0.075,
178+
0.060,
179+
0.080,
207180
0.100,
208-
0.150,
209181
0.200,
210-
0.300,
211182
0.400,
212-
0.500,
213-
0.750,
183+
0.600,
184+
0.800,
214185
1.000,
215186
2.000,
187+
4.000,
188+
6.000,
189+
8.000,
216190
],
217191
)
218192

@@ -224,23 +198,22 @@ def __init__(self, labels: Dict[str, str]) -> None:
224198
0.1,
225199
0.2,
226200
0.4,
201+
0.6,
227202
0.8,
228203
1,
229204
2,
230-
5,
205+
4,
206+
6,
207+
8,
231208
10,
232209
20,
233210
40,
234211
60,
235212
80,
236213
100,
237-
150,
238214
200,
239-
250,
240-
300,
241-
350,
242-
500,
243-
1000,
215+
400,
216+
800,
244217
],
245218
)
246219

@@ -256,13 +229,10 @@ def observe_one_finished_request(
256229
):
257230
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
258231
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
259-
self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
232+
if cached_tokens > 0:
233+
self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
260234
self.num_requests_total.labels(**self.labels).inc(1)
261235
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
262-
if generation_tokens >= 1:
263-
self.histogram_time_per_output_token.labels(**self.labels).observe(
264-
e2e_latency / generation_tokens
265-
)
266236

267237
def observe_time_to_first_token(self, value: float):
268238
self.histogram_time_to_first_token.labels(**self.labels).observe(value)

python/sglang/srt/server_args.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ class ServerArgs:
128128
# Kernel backend
129129
attention_backend: Optional[str] = None
130130
sampling_backend: Optional[str] = None
131-
grammar_backend: Optional[str] = "xgrammar"
131+
grammar_backend: Optional[str] = None
132132

133133
# Speculative decoding
134134
speculative_algorithm: Optional[str] = None
@@ -193,6 +193,13 @@ class ServerArgs:
193193
disaggregation_bootstrap_port: int = 8998
194194

195195
def __post_init__(self):
196+
# Expert parallelism
197+
if self.enable_ep_moe:
198+
self.ep_size = self.tp_size
199+
logger.info(
200+
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
201+
)
202+
196203
# Set missing default values
197204
if self.tokenizer_path is None:
198205
self.tokenizer_path = self.model_path
@@ -274,12 +281,9 @@ def __post_init__(self):
274281
)
275282
self.disable_cuda_graph = True
276283

277-
# Expert parallelism
278-
if self.enable_ep_moe:
279-
self.ep_size = self.tp_size
280-
logger.info(
281-
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
282-
)
284+
# Choose grammar backend
285+
if self.grammar_backend is None:
286+
self.grammar_backend = "xgrammar"
283287

284288
# Data parallelism attention
285289
if self.enable_dp_attention:
@@ -813,7 +817,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
813817
parser.add_argument(
814818
"--grammar-backend",
815819
type=str,
816-
choices=["xgrammar", "outlines", "llguidance"],
820+
choices=["xgrammar", "outlines", "llguidance", "none"],
817821
default=ServerArgs.grammar_backend,
818822
help="Choose the backend for grammar-guided decoding.",
819823
)

python/sglang/test/test_utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,9 +1012,6 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
10121012

10131013

10141014
class CustomTestCase(unittest.TestCase):
1015-
pass
1016-
1017-
"""
10181015
def _callTestMethod(self, method):
10191016
max_retry = int(
10201017
os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0")
@@ -1023,4 +1020,3 @@ def _callTestMethod(self, method):
10231020
lambda: super(CustomTestCase, self)._callTestMethod(method),
10241021
max_retry=max_retry,
10251022
)
1026-
"""

test/srt/models/lora/test_lora_tp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
],
3434
max_loras_per_batch=1,
3535
),
36+
]
37+
38+
ALL_OTHER_LORA_MODELS = [
3639
LoRAModelCase(
3740
base="meta-llama/Llama-3.1-8B-Instruct",
3841
adaptors=[
@@ -43,9 +46,6 @@
4346
],
4447
max_loras_per_batch=1,
4548
),
46-
]
47-
48-
ALL_OTHER_LORA_MODELS = [
4949
LoRAModelCase(
5050
base="meta-llama/Llama-2-7b-hf",
5151
adaptors=[LoRAAdaptor(name="winddude/wizardLM-LlaMA-LoRA-7B")],

test/srt/run_suite.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class TestFile:
1616
TestFile("models/lora/test_lora.py", 76),
1717
TestFile("models/lora/test_lora_backend.py", 420),
1818
TestFile("models/lora/test_multi_lora_backend.py", 144),
19-
TestFile("models/test_embedding_models.py", 119),
19+
TestFile("models/test_embedding_models.py", 35),
2020
TestFile("models/test_generation_models.py", 103),
2121
TestFile("models/test_grok_models.py", 60),
2222
TestFile("models/test_qwen_models.py", 82),
@@ -38,7 +38,7 @@ class TestFile:
3838
TestFile("test_metrics.py", 32),
3939
TestFile("test_mla.py", 92),
4040
TestFile("test_mla_deepseek_v3.py", 221),
41-
TestFile("test_mla_int8_deepseek_v3.py", 421),
41+
TestFile("test_mla_int8_deepseek_v3.py", 522),
4242
TestFile("test_mla_flashinfer.py", 395),
4343
TestFile("test_mla_fp8.py", 93),
4444
TestFile("test_no_chunked_prefill.py", 126),
@@ -59,7 +59,7 @@ class TestFile:
5959
TestFile("test_srt_endpoint.py", 94),
6060
TestFile("test_torch_compile.py", 76),
6161
TestFile("test_torch_compile_moe.py", 85),
62-
TestFile("test_torch_native_attention_backend.py", 149),
62+
TestFile("test_torch_native_attention_backend.py", 123),
6363
TestFile("test_torchao.py", 70),
6464
TestFile("test_triton_attention_kernels.py", 4),
6565
TestFile("test_triton_attention_backend.py", 134),
@@ -76,6 +76,16 @@ class TestFile:
7676
TestFile("test_hicache.py", 60),
7777
TestFile("test_hicache_mla.py", 90),
7878
],
79+
"per-commit-2-gpu": [
80+
TestFile("test_data_parallelism.py", 90),
81+
TestFile("test_dp_attention.py", 90),
82+
TestFile("test_update_weights_from_distributed.py", 100),
83+
TestFile("test_verl_engine.py", 100),
84+
TestFile("test_patch_torch.py", 30),
85+
TestFile("test_moe_ep.py", 220),
86+
TestFile("test_mla_tp.py", 420),
87+
TestFile("test_lora_tp.py", 300),
88+
],
7989
"nightly": [
8090
TestFile("test_nightly_gsm8k_eval.py"),
8191
],

test/srt/test_dp_attention.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,7 @@ def test_mgsm_en(self):
6060
metrics = run_eval(args)
6161
print(f"{metrics=}")
6262
self.assertGreater(metrics["score"], 0.8)
63+
64+
65+
if __name__ == "__main__":
66+
unittest.main()

test/srt/test_metrics.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def test_metrics_enabled(self):
6363
"sglang:cached_tokens_total",
6464
"sglang:num_requests_total",
6565
"sglang:time_to_first_token_seconds",
66-
"sglang:time_per_output_token_seconds",
6766
"sglang:inter_token_latency_seconds",
6867
"sglang:e2e_request_latency_seconds",
6968
]

0 commit comments

Comments
 (0)