From 1867646ef906545a668ac3e029e5bed428951806 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Thu, 14 Aug 2025 10:33:06 +0800 Subject: [PATCH 01/16] [BugFix] fix control signal release failed --- fastdeploy/entrypoints/openai/api_server.py | 10 +- fastdeploy/entrypoints/openai/serving_chat.py | 103 ++++++++---------- .../entrypoints/openai/serving_completion.py | 16 +-- fastdeploy/inter_communicator/zmq_client.py | 14 ++- 4 files changed, 71 insertions(+), 72 deletions(-) diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 2a4c0e7aba..bf4f909532 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -171,10 +171,10 @@ async def connection_manager(): await asyncio.wait_for(connection_semaphore.acquire(), timeout=0.001) yield except asyncio.TimeoutError: - api_server_logger.info(f"Reach max request release: {connection_semaphore.status()}") - if connection_semaphore.locked(): - connection_semaphore.release() - raise HTTPException(status_code=429, detail="Too many requests") + api_server_logger.info(f"Reach max request concurrency, semaphore status: {connection_semaphore.status()}") + raise HTTPException( + status_code=429, detail=f"Too many requests,current max concurrency is {args.max_concurrency}" + ) # TODO 传递真实引擎值 通过pid 获取状态 @@ -261,9 +261,11 @@ async def create_chat_completion(request: ChatCompletionRequest): inject_to_metadata(request) generator = await app.state.chat_handler.create_chat_completion(request) if isinstance(generator, ErrorResponse): + api_server_logger.debug(f"release: {connection_semaphore.status()}") connection_semaphore.release() return JSONResponse(content={"detail": generator.model_dump()}, status_code=generator.code) elif isinstance(generator, ChatCompletionResponse): + api_server_logger.debug(f"release: {connection_semaphore.status()}") connection_semaphore.release() return JSONResponse(content=generator.model_dump()) else: diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index b14f28e627..102b494f57 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -78,44 +78,45 @@ async def create_chat_completion(self, request: ChatCompletionRequest): api_server_logger.error(err_msg) return ErrorResponse(message=err_msg, code=400) - if request.user is not None: - request_id = f"chatcmpl-{request.user}-{uuid.uuid4()}" - else: - request_id = f"chatcmpl-{uuid.uuid4()}" - api_server_logger.info(f"create chat completion request: {request_id}") - text_after_process = None - try: - current_req_dict = request.to_dict_for_infer(request_id) - current_req_dict["arrival_time"] = time.time() - prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict) - text_after_process = current_req_dict.get("text_after_process") - if isinstance(prompt_token_ids, np.ndarray): - prompt_token_ids = prompt_token_ids.tolist() - except Exception as e: - return ErrorResponse(code=400, message=str(e)) - - del current_req_dict try: - api_server_logger.debug(f"{self.engine_client.semaphore.status()}") if self.max_waiting_time < 0: await self.engine_client.semaphore.acquire() else: await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time) - except Exception: - return ErrorResponse(code=408, message=f"Request queued time exceed {self.max_waiting_time}") + api_server_logger.info(f"current {self.engine_client.semaphore.status()}") - if request.stream: - return self.chat_completion_stream_generator( - request, request_id, request.model, prompt_token_ids, text_after_process - ) - else: + if request.user is not None: + request_id = f"chatcmpl-{request.user}-{uuid.uuid4()}" + else: + request_id = f"chatcmpl-{uuid.uuid4()}" + api_server_logger.info(f"create chat completion request: {request_id}") + text_after_process = None try: - return await self.chat_completion_full_generator( - request, request_id, request.model, prompt_token_ids, text_after_process - ) + current_req_dict = request.to_dict_for_infer(request_id) + current_req_dict["arrival_time"] = time.time() + prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict) + text_after_process = current_req_dict.get("text_after_process") + if isinstance(prompt_token_ids, np.ndarray): + prompt_token_ids = prompt_token_ids.tolist() except Exception as e: return ErrorResponse(code=400, message=str(e)) + del current_req_dict + + if request.stream: + return self.chat_completion_stream_generator( + request, request_id, request.model, prompt_token_ids, text_after_process + ) + else: + try: + return await self.chat_completion_full_generator( + request, request_id, request.model, prompt_token_ids, text_after_process + ) + except Exception as e: + return ErrorResponse(code=400, message=str(e)) + except Exception: + return ErrorResponse(code=408, message=f"Request queued time exceed {self.max_waiting_time}") + def _create_streaming_error_response(self, message: str) -> str: error_response = ErrorResponse( code=400, @@ -140,7 +141,6 @@ async def chat_completion_stream_generator( previous_num_tokens = 0 num_prompt_tokens = 0 num_choices = 1 - tool_called = False max_streaming_response_tokens = ( request.max_streaming_response_tokens if request.max_streaming_response_tokens is not None @@ -239,34 +239,25 @@ async def chat_completion_stream_generator( prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cached_tokens), ) yield f"data: {chunk.model_dump_json(exclude_unset=True)} \n\n" - api_server_logger.info(f"Chat Streaming response send_idx 0: {chunk.model_dump_json()}") first_iteration = False output = res["outputs"] delta_text = output["text"] output_top_logprobs = output["top_logprobs"] - previous_num_tokens += len(output["token_ids"]) logprobs_res: Optional[LogProbs] = None if request.logprobs and output_top_logprobs is not None: logprobs_res = self._create_chat_logprobs( output_top_logprobs, request.logprobs, request.top_logprobs ) - if self.engine_client.data_processor.tool_parser_obj and not res["finished"]: - tool_delta_message = output["tool_delta_message"] - if tool_delta_message is None: - continue - delta_message = tool_delta_message - delta_message.reasoning_content = output.get("reasoning_content") - if delta_message.tool_calls: - tool_called = True - else: - delta_message = DeltaMessage( - content=delta_text, - reasoning_content=output.get("reasoning_content"), - prompt_token_ids=None, - completion_token_ids=None, - tool_calls=None, - ) + + previous_num_tokens += len(output["token_ids"]) + delta_message = DeltaMessage( + content=delta_text, + reasoning_content=output.get("reasoning_content"), + prompt_token_ids=None, + completion_token_ids=None, + tool_calls=output.get("tool_call_content", []), + ) choice = ChatCompletionResponseStreamChoice( index=0, @@ -274,7 +265,6 @@ async def chat_completion_stream_generator( logprobs=logprobs_res, arrival_time=arrival_time, ) - if res["finished"]: num_choices -= 1 work_process_metrics.e2e_request_latency.observe( @@ -284,7 +274,10 @@ async def chat_completion_stream_generator( max_tokens = request.max_completion_tokens or request.max_tokens if has_no_token_limit or previous_num_tokens != max_tokens: choice.finish_reason = "stop" - if tool_called: + if ( + self.engine_client.reasoning_parser == "ernie_x1" + and output.get("finish_reason", "") == "tool_calls" + ): choice.finish_reason = "tool_calls" else: choice.finish_reason = "length" @@ -306,9 +299,6 @@ async def chat_completion_stream_generator( if len(choices) == max_streaming_response_tokens or res["finished"]: chunk.choices = choices yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" - # 打印尾包 - if res["finished"]: - api_server_logger.info(f"Chat Streaming response last send: {chunk.model_dump_json()}") choices = [] if choices: @@ -414,8 +404,9 @@ async def chat_completion_full_generator( if task_is_finished: break finally: - self.engine_client.semaphore.release() dealer.close() + self.engine_client.semaphore.release() + api_server_logger.info(f"release {self.engine_client.semaphore.status()}") choices = [] output = final_res["outputs"] @@ -423,7 +414,7 @@ async def chat_completion_full_generator( role="assistant", content=output["text"], reasoning_content=output.get("reasoning_content"), - tool_calls=output.get("tool_call"), + tool_calls=output.get("tool_call_content"), prompt_token_ids=prompt_token_ids if request.return_token_ids else None, completion_token_ids=completion_token_ids if request.return_token_ids else None, text_after_process=text_after_process if request.return_token_ids else None, @@ -461,15 +452,13 @@ async def chat_completion_full_generator( prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=final_res.get("num_cached_tokens", 0)), ) work_process_metrics.e2e_request_latency.observe(time.time() - final_res["metrics"]["request_start_time"]) - res = ChatCompletionResponse( + return ChatCompletionResponse( id=request_id, created=created_time, model=model_name, choices=choices, usage=usage, ) - api_server_logger.info(f"Chat response: {res.model_dump_json()}") - return res def _create_chat_logprobs( self, diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index a6aadcf060..64cc969039 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -101,6 +101,14 @@ async def create_completion(self, request: CompletionRequest): api_server_logger.info(f"start inference for request {num_choices}") prompt_batched_token_ids = [] text_after_process_list = [] + try: + if self.max_waiting_time < 0: + await self.engine_client.semaphore.acquire() + else: + await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time) + except Exception: + return ErrorResponse(code=408, message=f"Request queued time exceed {self.max_waiting_time}") + try: for idx, prompt in enumerate(request_prompts): request_id_idx = f"{request_id}-{idx}" @@ -117,14 +125,6 @@ async def create_completion(self, request: CompletionRequest): del current_req_dict - try: - if self.max_waiting_time < 0: - await self.engine_client.semaphore.acquire() - else: - await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time) - except Exception: - return ErrorResponse(code=408, message=f"Request queued time exceed {self.max_waiting_time}") - if request.stream: return self.completion_stream_generator( request=request, diff --git a/fastdeploy/inter_communicator/zmq_client.py b/fastdeploy/inter_communicator/zmq_client.py index 05e55929dd..5bbfa33ba0 100644 --- a/fastdeploy/inter_communicator/zmq_client.py +++ b/fastdeploy/inter_communicator/zmq_client.py @@ -31,10 +31,10 @@ class ZmqClient: """ def __init__(self, name, mode): - self.context = zmq.Context() + self.context = zmq.Context(4) self.socket = self.context.socket(mode) self.file_name = f"/dev/shm/{name}.socket" - self.router_path = f"/dev/shm/router_{name}.ipc" + self.router_path = f"./router_{name}.ipc" self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) self.aggregate_send = envs.FD_USE_AGGREGATE_SEND @@ -67,6 +67,7 @@ def create_router(self): """ self.router = self.context.socket(zmq.ROUTER) self.router.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM) + self.router.setsockopt(zmq.ROUTER_MANDATORY, 1) self.router.setsockopt(zmq.SNDTIMEO, -1) self.router.bind(f"ipc://{self.router_path}") @@ -125,6 +126,11 @@ def send_multipart(self, req_id, data): else: break + if self.req_dict[req_id] == -1: + if data[-1].finished: + with self.mutex: + self.req_dict.pop(req_id, None) + return try: start_send = time.time() if self.aggregate_send: @@ -133,7 +139,9 @@ def send_multipart(self, req_id, data): result = msgpack.packb([response.to_dict() for response in data]) self.router.send_multipart([self.req_dict[req_id], b"", result]) llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}") - + except zmq.ZMQError as e: + llm_logger.error(f"[{req_id}] zmq error: {e}") + self.req_dict[req_id] = -1 except Exception as e: llm_logger.error(f"Send result to zmq client failed: {e}") From f47990247cfe7e727e6f2c4e81f6deb08f7f6471 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Thu, 14 Aug 2025 10:38:41 +0800 Subject: [PATCH 02/16] [BugFix] fix control signal release failed --- fastdeploy/entrypoints/openai/serving_chat.py | 34 ++++++++++++------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 102b494f57..765fc2aa53 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -141,6 +141,7 @@ async def chat_completion_stream_generator( previous_num_tokens = 0 num_prompt_tokens = 0 num_choices = 1 + tool_called = False max_streaming_response_tokens = ( request.max_streaming_response_tokens if request.max_streaming_response_tokens is not None @@ -239,25 +240,35 @@ async def chat_completion_stream_generator( prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cached_tokens), ) yield f"data: {chunk.model_dump_json(exclude_unset=True)} \n\n" + api_server_logger.info(f"Chat Streaming response send_idx 0: {chunk.model_dump_json()}") first_iteration = False output = res["outputs"] delta_text = output["text"] output_top_logprobs = output["top_logprobs"] + previous_num_tokens += len(output["token_ids"]) logprobs_res: Optional[LogProbs] = None if request.logprobs and output_top_logprobs is not None: logprobs_res = self._create_chat_logprobs( output_top_logprobs, request.logprobs, request.top_logprobs ) - previous_num_tokens += len(output["token_ids"]) - delta_message = DeltaMessage( - content=delta_text, - reasoning_content=output.get("reasoning_content"), - prompt_token_ids=None, - completion_token_ids=None, - tool_calls=output.get("tool_call_content", []), - ) + if self.engine_client.data_processor.tool_parser_obj and not res["finished"]: + tool_delta_message = output["tool_delta_message"] + if tool_delta_message is None: + continue + delta_message = tool_delta_message + delta_message.reasoning_content = output.get("reasoning_content") + if delta_message.tool_calls: + tool_called = True + else: + delta_message = DeltaMessage( + content=delta_text, + reasoning_content=output.get("reasoning_content"), + prompt_token_ids=None, + completion_token_ids=None, + tool_calls=None, + ) choice = ChatCompletionResponseStreamChoice( index=0, @@ -274,10 +285,7 @@ async def chat_completion_stream_generator( max_tokens = request.max_completion_tokens or request.max_tokens if has_no_token_limit or previous_num_tokens != max_tokens: choice.finish_reason = "stop" - if ( - self.engine_client.reasoning_parser == "ernie_x1" - and output.get("finish_reason", "") == "tool_calls" - ): + if tool_called: choice.finish_reason = "tool_calls" else: choice.finish_reason = "length" @@ -304,6 +312,8 @@ async def chat_completion_stream_generator( if choices: chunk.choices = choices yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + if res["finished"]: + api_server_logger.info(f"Chat Streaming response last send: {chunk.model_dump_json()}") choices = [] if include_usage: From 64ba4bb87080895a87eee418aaf896302f367655 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Thu, 14 Aug 2025 10:43:53 +0800 Subject: [PATCH 03/16] update --- fastdeploy/entrypoints/openai/serving_chat.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 765fc2aa53..d950382acf 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -253,22 +253,22 @@ async def chat_completion_stream_generator( output_top_logprobs, request.logprobs, request.top_logprobs ) - if self.engine_client.data_processor.tool_parser_obj and not res["finished"]: - tool_delta_message = output["tool_delta_message"] - if tool_delta_message is None: - continue - delta_message = tool_delta_message - delta_message.reasoning_content = output.get("reasoning_content") - if delta_message.tool_calls: - tool_called = True - else: - delta_message = DeltaMessage( - content=delta_text, - reasoning_content=output.get("reasoning_content"), - prompt_token_ids=None, - completion_token_ids=None, - tool_calls=None, - ) + if self.engine_client.data_processor.tool_parser_obj and not res["finished"]: + tool_delta_message = output["tool_delta_message"] + if tool_delta_message is None: + continue + delta_message = tool_delta_message + delta_message.reasoning_content = output.get("reasoning_content") + if delta_message.tool_calls: + tool_called = True + else: + delta_message = DeltaMessage( + content=delta_text, + reasoning_content=output.get("reasoning_content"), + prompt_token_ids=None, + completion_token_ids=None, + tool_calls=None, + ) choice = ChatCompletionResponseStreamChoice( index=0, From 1b775ad6e83ddc9f1a318f20e93623cf2af7499f Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Thu, 14 Aug 2025 10:47:04 +0800 Subject: [PATCH 04/16] update --- fastdeploy/entrypoints/openai/serving_chat.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index d950382acf..13e26eb9c8 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -307,13 +307,13 @@ async def chat_completion_stream_generator( if len(choices) == max_streaming_response_tokens or res["finished"]: chunk.choices = choices yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + if res["finished"]: + api_server_logger.info(f"Chat Streaming response last send: {chunk.model_dump_json()}") choices = [] if choices: chunk.choices = choices yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" - if res["finished"]: - api_server_logger.info(f"Chat Streaming response last send: {chunk.model_dump_json()}") choices = [] if include_usage: @@ -424,7 +424,7 @@ async def chat_completion_full_generator( role="assistant", content=output["text"], reasoning_content=output.get("reasoning_content"), - tool_calls=output.get("tool_call_content"), + tool_calls=output.get("tool_call"), prompt_token_ids=prompt_token_ids if request.return_token_ids else None, completion_token_ids=completion_token_ids if request.return_token_ids else None, text_after_process=text_after_process if request.return_token_ids else None, @@ -462,13 +462,15 @@ async def chat_completion_full_generator( prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=final_res.get("num_cached_tokens", 0)), ) work_process_metrics.e2e_request_latency.observe(time.time() - final_res["metrics"]["request_start_time"]) - return ChatCompletionResponse( + res = ChatCompletionResponse( id=request_id, created=created_time, model=model_name, choices=choices, usage=usage, ) + api_server_logger.info(f"Chat response: {res.model_dump_json()}") + return res def _create_chat_logprobs( self, From 6ba261846c5d99bb5cbfdc667344c4db2726c982 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Thu, 14 Aug 2025 10:49:46 +0800 Subject: [PATCH 05/16] update --- fastdeploy/inter_communicator/zmq_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/inter_communicator/zmq_client.py b/fastdeploy/inter_communicator/zmq_client.py index 5bbfa33ba0..5143d9d47a 100644 --- a/fastdeploy/inter_communicator/zmq_client.py +++ b/fastdeploy/inter_communicator/zmq_client.py @@ -34,7 +34,7 @@ def __init__(self, name, mode): self.context = zmq.Context(4) self.socket = self.context.socket(mode) self.file_name = f"/dev/shm/{name}.socket" - self.router_path = f"./router_{name}.ipc" + self.router_path = f"/dev/shm/router_{name}.ipc" self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) self.aggregate_send = envs.FD_USE_AGGREGATE_SEND From 7ccd2410cf89c3b60c9f628f2b3428a46a666d15 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Tue, 19 Aug 2025 16:34:30 +0800 Subject: [PATCH 06/16] [Feature] add dealer manager to reuse the connection --- fastdeploy/entrypoints/engine_client.py | 7 + fastdeploy/entrypoints/openai/api_server.py | 2 +- fastdeploy/entrypoints/openai/serving_chat.py | 28 ++-- .../entrypoints/openai/serving_completion.py | 29 ++-- fastdeploy/entrypoints/openai/utils.py | 152 ++++++++++++++++++ fastdeploy/envs.py | 2 +- 6 files changed, 195 insertions(+), 25 deletions(-) create mode 100644 fastdeploy/entrypoints/openai/utils.py diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index daed93b8f9..2dba67e70e 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -14,6 +14,7 @@ # limitations under the License. """ +import os import time import uuid @@ -21,6 +22,7 @@ from fastdeploy import envs from fastdeploy.engine.config import ModelConfig +from fastdeploy.entrypoints.openai.utils import DealerConnectionManager from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS from fastdeploy.input.preprocess import InputPreprocessor from fastdeploy.inter_communicator import IPCSignal, ZmqClient @@ -90,6 +92,11 @@ def __init__( suffix=pid, create=False, ) + self.semaphore = StatefulSemaphore((envs.FD_SUPPORT_MAX_CONNECTIONS + workers - 1) // workers) + self.connection_manager = DealerConnectionManager( + pid, max_connections=int(os.getenv("FD_DEALER_CONNECTIONS", 50)) + ) + self.connection_initialized = False def create_zmq_client(self, model, mode): """ diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 98bb071d31..b180d26e76 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -153,9 +153,9 @@ async def lifespan(app: FastAPI): yield # close zmq try: + await engine_client.connection_manager.close() engine_client.zmq_client.close() from prometheus_client import multiprocess - multiprocess.mark_process_dead(os.getpid()) api_server_logger.info(f"Closing metrics client pid: {pid}") except Exception as e: diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index e941719703..d639e3760a 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -20,10 +20,7 @@ import uuid from typing import List, Optional -import aiozmq -import msgpack import numpy as np -from aiozmq import zmq from fastdeploy.entrypoints.openai.protocol import ( ChatCompletionRequest, @@ -62,6 +59,12 @@ def __init__(self, engine_client, pid, ips, max_waiting_time, chat_template): else: self.master_ip = self.master_ip.split(",")[0] + async def _ensure_connection_manager(self): + """ensure connection manager initialized""" + if not self.engine_client.connection_initialized: + await self.engine_client.connection_manager.initialize() + self.engine_client.connection_initialized = True + def _check_master(self): if self.master_ip is None: return True @@ -170,14 +173,16 @@ async def chat_completion_stream_generator( choices=[], model=model_name, ) + try: - dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc") + await self._ensure_connection_manager() + dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id) dealer.write([b"", request_id.encode("utf-8")]) choices = [] current_waiting_time = 0 while num_choices > 0: try: - raw_data = await asyncio.wait_for(dealer.read(), timeout=10) + response = await asyncio.wait_for(response_queue.get(), timeout=10) current_waiting_time = 0 except asyncio.TimeoutError: current_waiting_time += 10 @@ -192,7 +197,6 @@ async def chat_completion_stream_generator( current_waiting_time = 0 await asyncio.sleep(0.01) continue - response = msgpack.unpackb(raw_data[-1]) for res in response: if res.get("error_code", 200) != 200: raise ValueError("{}".format(res["error_msg"])) @@ -339,9 +343,9 @@ async def chat_completion_stream_generator( error_data = self._create_streaming_error_response(str(e)) yield f"data: {error_data}\n\n" finally: - dealer.close() + await self.engine_client.connection_manager.cleanup_request(request_id) self.engine_client.semaphore.release() - api_server_logger.info(f"release {self.engine_client.semaphore.status()}") + api_server_logger.info(f"release {request_id} {self.engine_client.semaphore.status()}") yield "data: [DONE]\n\n" async def chat_completion_full_generator( @@ -364,7 +368,8 @@ async def chat_completion_full_generator( include_stop_str_in_output = request.include_stop_str_in_output try: - dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc") + await self._ensure_connection_manager() + dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id) dealer.write([b"", request_id.encode("utf-8")]) final_res = None previous_num_tokens = 0 @@ -373,7 +378,7 @@ async def chat_completion_full_generator( completion_token_ids = [] while True: try: - raw_data = await asyncio.wait_for(dealer.read(), timeout=10) + response = await asyncio.wait_for(response_queue.get(), timeout=10) current_waiting_time = 0 except asyncio.TimeoutError: current_waiting_time += 10 @@ -386,7 +391,6 @@ async def chat_completion_full_generator( await asyncio.sleep(0.1) continue - response = msgpack.unpackb(raw_data[-1]) task_is_finished = False for data in response: if data.get("error_code", 200) != 200: @@ -416,7 +420,7 @@ async def chat_completion_full_generator( if task_is_finished: break finally: - dealer.close() + await self.engine_client.connection_manager.cleanup_request(request_id) self.engine_client.semaphore.release() api_server_logger.info(f"release {self.engine_client.semaphore.status()}") diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 43336dac69..cc07787e2c 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -19,10 +19,7 @@ import uuid from typing import List, Optional -import aiozmq -import msgpack import numpy as np -from aiozmq import zmq from fastdeploy.engine.request import RequestOutput from fastdeploy.entrypoints.openai.protocol import ( @@ -52,6 +49,12 @@ def __init__(self, engine_client, pid, ips, max_waiting_time): else: self.master_ip = self.master_ip.split(",")[0] + async def _ensure_connection_manager(self): + """ensure connection manager initialized""" + if not self.engine_client.connection_initialized: + await self.engine_client.connection_manager.initialize() + self.engine_client.connection_initialized = True + def _check_master(self): if self.master_ip is None: return True @@ -169,7 +172,8 @@ async def completion_full_generator( try: request_ids = [f"{request_id}-{i}" for i in range(num_choices)] # create dealer - dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc") + await self._ensure_connection_manager() + dealer, response_queue = await self.engine.connection_manager.get_connection(request_id) for rid in request_ids: dealer.write([b"", rid.encode("utf-8")]) @@ -182,7 +186,7 @@ async def completion_full_generator( current_waiting_time = 0 while num_choices > 0: try: - raw_data = await asyncio.wait_for(dealer.read(), timeout=10) + response = await asyncio.wait_for(response_queue.get(), timeout=10) current_waiting_time = 0 except asyncio.TimeoutError: current_waiting_time += 10 @@ -194,7 +198,7 @@ async def completion_full_generator( current_waiting_time = 0 await asyncio.sleep(0.1) continue - response = msgpack.unpackb(raw_data[-1]) + for data in response: rid = int(data["request_id"].split("-")[-1]) if data.get("error_code", 200) != 200: @@ -239,7 +243,8 @@ async def completion_full_generator( finally: self.engine_client.semaphore.release() if dealer is not None: - dealer.close() + await self.engine_client.connection_manager.cleanup_request(request_id) + self.engine_client.semaphore.release() async def _echo_back_prompt(self, request, res, idx): if res["outputs"].get("send_idx", -1) == 0 and request.echo: @@ -272,7 +277,9 @@ async def completion_stream_generator( Process the stream completion request. """ try: - dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc") + await self._ensure_connection_manager() + dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id) + dealer.write([b"", request_id.encode("utf-8")]) for i in range(num_choices): req_id = f"{request_id}-{i}" @@ -296,7 +303,7 @@ async def completion_stream_generator( current_waiting_time = 0 while num_choices > 0: try: - raw_data = await asyncio.wait_for(dealer.read(), timeout=10) + response = await asyncio.wait_for(response_queue.get(), timeout=10) current_waiting_time = 0 except asyncio.TimeoutError: current_waiting_time += 10 @@ -309,7 +316,6 @@ async def completion_stream_generator( await asyncio.sleep(0.1) continue - response = msgpack.unpackb(raw_data[-1]) for res in response: idx = int(res["request_id"].split("-")[-1]) if res.get("error_code", 200) != 200: @@ -436,7 +442,8 @@ async def completion_stream_generator( del request self.engine_client.semaphore.release() if dealer is not None: - dealer.close() + await self.engine_client.connection_manager.cleanup_request(request_id) + self.engine_client.semaphore.release() yield "data: [DONE]\n\n" def request_output_to_completion_response( diff --git a/fastdeploy/entrypoints/openai/utils.py b/fastdeploy/entrypoints/openai/utils.py new file mode 100644 index 0000000000..360c0f6b6c --- /dev/null +++ b/fastdeploy/entrypoints/openai/utils.py @@ -0,0 +1,152 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import asyncio +import heapq +import random + +import aiozmq +import msgpack +import zmq + +from fastdeploy.utils import api_server_logger + + +class DealerConnectionManager: + """ + Manager for dealer connections, supporting multiplexing and connection reuse + """ + + def __init__(self, pid, max_connections=10): + self.pid = pid + self.max_connections = max(max_connections, 10) + self.connections = [] + self.connection_load = [] + self.connection_heap = [] + self.request_map = {} # request_id -> response_queue + self.lock = asyncio.Lock() + self.connection_tasks = [] + self.running = False + + async def initialize(self): + """initialize all connections""" + self.running = True + for index in range(self.max_connections): + await self._add_connection(index) + api_server_logger.info(f"Started {self.max_connections} connections") + + async def _add_connection(self, index): + """create a new connection and start listening task""" + try: + dealer = await aiozmq.create_zmq_stream( + zmq.DEALER, + connect=f"ipc:///dev/shm/router_{self.pid}.ipc", + ) + async with self.lock: + self.connections.append(dealer) + self.connection_load.append(0) + heapq.heappush(self.connection_heap, (0, index)) + + # start listening + task = asyncio.create_task(self._listen_connection(dealer, index)) + self.connection_tasks.append(task) + return True + except Exception as e: + api_server_logger.error(f"Failed to create dealer: {str(e)}") + return False + + async def _listen_connection(self, dealer, conn_index): + """ + listen for messages from the dealer connection + """ + while self.running: + try: + raw_data = await dealer.read() + response = msgpack.unpackb(raw_data[-1]) + request_id = response[-1]["request_id"] + async with self.lock: + if request_id in self.request_map: + await self.request_map[request_id].put(response) + if response[-1]["finished"]: + self._update_load(conn_index, -1) + except Exception as e: + api_server_logger.error(f"Listener error: {str(e)}") + break + + def _update_load(self, conn_index, delta): + """Update connection load and maintain the heap""" + self.connection_load[conn_index] += delta + heapq.heapify(self.connection_heap) + + # For Debugging purposes + if random.random() < 0.01: + min_load = self.connection_heap[0][0] if self.connection_heap else 0 + max_load = max(self.connection_load) if self.connection_load else 0 + api_server_logger.debug(f"Connection load update: min={min_load}, max={max_load}") + + def _get_least_loaded_connection(self): + """ + Get the least loaded connection + """ + if not self.connection_heap: + return None + + load, conn_index = self.connection_heap[0] + self._update_load(conn_index, 1) + + return self.connections[conn_index] + + async def get_connection(self, request_id): + """get a connection for the request""" + + response_queue = asyncio.Queue() + + async with self.lock: + self.request_map[request_id] = response_queue + dealer = self._get_least_loaded_connection() + if not dealer: + raise RuntimeError("No available connections") + + return dealer, response_queue + + async def cleanup_request(self, request_id): + """ + clean up the request after it is finished + """ + async with self.lock: + if request_id in self.request_map: + del self.request_map[request_id] + + async def close(self): + """ + close all connections and tasks + """ + self.running = False + + for task in self.connection_tasks: + task.cancel() + + async with self.lock: + for dealer in self.connections: + try: + dealer.close() + except: + pass + self.connections.clear() + self.connection_load.clear() + self.request_map.clear() + + api_server_logger.info("All connections and tasks closed") diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 1c310961cb..0155e260f0 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -85,7 +85,7 @@ # set trace attribute job_id. "FD_JOB_ID": lambda: os.getenv("FD_JOB_ID"), # support max connections - "FD_SUPPORT_MAX_CONNECTIONS": lambda: 768, + "FD_SUPPORT_MAX_CONNECTIONS": lambda: int(os.getenv("FD_SUPPORT_MAX_CONNECTIONS", "1024")), } From 4d4c4cbd4b134dc913a746b8036f29b20e6f0505 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Tue, 19 Aug 2025 16:50:55 +0800 Subject: [PATCH 07/16] fix --- fastdeploy/entrypoints/openai/api_server.py | 1 + .../openai/test_dealer_connection_manager.py | 115 ++++++++++++++++++ 2 files changed, 116 insertions(+) create mode 100644 test/entrypoints/openai/test_dealer_connection_manager.py diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index b180d26e76..116d8a9c87 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -156,6 +156,7 @@ async def lifespan(app: FastAPI): await engine_client.connection_manager.close() engine_client.zmq_client.close() from prometheus_client import multiprocess + multiprocess.mark_process_dead(os.getpid()) api_server_logger.info(f"Closing metrics client pid: {pid}") except Exception as e: diff --git a/test/entrypoints/openai/test_dealer_connection_manager.py b/test/entrypoints/openai/test_dealer_connection_manager.py new file mode 100644 index 0000000000..28fea74208 --- /dev/null +++ b/test/entrypoints/openai/test_dealer_connection_manager.py @@ -0,0 +1,115 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest +from unittest import mock + +import msgpack + +from fastdeploy.entrypoints.openai.utils import DealerConnectionManager + + +class TestDealerConnectionManager(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.patchers = [mock.patch("aiozmq.create_zmq_stream"), mock.patch("fastdeploy.utils.api_server_logger")] + for p in self.patchers: + p.start() + self.addCleanup(p.stop) + + self.mock_create_stream = self.patchers[0].start() + self.mock_logger = self.patchers[1].start() + + async def test_initialize(self): + """Test initialization of connections""" + manager = DealerConnectionManager(pid=1, max_connections=5) + + # Mock the stream creation + mock_stream = mock.AsyncMock() + self.mock_create_stream.return_value = mock_stream + + await manager.initialize() + + # Verify connections were created + self.assertEqual(len(manager.connections), 5) + self.mock_logger.info.assert_called_with("Started 5 connections") + + async def test_get_connection(self): + """Test getting a connection with load balancing""" + manager = DealerConnectionManager(pid=1, max_connections=2) + + # Mock the stream creation + mock_stream1 = mock.AsyncMock() + mock_stream2 = mock.AsyncMock() + self.mock_create_stream.side_effect = [mock_stream1, mock_stream2] + + await manager.initialize() + + # First request + conn1, queue1 = await manager.get_connection("req1") + self.assertIs(conn1, mock_stream1) + + # Second request should use different connection + conn2, queue2 = await manager.get_connection("req2") + self.assertIs(conn2, mock_stream2) + + # Third request should go back to first connection (least loaded) + conn3, queue3 = await manager.get_connection("req3") + self.assertIs(conn3, mock_stream1) + + async def test_listen_connection(self): + """Test message listening""" + manager = DealerConnectionManager(pid=1) + manager.running = True + + # Mock connection + mock_stream = mock.AsyncMock() + mock_stream.read.return_value = [b"", msgpack.packb({"request_id": "req1", "finished": True})] + + # Mock response queue + mock_queue = mock.AsyncMock() + manager.request_map["req1"] = mock_queue + + await manager._listen_connection(mock_stream, 0) + + # Verify message was processed + mock_queue.put.assert_called_once() + self.assertEqual(manager.connection_load[0], -1) + + async def test_close(self): + """Test cleanup on close""" + manager = DealerConnectionManager(pid=1) + manager.running = True + + # Mock connection + mock_stream = mock.MagicMock() + mock_task = mock.MagicMock() + manager.connections.append(mock_stream) + manager.connection_tasks.append(mock_task) + manager.request_map["req1"] = mock.AsyncMock() + + await manager.close() + + # Verify cleanup + self.assertFalse(manager.running) + mock_stream.close.assert_called_once() + mock_task.cancel.assert_called_once() + self.assertEqual(len(manager.connections), 0) + self.assertEqual(len(manager.request_map), 0) + self.mock_logger.info.assert_called_with("All connections and tasks closed") + + +if __name__ == "__main__": + unittest.main() From 91d8a5f590e707d76808159f769717e48402c0bd Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Tue, 19 Aug 2025 23:28:54 +0800 Subject: [PATCH 08/16] fix --- fastdeploy/entrypoints/engine_client.py | 1 - fastdeploy/entrypoints/openai/serving_completion.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 2dba67e70e..d78da82079 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -92,7 +92,6 @@ def __init__( suffix=pid, create=False, ) - self.semaphore = StatefulSemaphore((envs.FD_SUPPORT_MAX_CONNECTIONS + workers - 1) // workers) self.connection_manager = DealerConnectionManager( pid, max_connections=int(os.getenv("FD_DEALER_CONNECTIONS", 50)) ) diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index cc07787e2c..2bea6458e5 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -244,7 +244,6 @@ async def completion_full_generator( self.engine_client.semaphore.release() if dealer is not None: await self.engine_client.connection_manager.cleanup_request(request_id) - self.engine_client.semaphore.release() async def _echo_back_prompt(self, request, res, idx): if res["outputs"].get("send_idx", -1) == 0 and request.echo: @@ -279,7 +278,6 @@ async def completion_stream_generator( try: await self._ensure_connection_manager() dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id) - dealer.write([b"", request_id.encode("utf-8")]) for i in range(num_choices): req_id = f"{request_id}-{i}" @@ -440,7 +438,6 @@ async def completion_stream_generator( yield f"data: {ErrorResponse(message=str(e), code=400).model_dump_json(exclude_unset=True)}\n\n" finally: del request - self.engine_client.semaphore.release() if dealer is not None: await self.engine_client.connection_manager.cleanup_request(request_id) self.engine_client.semaphore.release() From faef052053985275070a9c1fa39951e7147da105 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Tue, 19 Aug 2025 23:51:27 +0800 Subject: [PATCH 09/16] fix --- fastdeploy/entrypoints/openai/serving_completion.py | 6 ++++-- fastdeploy/entrypoints/openai/utils.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 2bea6458e5..4c78207ec6 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -173,7 +173,7 @@ async def completion_full_generator( request_ids = [f"{request_id}-{i}" for i in range(num_choices)] # create dealer await self._ensure_connection_manager() - dealer, response_queue = await self.engine.connection_manager.get_connection(request_id) + dealer, response_queue = await self.engine.connection_manager.get_connection(request_id, num_choices) for rid in request_ids: dealer.write([b"", rid.encode("utf-8")]) @@ -277,7 +277,9 @@ async def completion_stream_generator( """ try: await self._ensure_connection_manager() - dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id) + dealer, response_queue = await self.engine_client.connection_manager.get_connection( + request_id, num_choices + ) for i in range(num_choices): req_id = f"{request_id}-{i}" diff --git a/fastdeploy/entrypoints/openai/utils.py b/fastdeploy/entrypoints/openai/utils.py index 360c0f6b6c..60b68f68d6 100644 --- a/fastdeploy/entrypoints/openai/utils.py +++ b/fastdeploy/entrypoints/openai/utils.py @@ -37,6 +37,7 @@ def __init__(self, pid, max_connections=10): self.connection_load = [] self.connection_heap = [] self.request_map = {} # request_id -> response_queue + self.request_num = {} # request_id -> num_choices self.lock = asyncio.Lock() self.connection_tasks = [] self.running = False @@ -77,11 +78,15 @@ async def _listen_connection(self, dealer, conn_index): raw_data = await dealer.read() response = msgpack.unpackb(raw_data[-1]) request_id = response[-1]["request_id"] + if "cmpl" in request_id: + request_id = request_id.rsplit("-", 1)[0] async with self.lock: if request_id in self.request_map: await self.request_map[request_id].put(response) if response[-1]["finished"]: - self._update_load(conn_index, -1) + self.request_num[request_id] -= 1 + if self.request_num[request_id] == 0: + self._update_load(conn_index, -1) except Exception as e: api_server_logger.error(f"Listener error: {str(e)}") break @@ -109,13 +114,14 @@ def _get_least_loaded_connection(self): return self.connections[conn_index] - async def get_connection(self, request_id): + async def get_connection(self, request_id, num_choices=1): """get a connection for the request""" response_queue = asyncio.Queue() async with self.lock: self.request_map[request_id] = response_queue + self.request_num[request_id] = num_choices dealer = self._get_least_loaded_connection() if not dealer: raise RuntimeError("No available connections") @@ -129,6 +135,7 @@ async def cleanup_request(self, request_id): async with self.lock: if request_id in self.request_map: del self.request_map[request_id] + del self.request_num[request_id] async def close(self): """ From a18cdb61e45dd313a5463426d83c1da82b1e6bf1 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Wed, 20 Aug 2025 10:46:09 +0800 Subject: [PATCH 10/16] fix --- fastdeploy/entrypoints/openai/serving_completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 4c78207ec6..fb10e0e1f5 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -173,7 +173,7 @@ async def completion_full_generator( request_ids = [f"{request_id}-{i}" for i in range(num_choices)] # create dealer await self._ensure_connection_manager() - dealer, response_queue = await self.engine.connection_manager.get_connection(request_id, num_choices) + dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id, num_choices) for rid in request_ids: dealer.write([b"", rid.encode("utf-8")]) From b9841a9ce5371915ba5249aa0ab9c688bb1f62f4 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Wed, 20 Aug 2025 11:11:49 +0800 Subject: [PATCH 11/16] fix --- fastdeploy/entrypoints/openai/serving_completion.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index e39aec46e3..a0cfb89704 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -189,7 +189,9 @@ async def completion_full_generator( request_ids = [f"{request_id}-{i}" for i in range(num_choices)] # create dealer await self._ensure_connection_manager() - dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id, num_choices) + dealer, response_queue = await self.engine_client.connection_manager.get_connection( + request_id, num_choices + ) for rid in request_ids: dealer.write([b"", rid.encode("utf-8")]) From 23e922a67bfd451ca147b444bf36f5a09269d6af Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Wed, 20 Aug 2025 13:05:16 +0800 Subject: [PATCH 12/16] fix --- fastdeploy/entrypoints/openai/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/entrypoints/openai/utils.py b/fastdeploy/entrypoints/openai/utils.py index 60b68f68d6..d33eb01c2b 100644 --- a/fastdeploy/entrypoints/openai/utils.py +++ b/fastdeploy/entrypoints/openai/utils.py @@ -78,7 +78,7 @@ async def _listen_connection(self, dealer, conn_index): raw_data = await dealer.read() response = msgpack.unpackb(raw_data[-1]) request_id = response[-1]["request_id"] - if "cmpl" in request_id: + if "cmpl" == request_id[:4]: request_id = request_id.rsplit("-", 1)[0] async with self.lock: if request_id in self.request_map: From 51bd8a241aad7957006fd9dc62e074ce45e59125 Mon Sep 17 00:00:00 2001 From: ltd0924 <32387785+ltd0924@users.noreply.github.com> Date: Wed, 20 Aug 2025 19:28:23 +0800 Subject: [PATCH 13/16] Create test_dealer_connection_manager.py --- .../openai/test_dealer_connection_manager.py | 115 ++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 tests/entrypoints/openai/test_dealer_connection_manager.py diff --git a/tests/entrypoints/openai/test_dealer_connection_manager.py b/tests/entrypoints/openai/test_dealer_connection_manager.py new file mode 100644 index 0000000000..28fea74208 --- /dev/null +++ b/tests/entrypoints/openai/test_dealer_connection_manager.py @@ -0,0 +1,115 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest +from unittest import mock + +import msgpack + +from fastdeploy.entrypoints.openai.utils import DealerConnectionManager + + +class TestDealerConnectionManager(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.patchers = [mock.patch("aiozmq.create_zmq_stream"), mock.patch("fastdeploy.utils.api_server_logger")] + for p in self.patchers: + p.start() + self.addCleanup(p.stop) + + self.mock_create_stream = self.patchers[0].start() + self.mock_logger = self.patchers[1].start() + + async def test_initialize(self): + """Test initialization of connections""" + manager = DealerConnectionManager(pid=1, max_connections=5) + + # Mock the stream creation + mock_stream = mock.AsyncMock() + self.mock_create_stream.return_value = mock_stream + + await manager.initialize() + + # Verify connections were created + self.assertEqual(len(manager.connections), 5) + self.mock_logger.info.assert_called_with("Started 5 connections") + + async def test_get_connection(self): + """Test getting a connection with load balancing""" + manager = DealerConnectionManager(pid=1, max_connections=2) + + # Mock the stream creation + mock_stream1 = mock.AsyncMock() + mock_stream2 = mock.AsyncMock() + self.mock_create_stream.side_effect = [mock_stream1, mock_stream2] + + await manager.initialize() + + # First request + conn1, queue1 = await manager.get_connection("req1") + self.assertIs(conn1, mock_stream1) + + # Second request should use different connection + conn2, queue2 = await manager.get_connection("req2") + self.assertIs(conn2, mock_stream2) + + # Third request should go back to first connection (least loaded) + conn3, queue3 = await manager.get_connection("req3") + self.assertIs(conn3, mock_stream1) + + async def test_listen_connection(self): + """Test message listening""" + manager = DealerConnectionManager(pid=1) + manager.running = True + + # Mock connection + mock_stream = mock.AsyncMock() + mock_stream.read.return_value = [b"", msgpack.packb({"request_id": "req1", "finished": True})] + + # Mock response queue + mock_queue = mock.AsyncMock() + manager.request_map["req1"] = mock_queue + + await manager._listen_connection(mock_stream, 0) + + # Verify message was processed + mock_queue.put.assert_called_once() + self.assertEqual(manager.connection_load[0], -1) + + async def test_close(self): + """Test cleanup on close""" + manager = DealerConnectionManager(pid=1) + manager.running = True + + # Mock connection + mock_stream = mock.MagicMock() + mock_task = mock.MagicMock() + manager.connections.append(mock_stream) + manager.connection_tasks.append(mock_task) + manager.request_map["req1"] = mock.AsyncMock() + + await manager.close() + + # Verify cleanup + self.assertFalse(manager.running) + mock_stream.close.assert_called_once() + mock_task.cancel.assert_called_once() + self.assertEqual(len(manager.connections), 0) + self.assertEqual(len(manager.request_map), 0) + self.mock_logger.info.assert_called_with("All connections and tasks closed") + + +if __name__ == "__main__": + unittest.main() From 99bec19f5df4749d10352b9b2597345d7067504e Mon Sep 17 00:00:00 2001 From: ltd0924 <32387785+ltd0924@users.noreply.github.com> Date: Wed, 20 Aug 2025 19:29:11 +0800 Subject: [PATCH 14/16] Delete test/entrypoints/openai directory --- .../openai/test_dealer_connection_manager.py | 115 ------------------ 1 file changed, 115 deletions(-) delete mode 100644 test/entrypoints/openai/test_dealer_connection_manager.py diff --git a/test/entrypoints/openai/test_dealer_connection_manager.py b/test/entrypoints/openai/test_dealer_connection_manager.py deleted file mode 100644 index 28fea74208..0000000000 --- a/test/entrypoints/openai/test_dealer_connection_manager.py +++ /dev/null @@ -1,115 +0,0 @@ -""" -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License" -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -import unittest -from unittest import mock - -import msgpack - -from fastdeploy.entrypoints.openai.utils import DealerConnectionManager - - -class TestDealerConnectionManager(unittest.IsolatedAsyncioTestCase): - def setUp(self): - self.patchers = [mock.patch("aiozmq.create_zmq_stream"), mock.patch("fastdeploy.utils.api_server_logger")] - for p in self.patchers: - p.start() - self.addCleanup(p.stop) - - self.mock_create_stream = self.patchers[0].start() - self.mock_logger = self.patchers[1].start() - - async def test_initialize(self): - """Test initialization of connections""" - manager = DealerConnectionManager(pid=1, max_connections=5) - - # Mock the stream creation - mock_stream = mock.AsyncMock() - self.mock_create_stream.return_value = mock_stream - - await manager.initialize() - - # Verify connections were created - self.assertEqual(len(manager.connections), 5) - self.mock_logger.info.assert_called_with("Started 5 connections") - - async def test_get_connection(self): - """Test getting a connection with load balancing""" - manager = DealerConnectionManager(pid=1, max_connections=2) - - # Mock the stream creation - mock_stream1 = mock.AsyncMock() - mock_stream2 = mock.AsyncMock() - self.mock_create_stream.side_effect = [mock_stream1, mock_stream2] - - await manager.initialize() - - # First request - conn1, queue1 = await manager.get_connection("req1") - self.assertIs(conn1, mock_stream1) - - # Second request should use different connection - conn2, queue2 = await manager.get_connection("req2") - self.assertIs(conn2, mock_stream2) - - # Third request should go back to first connection (least loaded) - conn3, queue3 = await manager.get_connection("req3") - self.assertIs(conn3, mock_stream1) - - async def test_listen_connection(self): - """Test message listening""" - manager = DealerConnectionManager(pid=1) - manager.running = True - - # Mock connection - mock_stream = mock.AsyncMock() - mock_stream.read.return_value = [b"", msgpack.packb({"request_id": "req1", "finished": True})] - - # Mock response queue - mock_queue = mock.AsyncMock() - manager.request_map["req1"] = mock_queue - - await manager._listen_connection(mock_stream, 0) - - # Verify message was processed - mock_queue.put.assert_called_once() - self.assertEqual(manager.connection_load[0], -1) - - async def test_close(self): - """Test cleanup on close""" - manager = DealerConnectionManager(pid=1) - manager.running = True - - # Mock connection - mock_stream = mock.MagicMock() - mock_task = mock.MagicMock() - manager.connections.append(mock_stream) - manager.connection_tasks.append(mock_task) - manager.request_map["req1"] = mock.AsyncMock() - - await manager.close() - - # Verify cleanup - self.assertFalse(manager.running) - mock_stream.close.assert_called_once() - mock_task.cancel.assert_called_once() - self.assertEqual(len(manager.connections), 0) - self.assertEqual(len(manager.request_map), 0) - self.mock_logger.info.assert_called_with("All connections and tasks closed") - - -if __name__ == "__main__": - unittest.main() From 62115e5d683db025e6e7c56a3eca732bcd07fe11 Mon Sep 17 00:00:00 2001 From: ltd0924 <32387785+ltd0924@users.noreply.github.com> Date: Wed, 20 Aug 2025 23:46:19 +0800 Subject: [PATCH 15/16] Update test_dealer_connection_manager.py --- .../openai/test_dealer_connection_manager.py | 215 ++++++++++-------- 1 file changed, 126 insertions(+), 89 deletions(-) diff --git a/tests/entrypoints/openai/test_dealer_connection_manager.py b/tests/entrypoints/openai/test_dealer_connection_manager.py index 28fea74208..ec7dcfdc4d 100644 --- a/tests/entrypoints/openai/test_dealer_connection_manager.py +++ b/tests/entrypoints/openai/test_dealer_connection_manager.py @@ -15,101 +15,138 @@ """ import unittest -from unittest import mock - -import msgpack - +import asyncio +from unittest.mock import AsyncMock, patch, MagicMock from fastdeploy.entrypoints.openai.utils import DealerConnectionManager +class TestDealerConnectionManager(unittest.TestCase): + """Test cases for DealerConnectionManager""" -class TestDealerConnectionManager(unittest.IsolatedAsyncioTestCase): def setUp(self): - self.patchers = [mock.patch("aiozmq.create_zmq_stream"), mock.patch("fastdeploy.utils.api_server_logger")] - for p in self.patchers: - p.start() - self.addCleanup(p.stop) - - self.mock_create_stream = self.patchers[0].start() - self.mock_logger = self.patchers[1].start() - - async def test_initialize(self): - """Test initialization of connections""" - manager = DealerConnectionManager(pid=1, max_connections=5) - - # Mock the stream creation - mock_stream = mock.AsyncMock() - self.mock_create_stream.return_value = mock_stream - - await manager.initialize() - + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.manager = DealerConnectionManager(pid=1, max_connections=5) + + def tearDown(self): + self.loop.run_until_complete(self.manager.close()) + self.loop.close() + + @patch('aiozmq.create_zmq_stream') + async def test_initialization(self, mock_create): + """Test manager initialization creates connections""" + mock_stream = AsyncMock() + mock_create.return_value = mock_stream + + # Test initialization + await self.manager.initialize() + # Verify connections were created - self.assertEqual(len(manager.connections), 5) - self.mock_logger.info.assert_called_with("Started 5 connections") - - async def test_get_connection(self): + self.assertEqual(len(self.manager.connections), 10) + self.assertEqual(len(self.manager.connection_load), 10) + self.assertEqual(len(self.manager.connection_tasks), 10) + + # Verify connection tasks are running + for task in self.manager.connection_tasks: + self.assertFalse(task.done()) + + @patch('aiozmq.create_zmq_stream') + async def test_get_connection(self, mock_create): """Test getting a connection with load balancing""" - manager = DealerConnectionManager(pid=1, max_connections=2) - - # Mock the stream creation - mock_stream1 = mock.AsyncMock() - mock_stream2 = mock.AsyncMock() - self.mock_create_stream.side_effect = [mock_stream1, mock_stream2] - - await manager.initialize() - - # First request - conn1, queue1 = await manager.get_connection("req1") - self.assertIs(conn1, mock_stream1) - - # Second request should use different connection - conn2, queue2 = await manager.get_connection("req2") - self.assertIs(conn2, mock_stream2) - - # Third request should go back to first connection (least loaded) - conn3, queue3 = await manager.get_connection("req3") - self.assertIs(conn3, mock_stream1) - - async def test_listen_connection(self): - """Test message listening""" - manager = DealerConnectionManager(pid=1) - manager.running = True - - # Mock connection - mock_stream = mock.AsyncMock() - mock_stream.read.return_value = [b"", msgpack.packb({"request_id": "req1", "finished": True})] - - # Mock response queue - mock_queue = mock.AsyncMock() - manager.request_map["req1"] = mock_queue - - await manager._listen_connection(mock_stream, 0) - - # Verify message was processed - mock_queue.put.assert_called_once() - self.assertEqual(manager.connection_load[0], -1) - - async def test_close(self): - """Test cleanup on close""" - manager = DealerConnectionManager(pid=1) - manager.running = True - - # Mock connection - mock_stream = mock.MagicMock() - mock_task = mock.MagicMock() - manager.connections.append(mock_stream) - manager.connection_tasks.append(mock_task) - manager.request_map["req1"] = mock.AsyncMock() - - await manager.close() - + mock_stream = AsyncMock() + mock_create.return_value = mock_stream + await self.manager.initialize() + + # Get a connection + dealer, queue = await self.manager.get_connection("req1") + + # Verify least loaded connection is returned + self.assertEqual(self.manager.connection_load[0], 1) + self.assertIsNotNone(dealer) + self.assertIsNotNone(queue) + self.assertIn("req1", self.manager.request_map) + + @patch('aiozmq.create_zmq_stream') + async def test_connection_listening(self, mock_create): + """Test connection listener handles responses""" + mock_stream = AsyncMock() + mock_create.return_value = mock_stream + await self.manager.initialize() + + # Setup test response + test_response = {"request_id": "req1", "finished": True} + mock_stream.read.return_value = [b'', msgpack.packb(test_response)] + + # Simulate response + dealer, queue = await self.manager.get_connection("req1") + response = await queue.get() + + # Verify response handling + self.assertEqual(response[-1]["request_id"], "req1") + self.assertEqual(self.manager.connection_load[0], 0) # Should be decremented after finish + + @patch('aiozmq.create_zmq_stream') + async def test_request_cleanup(self, mock_create): + """Test request cleanup removes request tracking""" + mock_stream = AsyncMock() + mock_create.return_value = mock_stream + await self.manager.initialize() + + await self.manager.get_connection("req1") + self.assertIn("req1", self.manager.request_map) + + await self.manager.cleanup_request("req1") + self.assertNotIn("req1", self.manager.request_map) + + @patch('aiozmq.create_zmq_stream') + async def test_multiple_requests(self, mock_create): + """Test load balancing with multiple requests""" + mock_stream = AsyncMock() + mock_create.return_value = mock_stream + await self.manager.initialize() + + # Get multiple connections + connections = [] + for i in range(1, 6): + dealer, queue = await self.manager.get_connection(f"req{i}") + connections.append((dealer, queue)) + + # Verify load is distributed + load_counts = [0] * 5 + for i in range(5): + load_counts[i] = self.manager.connection_load[i] + + self.assertEqual(sum(load_counts), 5) + self.assertTrue(all(1 <= load <= 2 for load in load_counts)) + + @patch('aiozmq.create_zmq_stream') + async def test_connection_failure(self, mock_create): + """Test connection failure handling""" + mock_create.side_effect = Exception("Connection failed") + + with self.assertLogs(level='ERROR') as log: + await self.manager._add_connection(0) + self.assertTrue(any("Failed to create dealer" in msg for msg in log.output)) + + self.assertEqual(len(self.manager.connections), 0) + + @patch('aiozmq.create_zmq_stream') + async def test_close_manager(self, mock_create): + """Test manager shutdown""" + mock_stream = AsyncMock() + mock_create.return_value = mock_stream + await self.manager.initialize() + + # Verify connections exist + self.assertEqual(len(self.manager.connections), 5) + + # Close manager + await self.manager.close() + # Verify cleanup - self.assertFalse(manager.running) - mock_stream.close.assert_called_once() - mock_task.cancel.assert_called_once() - self.assertEqual(len(manager.connections), 0) - self.assertEqual(len(manager.request_map), 0) - self.mock_logger.info.assert_called_with("All connections and tasks closed") - + self.assertEqual(len(self.manager.connections), 0) + self.assertEqual(len(self.manager.request_map), 0) + for task in self.manager.connection_tasks: + self.assertTrue(task.cancelled()) -if __name__ == "__main__": +if __name__ == '__main__': unittest.main() From a306acdc47745813e6041a875b6e71f8bfbaf8b3 Mon Sep 17 00:00:00 2001 From: ltd0924 <32387785+ltd0924@users.noreply.github.com> Date: Wed, 20 Aug 2025 23:52:14 +0800 Subject: [PATCH 16/16] Update test_dealer_connection_manager.py --- .../openai/test_dealer_connection_manager.py | 67 ++++++++++--------- 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/tests/entrypoints/openai/test_dealer_connection_manager.py b/tests/entrypoints/openai/test_dealer_connection_manager.py index ec7dcfdc4d..4ab1e4b99a 100644 --- a/tests/entrypoints/openai/test_dealer_connection_manager.py +++ b/tests/entrypoints/openai/test_dealer_connection_manager.py @@ -14,11 +14,15 @@ # limitations under the License. """ -import unittest import asyncio -from unittest.mock import AsyncMock, patch, MagicMock +import unittest +from unittest.mock import AsyncMock, patch + +import msgpack + from fastdeploy.entrypoints.openai.utils import DealerConnectionManager + class TestDealerConnectionManager(unittest.TestCase): """Test cases for DealerConnectionManager""" @@ -26,127 +30,128 @@ def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) self.manager = DealerConnectionManager(pid=1, max_connections=5) - + def tearDown(self): self.loop.run_until_complete(self.manager.close()) self.loop.close() - @patch('aiozmq.create_zmq_stream') + @patch("aiozmq.create_zmq_stream") async def test_initialization(self, mock_create): """Test manager initialization creates connections""" mock_stream = AsyncMock() mock_create.return_value = mock_stream - + # Test initialization await self.manager.initialize() - + # Verify connections were created self.assertEqual(len(self.manager.connections), 10) self.assertEqual(len(self.manager.connection_load), 10) self.assertEqual(len(self.manager.connection_tasks), 10) - + # Verify connection tasks are running for task in self.manager.connection_tasks: self.assertFalse(task.done()) - @patch('aiozmq.create_zmq_stream') + @patch("aiozmq.create_zmq_stream") async def test_get_connection(self, mock_create): """Test getting a connection with load balancing""" mock_stream = AsyncMock() mock_create.return_value = mock_stream await self.manager.initialize() - + # Get a connection dealer, queue = await self.manager.get_connection("req1") - + # Verify least loaded connection is returned self.assertEqual(self.manager.connection_load[0], 1) self.assertIsNotNone(dealer) self.assertIsNotNone(queue) self.assertIn("req1", self.manager.request_map) - @patch('aiozmq.create_zmq_stream') + @patch("aiozmq.create_zmq_stream") async def test_connection_listening(self, mock_create): """Test connection listener handles responses""" mock_stream = AsyncMock() mock_create.return_value = mock_stream await self.manager.initialize() - + # Setup test response test_response = {"request_id": "req1", "finished": True} - mock_stream.read.return_value = [b'', msgpack.packb(test_response)] - + mock_stream.read.return_value = [b"", msgpack.packb(test_response)] + # Simulate response dealer, queue = await self.manager.get_connection("req1") response = await queue.get() - + # Verify response handling self.assertEqual(response[-1]["request_id"], "req1") self.assertEqual(self.manager.connection_load[0], 0) # Should be decremented after finish - @patch('aiozmq.create_zmq_stream') + @patch("aiozmq.create_zmq_stream") async def test_request_cleanup(self, mock_create): """Test request cleanup removes request tracking""" mock_stream = AsyncMock() mock_create.return_value = mock_stream await self.manager.initialize() - + await self.manager.get_connection("req1") self.assertIn("req1", self.manager.request_map) - + await self.manager.cleanup_request("req1") self.assertNotIn("req1", self.manager.request_map) - @patch('aiozmq.create_zmq_stream') + @patch("aiozmq.create_zmq_stream") async def test_multiple_requests(self, mock_create): """Test load balancing with multiple requests""" mock_stream = AsyncMock() mock_create.return_value = mock_stream await self.manager.initialize() - + # Get multiple connections connections = [] for i in range(1, 6): dealer, queue = await self.manager.get_connection(f"req{i}") connections.append((dealer, queue)) - + # Verify load is distributed load_counts = [0] * 5 for i in range(5): load_counts[i] = self.manager.connection_load[i] - + self.assertEqual(sum(load_counts), 5) self.assertTrue(all(1 <= load <= 2 for load in load_counts)) - @patch('aiozmq.create_zmq_stream') + @patch("aiozmq.create_zmq_stream") async def test_connection_failure(self, mock_create): """Test connection failure handling""" mock_create.side_effect = Exception("Connection failed") - - with self.assertLogs(level='ERROR') as log: + + with self.assertLogs(level="ERROR") as log: await self.manager._add_connection(0) self.assertTrue(any("Failed to create dealer" in msg for msg in log.output)) - + self.assertEqual(len(self.manager.connections), 0) - @patch('aiozmq.create_zmq_stream') + @patch("aiozmq.create_zmq_stream") async def test_close_manager(self, mock_create): """Test manager shutdown""" mock_stream = AsyncMock() mock_create.return_value = mock_stream await self.manager.initialize() - + # Verify connections exist self.assertEqual(len(self.manager.connections), 5) - + # Close manager await self.manager.close() - + # Verify cleanup self.assertEqual(len(self.manager.connections), 0) self.assertEqual(len(self.manager.request_map), 0) for task in self.manager.connection_tasks: self.assertTrue(task.cancelled()) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main()