Skip to content

Commit bd4d81f

Browse files
ShangmingCailifuhuang
authored andcommitted
[PD] Add simple unit test for disaggregation feature (sgl-project#5654)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
1 parent 9396732 commit bd4d81f

File tree

5 files changed

+241
-0
lines changed

5 files changed

+241
-0
lines changed

.github/workflows/pr-test.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,24 @@ jobs:
252252
cd test/srt
253253
python3 test_moe_eval_accuracy_large.py
254254
255+
unit-test-backend-pd:
256+
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
257+
github.event.pull_request.draft == false
258+
runs-on: 8-gpu-runner
259+
steps:
260+
- name: Checkout code
261+
uses: actions/checkout@v4
262+
263+
- name: Install dependencies
264+
run: |
265+
bash scripts/ci_install_dependency_8_gpu.sh
266+
267+
- name: Run test
268+
timeout-minutes: 10
269+
run: |
270+
cd test/srt
271+
python3 -m unittest test_disaggregation.TestDisaggregationMooncake.test_gsm8k
272+
255273
large-scale-test-8-gpu:
256274
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
257275
github.event.pull_request.draft == false

python/sglang/test/test_utils.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,81 @@ def popen_launch_server(
478478
raise TimeoutError("Server failed to start within the timeout period.")
479479

480480

481+
def popen_launch_pd_server(
482+
model: str,
483+
base_url: str,
484+
timeout: float,
485+
api_key: Optional[str] = None,
486+
other_args: list[str] = (),
487+
env: Optional[dict] = None,
488+
return_stdout_stderr: Optional[tuple] = None,
489+
):
490+
_, host, port = base_url.split(":")
491+
host = host[2:]
492+
493+
command = "sglang.launch_server"
494+
495+
command = [
496+
"python3",
497+
"-m",
498+
command,
499+
"--model-path",
500+
model,
501+
*[str(x) for x in other_args],
502+
]
503+
504+
command.extend(
505+
[
506+
"--host",
507+
host,
508+
"--port",
509+
port,
510+
]
511+
)
512+
513+
if api_key:
514+
command += ["--api-key", api_key]
515+
516+
print(f"command={' '.join(command)}")
517+
518+
if return_stdout_stderr:
519+
process = subprocess.Popen(
520+
command,
521+
stdout=return_stdout_stderr[0],
522+
stderr=return_stdout_stderr[1],
523+
env=env,
524+
text=True,
525+
)
526+
else:
527+
process = subprocess.Popen(command, stdout=None, stderr=None, env=env)
528+
529+
start_time = time.time()
530+
with requests.Session() as session:
531+
while time.time() - start_time < timeout:
532+
try:
533+
headers = {
534+
"Content-Type": "application/json; charset=utf-8",
535+
"Authorization": f"Bearer {api_key}",
536+
}
537+
response = session.get(
538+
f"{base_url}/health",
539+
headers=headers,
540+
)
541+
if response.status_code == 200:
542+
return process
543+
except requests.RequestException:
544+
pass
545+
546+
return_code = process.poll()
547+
if return_code is not None:
548+
raise Exception(f"Server unexpectedly exits ({return_code=}).")
549+
550+
time.sleep(10)
551+
552+
kill_process_tree(process.pid)
553+
raise TimeoutError("Server failed to start within the timeout period.")
554+
555+
481556
def run_with_timeout(
482557
func: Callable,
483558
args: tuple = (),

scripts/ci_install_dependency_8_gpu.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ pip install -e lmms-eval/
5353
# Install FlashMLA for attention backend tests
5454
pip install git+https://github.com/deepseek-ai/FlashMLA.git
5555

56+
# Install mooncake-transfer-engine
57+
pip install mooncake-transfer-engine
58+
5659
# Install system dependencies
5760
# apt-get update && apt-get install -y libibverbs-dev infiniband-diags libmlx5-1 rdma-core openssh-server perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1 rdma-core-dev infiniband-diags-dev libibverbs-dev libibverbs-utils librdmacm-dev librdmacm-utils ibverbs-utils rdma-core-utils
5861
apt install curl wget git sudo libibverbs-dev -y

test/srt/run_suite.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ class TestFile:
8585
TestFile("test_w8a8_quantization.py", 46),
8686
TestFile("models/lora/test_lora_cuda_graph.py", 250),
8787
],
88+
"per-commit-pd": [
89+
TestFile("test_disaggregation.py", 90),
90+
],
8891
"per-commit-2-gpu": [
8992
TestFile("models/lora/test_lora_tp.py", 116),
9093
TestFile("test_data_parallelism.py", 73),

test/srt/test_disaggregation.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import subprocess
2+
import threading
3+
import time
4+
import unittest
5+
from types import SimpleNamespace
6+
7+
import requests
8+
import torch
9+
10+
from sglang.srt.utils import kill_process_tree
11+
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
12+
from sglang.test.test_utils import (
13+
DEFAULT_MODEL_NAME_FOR_TEST,
14+
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
15+
DEFAULT_URL_FOR_TEST,
16+
CustomTestCase,
17+
popen_launch_pd_server,
18+
run_with_timeout,
19+
)
20+
21+
22+
class TestDisaggregationMooncake(CustomTestCase):
23+
@classmethod
24+
def setUpClass(cls):
25+
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
26+
cls.base_host = "127.0.0.1"
27+
cls.base_port = int(DEFAULT_URL_FOR_TEST.split(":")[-1])
28+
cls.lb_url = DEFAULT_URL_FOR_TEST
29+
cls.prefill_url = f"http://{cls.base_host}:{cls.base_port + 100}"
30+
cls.decode_url = f"http://{cls.base_host}:{cls.base_port + 200}"
31+
32+
run_with_timeout(cls.start_prefill, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH)
33+
run_with_timeout(cls.start_decode, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH)
34+
35+
cls.wait_server_ready(cls.prefill_url + "/health")
36+
cls.wait_server_ready(cls.decode_url + "/health")
37+
38+
lb_command = [
39+
"python3",
40+
"-m",
41+
"sglang.srt.disaggregation.mini_lb",
42+
"--prefill",
43+
cls.prefill_url,
44+
"--decode",
45+
cls.decode_url,
46+
"--host",
47+
cls.base_host,
48+
"--port",
49+
str(cls.base_port),
50+
]
51+
52+
print("Starting load balancer:", " ".join(lb_command))
53+
cls.process_lb = subprocess.Popen(
54+
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
55+
)
56+
cls.wait_server_ready(cls.lb_url + "/health")
57+
58+
@classmethod
59+
def start_prefill(cls):
60+
prefill_args = [
61+
"--trust-remote-code",
62+
"--disaggregation-mode",
63+
"prefill",
64+
"--host",
65+
cls.base_host,
66+
"--port",
67+
str(cls.base_port + 100),
68+
"--tp",
69+
"4",
70+
]
71+
cls.process_prefill = popen_launch_pd_server(
72+
cls.model,
73+
cls.prefill_url,
74+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
75+
other_args=prefill_args,
76+
)
77+
78+
@classmethod
79+
def start_decode(cls):
80+
decode_args = [
81+
"--trust-remote-code",
82+
"--disaggregation-mode",
83+
"decode",
84+
"--host",
85+
cls.base_host,
86+
"--port",
87+
str(cls.base_port + 200),
88+
"--tp",
89+
"4",
90+
"--base-gpu-id",
91+
"4",
92+
]
93+
cls.process_decode = popen_launch_pd_server(
94+
cls.model,
95+
cls.decode_url,
96+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
97+
other_args=decode_args,
98+
)
99+
100+
@classmethod
101+
def wait_server_ready(cls, url, timeout=60):
102+
start_time = time.time()
103+
while True:
104+
try:
105+
response = requests.get(url)
106+
if response.status_code == 200:
107+
print(f"Server {url} is ready")
108+
return
109+
except Exception:
110+
pass
111+
112+
if time.time() - start_time > timeout:
113+
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
114+
time.sleep(1)
115+
116+
@classmethod
117+
def tearDownClass(cls):
118+
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
119+
if process:
120+
try:
121+
kill_process_tree(process.pid)
122+
except Exception as e:
123+
print(f"Error killing process {process.pid}: {e}")
124+
125+
def test_gsm8k(self):
126+
args = SimpleNamespace(
127+
num_shots=5,
128+
data_path=None,
129+
num_questions=200,
130+
max_new_tokens=512,
131+
parallel=128,
132+
host="http://127.0.0.1",
133+
port=int(self.lb_url.split(":")[-1]),
134+
)
135+
metrics = run_eval_few_shot_gsm8k(args)
136+
print(f"Evaluation metrics: {metrics}")
137+
138+
self.assertGreater(metrics["accuracy"], 0.62)
139+
140+
141+
if __name__ == "__main__":
142+
unittest.main()

0 commit comments

Comments
 (0)