Skip to content

Commit 4d23ba0

Browse files
authored
Simplify FA3 tests (sgl-project#5779)
1 parent 6e313c1 commit 4d23ba0

File tree

3 files changed

+14
-67
lines changed

3 files changed

+14
-67
lines changed

test/srt/run_suite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class TestFile:
3030
TestFile("test_chunked_prefill.py", 336),
3131
TestFile("test_eagle_infer.py", 500),
3232
TestFile("test_ebnf_constrained.py"),
33-
TestFile("test_fa3.py", 500),
33+
TestFile("test_fa3.py", 400),
3434
TestFile("test_fp8_kernel.py", 8),
3535
TestFile("test_embedding_openai_server.py", 36),
3636
TestFile("test_hidden_states.py", 55),
@@ -92,7 +92,7 @@ class TestFile:
9292
TestFile("test_verl_engine.py", 100),
9393
],
9494
"per-commit-8-gpu": [
95-
TestFile("test_local_attn.py", 100),
95+
TestFile("test_local_attn.py", 250),
9696
],
9797
"nightly": [
9898
TestFile("test_nightly_gsm8k_eval.py"),

test/srt/test_fa3.py

Lines changed: 8 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from types import SimpleNamespace
44

55
import requests
6-
import torch
76

87
from sglang.srt.utils import get_device_sm, kill_process_tree
98
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
@@ -14,6 +13,7 @@
1413
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
1514
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
1615
DEFAULT_URL_FOR_TEST,
16+
CustomTestCase,
1717
popen_launch_server,
1818
)
1919

@@ -47,9 +47,8 @@
4747
# Default server arguments shared across all tests
4848
DEFAULT_SERVER_ARGS = [
4949
"--trust-remote-code",
50-
"--enable-torch-compile",
5150
"--cuda-graph-max-bs",
52-
"2",
51+
"4",
5352
"--attention-backend",
5453
"fa3",
5554
]
@@ -60,7 +59,7 @@
6059

6160

6261
@unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher")
63-
class BaseFlashAttentionTest(unittest.TestCase):
62+
class BaseFlashAttentionTest(CustomTestCase):
6463
"""Base class for testing FlashAttention3."""
6564

6665
model = DEFAULT_MODEL_NAME_FOR_TEST
@@ -78,20 +77,22 @@ def get_server_args(cls):
7877
def setUpClass(cls):
7978
# disable deep gemm precompile to make launch server faster
8079
# please don't do this if you want to make your inference workload faster
81-
os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "False"
80+
os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false"
81+
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
8282
cls.process = popen_launch_server(
8383
cls.model,
8484
cls.base_url,
8585
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
8686
other_args=cls.get_server_args(),
87-
env=os.environ,
8887
)
8988

9089
@classmethod
9190
def tearDownClass(cls):
9291
kill_process_tree(cls.process.pid)
9392

9493
def test_gsm8k(self):
94+
requests.get(self.base_url + "/flush_cache")
95+
9596
args = SimpleNamespace(
9697
num_shots=4,
9798
num_questions=100,
@@ -102,7 +103,7 @@ def test_gsm8k(self):
102103
data_path=GSM_DATASET_PATH,
103104
)
104105
metrics = run_eval_few_shot_gsm8k(args)
105-
print(metrics)
106+
print(f"{metrics=}")
106107

107108
# Use the appropriate metric key based on the test class
108109
metric_key = "accuracy"
@@ -192,60 +193,6 @@ def get_server_args(cls):
192193
return args
193194

194195

195-
class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
196-
"""Test FlashAttention3 with speculative decode enabled, topk > 1"""
197-
198-
model = DEFAULT_MODEL_NAME_FOR_TEST
199-
200-
@classmethod
201-
def get_server_args(cls):
202-
args = super().get_server_args()
203-
args.extend(
204-
[
205-
"--cuda-graph-max-bs",
206-
"2",
207-
"--speculative-algorithm",
208-
"EAGLE3",
209-
"--speculative-draft",
210-
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
211-
"--speculative-num-steps",
212-
"5",
213-
"--speculative-eagle-topk",
214-
"4",
215-
"--speculative-num-draft-tokens",
216-
"8",
217-
"--dtype",
218-
"float16",
219-
]
220-
)
221-
return args
222-
223-
def test_gsm8k(self):
224-
"""
225-
Override the test_gsm8k to further test for average speculative accept length.
226-
"""
227-
requests.get(self.base_url + "/flush_cache")
228-
229-
args = SimpleNamespace(
230-
num_shots=5,
231-
data_path=GSM_DATASET_PATH,
232-
num_questions=200,
233-
max_new_tokens=512,
234-
parallel=128,
235-
host="http://127.0.0.1",
236-
port=int(self.base_url.split(":")[-1]),
237-
)
238-
metrics = run_eval_few_shot_gsm8k(args)
239-
print(metrics)
240-
241-
self.assertGreater(metrics["accuracy"], 0.60)
242-
243-
server_info = requests.get(self.base_url + "/get_server_info")
244-
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
245-
print(f"{avg_spec_accept_length=}")
246-
self.assertGreater(avg_spec_accept_length, 1.8)
247-
248-
249196
class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
250197
"""Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model"""
251198

test/srt/test_local_attn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,20 @@
1010
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION,
1111
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
1212
DEFAULT_URL_FOR_TEST,
13+
CustomTestCase,
1314
popen_launch_server,
1415
)
1516

1617

1718
@unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher")
18-
class TestFlashAttention3LocalAttn(unittest.TestCase):
19+
class TestFlashAttention3LocalAttn(CustomTestCase):
1920
model = DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION
2021
base_url = DEFAULT_URL_FOR_TEST
2122
accuracy_threshold = 0.90
2223

2324
@classmethod
2425
def get_server_args(cls):
2526
return [
26-
"--trust-remote-code",
2727
"--cuda-graph-max-bs",
2828
"2",
2929
"--attention-backend",
@@ -36,8 +36,6 @@ def get_server_args(cls):
3636

3737
@classmethod
3838
def setUpClass(cls):
39-
# disable deep gemm precompile to make launch server faster
40-
# please don't do this if you want to make your inference workload faster
4139
cls.process = popen_launch_server(
4240
cls.model,
4341
cls.base_url,
@@ -51,6 +49,8 @@ def tearDownClass(cls):
5149
kill_process_tree(cls.process.pid)
5250

5351
def test_gsm8k(self):
52+
requests.get(self.base_url + "/flush_cache")
53+
5454
args = SimpleNamespace(
5555
num_shots=4,
5656
num_questions=100,

0 commit comments

Comments
 (0)