3
3
from types import SimpleNamespace
4
4
5
5
import requests
6
- import torch
7
6
8
7
from sglang .srt .utils import get_device_sm , kill_process_tree
9
8
from sglang .test .few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
14
13
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN ,
15
14
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
16
15
DEFAULT_URL_FOR_TEST ,
16
+ CustomTestCase ,
17
17
popen_launch_server ,
18
18
)
19
19
47
47
# Default server arguments shared across all tests
48
48
DEFAULT_SERVER_ARGS = [
49
49
"--trust-remote-code" ,
50
- "--enable-torch-compile" ,
51
50
"--cuda-graph-max-bs" ,
52
- "2 " ,
51
+ "4 " ,
53
52
"--attention-backend" ,
54
53
"fa3" ,
55
54
]
60
59
61
60
62
61
@unittest .skipIf (get_device_sm () < 90 , "Test requires CUDA SM 90 or higher" )
63
- class BaseFlashAttentionTest (unittest . TestCase ):
62
+ class BaseFlashAttentionTest (CustomTestCase ):
64
63
"""Base class for testing FlashAttention3."""
65
64
66
65
model = DEFAULT_MODEL_NAME_FOR_TEST
@@ -78,20 +77,22 @@ def get_server_args(cls):
78
77
def setUpClass (cls ):
79
78
# disable deep gemm precompile to make launch server faster
80
79
# please don't do this if you want to make your inference workload faster
81
- os .environ ["SGL_JIT_DEEPGEMM_PRECOMPILE" ] = "False"
80
+ os .environ ["SGL_JIT_DEEPGEMM_PRECOMPILE" ] = "false"
81
+ os .environ ["SGL_ENABLE_JIT_DEEPGEMM" ] = "false"
82
82
cls .process = popen_launch_server (
83
83
cls .model ,
84
84
cls .base_url ,
85
85
timeout = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
86
86
other_args = cls .get_server_args (),
87
- env = os .environ ,
88
87
)
89
88
90
89
@classmethod
91
90
def tearDownClass (cls ):
92
91
kill_process_tree (cls .process .pid )
93
92
94
93
def test_gsm8k (self ):
94
+ requests .get (self .base_url + "/flush_cache" )
95
+
95
96
args = SimpleNamespace (
96
97
num_shots = 4 ,
97
98
num_questions = 100 ,
@@ -102,7 +103,7 @@ def test_gsm8k(self):
102
103
data_path = GSM_DATASET_PATH ,
103
104
)
104
105
metrics = run_eval_few_shot_gsm8k (args )
105
- print (metrics )
106
+ print (f" { metrics = } " )
106
107
107
108
# Use the appropriate metric key based on the test class
108
109
metric_key = "accuracy"
@@ -192,60 +193,6 @@ def get_server_args(cls):
192
193
return args
193
194
194
195
195
- class TestFlashAttention3SpeculativeDecodeTopk (BaseFlashAttentionTest ):
196
- """Test FlashAttention3 with speculative decode enabled, topk > 1"""
197
-
198
- model = DEFAULT_MODEL_NAME_FOR_TEST
199
-
200
- @classmethod
201
- def get_server_args (cls ):
202
- args = super ().get_server_args ()
203
- args .extend (
204
- [
205
- "--cuda-graph-max-bs" ,
206
- "2" ,
207
- "--speculative-algorithm" ,
208
- "EAGLE3" ,
209
- "--speculative-draft" ,
210
- DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 ,
211
- "--speculative-num-steps" ,
212
- "5" ,
213
- "--speculative-eagle-topk" ,
214
- "4" ,
215
- "--speculative-num-draft-tokens" ,
216
- "8" ,
217
- "--dtype" ,
218
- "float16" ,
219
- ]
220
- )
221
- return args
222
-
223
- def test_gsm8k (self ):
224
- """
225
- Override the test_gsm8k to further test for average speculative accept length.
226
- """
227
- requests .get (self .base_url + "/flush_cache" )
228
-
229
- args = SimpleNamespace (
230
- num_shots = 5 ,
231
- data_path = GSM_DATASET_PATH ,
232
- num_questions = 200 ,
233
- max_new_tokens = 512 ,
234
- parallel = 128 ,
235
- host = "http://127.0.0.1" ,
236
- port = int (self .base_url .split (":" )[- 1 ]),
237
- )
238
- metrics = run_eval_few_shot_gsm8k (args )
239
- print (metrics )
240
-
241
- self .assertGreater (metrics ["accuracy" ], 0.60 )
242
-
243
- server_info = requests .get (self .base_url + "/get_server_info" )
244
- avg_spec_accept_length = server_info .json ()["avg_spec_accept_length" ]
245
- print (f"{ avg_spec_accept_length = } " )
246
- self .assertGreater (avg_spec_accept_length , 1.8 )
247
-
248
-
249
196
class TestFlashAttention3MLASpeculativeDecode (BaseFlashAttentionTest ):
250
197
"""Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model"""
251
198
0 commit comments