From a534d7de44a043ef1846cf92f6ee02c795e8a636 Mon Sep 17 00:00:00 2001 From: jokerwyt <914554688@qq.com> Date: Thu, 17 Apr 2025 06:15:30 +0000 Subject: [PATCH 1/3] Support PD bootstrap fields on /v1/chat/completions endpoint --- python/sglang/srt/disaggregation/mini_lb.py | 38 +++++++++++++++++---- python/sglang/srt/openai_api/adapter.py | 2 ++ python/sglang/srt/openai_api/protocol.py | 4 +++ 3 files changed, 38 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/disaggregation/mini_lb.py b/python/sglang/srt/disaggregation/mini_lb.py index d90277a774c..9832444f03b 100644 --- a/python/sglang/srt/disaggregation/mini_lb.py +++ b/python/sglang/srt/disaggregation/mini_lb.py @@ -23,13 +23,14 @@ def select_pair(self): return random.choice(self.prefill_servers), random.choice(self.decode_servers) async def generate( - self, modified_request, prefill_server, decode_server + self, modified_request, prefill_server, decode_server, endpoint ) -> ORJSONResponse: + assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}" async with aiohttp.ClientSession() as session: tasks = [ - session.post(f"{prefill_server}/generate", json=modified_request), - session.post(f"{decode_server}/generate", json=modified_request), + session.post(f"{prefill_server}/{endpoint}", json=modified_request), + session.post(f"{decode_server}/{endpoint}", json=modified_request), ] # Wait for both responses to complete. Prefill should end first. prefill_response, decode_response = await asyncio.gather(*tasks) @@ -39,7 +40,9 @@ async def generate( status_code=decode_response.status, ) - async def generate_stream(self, modified_request, prefill_server, decode_server): + async def generate_stream(self, modified_request, prefill_server, decode_server, endpoint = "generate"): + assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}" + async def stream_results(): async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout( @@ -50,10 +53,10 @@ async def stream_results(): # Create the tasks for both prefill and decode requests tasks = [ session.post( - f"{prefill_server}/generate", json=modified_request + f"{prefill_server}/{endpoint}", json=modified_request ), session.post( - f"{decode_server}/generate", json=modified_request + f"{decode_server}/{endpoint}", json=modified_request ), ] # Wait for both responses to complete. Since this is streaming, they return immediately. @@ -173,6 +176,29 @@ async def handle_generate_request(request_data: dict): modified_request, prefill_server, decode_server ) +@app.post("/v1/chat/completions") +async def handle_completion_request(request_data: dict): + prefill_server, decode_server = load_balancer.select_pair() + + # Parse and transform prefill_server for bootstrap data + parsed_url = urllib.parse.urlparse(prefill_server) + hostname = parsed_url.hostname + modified_request = request_data.copy() + modified_request.update( + { + "bootstrap_host": hostname, + "bootstrap_room": random.randint(0, 2**63 - 1), + } + ) + + if request_data.get("stream", False): + return await load_balancer.request_stream( + modified_request, prefill_server, decode_server, endpoint="v1/chat/completions" + ) + else: + return await load_balancer.request( + modified_request, prefill_server, decode_server, endpoint="v1/chat/completions" + ) @app.get("/v1/models") async def get_models(): diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 8f74007fdaf..5e86e8f9def 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -1109,6 +1109,8 @@ def v1_chat_generate_request( rid=request_ids, modalities=modalities_list, lora_path=lora_paths, + bootstrap_host=all_requests[0].bootstrap_host, + bootstrap_room=all_requests[0].bootstrap_room, ) return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 4318afea624..db0420b9533 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -361,6 +361,10 @@ def set_tool_choice_default(cls, values): separate_reasoning: bool = True stream_reasoning: bool = True + # For PD disaggregation + bootstrap_host: Optional[str] = None + bootstrap_room: Optional[int] = None + class FunctionResponse(BaseModel): """Function response.""" From 34b40bbafa5f97dd7525a9772fd1f900766bf7c9 Mon Sep 17 00:00:00 2001 From: jokerwyt <914554688@qq.com> Date: Fri, 18 Apr 2025 11:36:15 +0000 Subject: [PATCH 2/3] fix method name bug on /v1/chat/completions --- python/sglang/srt/disaggregation/mini_lb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/disaggregation/mini_lb.py b/python/sglang/srt/disaggregation/mini_lb.py index 9832444f03b..afc924492b2 100644 --- a/python/sglang/srt/disaggregation/mini_lb.py +++ b/python/sglang/srt/disaggregation/mini_lb.py @@ -192,11 +192,11 @@ async def handle_completion_request(request_data: dict): ) if request_data.get("stream", False): - return await load_balancer.request_stream( + return await load_balancer.generate_stream( modified_request, prefill_server, decode_server, endpoint="v1/chat/completions" ) else: - return await load_balancer.request( + return await load_balancer.generate( modified_request, prefill_server, decode_server, endpoint="v1/chat/completions" ) From 9a122d25fa3792064c796237ce921dedb80ce2a6 Mon Sep 17 00:00:00 2001 From: jokerwyt <914554688@qq.com> Date: Mon, 21 Apr 2025 01:46:23 +0000 Subject: [PATCH 3/3] lint --- python/sglang/srt/disaggregation/mini_lb.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/disaggregation/mini_lb.py b/python/sglang/srt/disaggregation/mini_lb.py index afc924492b2..2ac94faa181 100644 --- a/python/sglang/srt/disaggregation/mini_lb.py +++ b/python/sglang/srt/disaggregation/mini_lb.py @@ -40,7 +40,9 @@ async def generate( status_code=decode_response.status, ) - async def generate_stream(self, modified_request, prefill_server, decode_server, endpoint = "generate"): + async def generate_stream( + self, modified_request, prefill_server, decode_server, endpoint="generate" + ): assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}" async def stream_results(): @@ -176,6 +178,7 @@ async def handle_generate_request(request_data: dict): modified_request, prefill_server, decode_server ) + @app.post("/v1/chat/completions") async def handle_completion_request(request_data: dict): prefill_server, decode_server = load_balancer.select_pair() @@ -193,13 +196,20 @@ async def handle_completion_request(request_data: dict): if request_data.get("stream", False): return await load_balancer.generate_stream( - modified_request, prefill_server, decode_server, endpoint="v1/chat/completions" + modified_request, + prefill_server, + decode_server, + endpoint="v1/chat/completions", ) else: return await load_balancer.generate( - modified_request, prefill_server, decode_server, endpoint="v1/chat/completions" + modified_request, + prefill_server, + decode_server, + endpoint="v1/chat/completions", ) + @app.get("/v1/models") async def get_models(): prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server