Skip to content

Commit 9ecb185

Browse files
authored
Fix triton sliding window test case (sgl-project#6981)
1 parent cc74499 commit 9ecb185

File tree

1 file changed

+15
-18
lines changed

1 file changed

+15
-18
lines changed

test/srt/test_triton_sliding_window.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import time
21
import unittest
32
from types import SimpleNamespace
43

@@ -10,6 +9,7 @@
109
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
1110
DEFAULT_URL_FOR_TEST,
1211
CustomTestCase,
12+
is_in_ci,
1313
popen_launch_server,
1414
)
1515

@@ -45,10 +45,6 @@ def setUpClass(cls):
4545
)
4646
cls.long_context_prompt += "\nNow, summarize the story in one sentence:"
4747

48-
@classmethod
49-
def tearDownClass(cls):
50-
pass
51-
5248
def _test_mmlu(self):
5349
args = SimpleNamespace(
5450
base_url=self.base_url,
@@ -61,7 +57,7 @@ def _test_mmlu(self):
6157
metrics = run_eval(args)
6258
print(f"MMLU metrics with sliding window: {metrics}")
6359

64-
self.assertGreaterEqual(metrics["score"], 0.61)
60+
self.assertGreaterEqual(metrics["score"], 0.60)
6561

6662
def _test_short_context_generation(self):
6763
response = requests.post(
@@ -97,6 +93,7 @@ def _test_long_context_generation(self):
9793
self.assertGreater(len(result["text"].strip()), 0)
9894
print(f"Long context generation result: {result['text'][:100]}...")
9995

96+
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
10097
def test_no_cuda_graph(self):
10198
self.no_cuda_graph_process = popen_launch_server(
10299
self.model,
@@ -105,12 +102,12 @@ def test_no_cuda_graph(self):
105102
other_args=self.common_args + ["--disable-cuda-graph"],
106103
)
107104

108-
self._test_short_context_generation()
109-
self._test_long_context_generation()
110-
self._test_mmlu()
111-
112-
kill_process_tree(self.no_cuda_graph_process.pid)
113-
time.sleep(5)
105+
try:
106+
self._test_short_context_generation()
107+
self._test_long_context_generation()
108+
self._test_mmlu()
109+
finally:
110+
kill_process_tree(self.no_cuda_graph_process.pid)
114111

115112
def test_cuda_graph(self):
116113
self.cuda_graph_process = popen_launch_server(
@@ -120,12 +117,12 @@ def test_cuda_graph(self):
120117
other_args=self.common_args,
121118
)
122119

123-
self._test_short_context_generation()
124-
self._test_long_context_generation()
125-
self._test_mmlu()
126-
127-
kill_process_tree(self.cuda_graph_process.pid)
128-
time.sleep(5)
120+
try:
121+
self._test_short_context_generation()
122+
self._test_long_context_generation()
123+
self._test_mmlu()
124+
finally:
125+
kill_process_tree(self.cuda_graph_process.pid)
129126

130127

131128
if __name__ == "__main__":

0 commit comments

Comments
 (0)