Skip to content

[Feature] add dealer manager to reuse the connection #3471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions fastdeploy/entrypoints/engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
# limitations under the License.
"""

import os
import time
import uuid

import numpy as np

from fastdeploy import envs
from fastdeploy.engine.config import ModelConfig
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
Expand Down Expand Up @@ -90,6 +92,11 @@ def __init__(
suffix=pid,
create=False,
)
self.semaphore = StatefulSemaphore((envs.FD_SUPPORT_MAX_CONNECTIONS + workers - 1) // workers)
Copy link
Preview

Copilot AI Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The workers parameter is not defined in the init method. This will cause a NameError when initializing the EngineClient.

Copilot uses AI. Check for mistakes.

self.connection_manager = DealerConnectionManager(
pid, max_connections=int(os.getenv("FD_DEALER_CONNECTIONS", 50))
)
self.connection_initialized = False

def create_zmq_client(self, model, mode):
"""
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ async def lifespan(app: FastAPI):
yield
# close zmq
try:
await engine_client.connection_manager.close()
engine_client.zmq_client.close()
from prometheus_client import multiprocess

Expand Down
28 changes: 16 additions & 12 deletions fastdeploy/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@
import uuid
from typing import List, Optional

import aiozmq
import msgpack
import numpy as np
from aiozmq import zmq

from fastdeploy.entrypoints.openai.protocol import (
ChatCompletionRequest,
Expand Down Expand Up @@ -62,6 +59,12 @@ def __init__(self, engine_client, pid, ips, max_waiting_time, chat_template):
else:
self.master_ip = self.master_ip.split(",")[0]

async def _ensure_connection_manager(self):
"""ensure connection manager initialized"""
if not self.engine_client.connection_initialized:
await self.engine_client.connection_manager.initialize()
self.engine_client.connection_initialized = True

def _check_master(self):
if self.master_ip is None:
return True
Expand Down Expand Up @@ -170,14 +173,16 @@ async def chat_completion_stream_generator(
choices=[],
model=model_name,
)

try:
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
await self._ensure_connection_manager()
dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id)
dealer.write([b"", request_id.encode("utf-8")])
choices = []
current_waiting_time = 0
while num_choices > 0:
try:
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
response = await asyncio.wait_for(response_queue.get(), timeout=10)
current_waiting_time = 0
except asyncio.TimeoutError:
current_waiting_time += 10
Expand All @@ -192,7 +197,6 @@ async def chat_completion_stream_generator(
current_waiting_time = 0
await asyncio.sleep(0.01)
continue
response = msgpack.unpackb(raw_data[-1])
for res in response:
if res.get("error_code", 200) != 200:
raise ValueError("{}".format(res["error_msg"]))
Expand Down Expand Up @@ -339,9 +343,9 @@ async def chat_completion_stream_generator(
error_data = self._create_streaming_error_response(str(e))
yield f"data: {error_data}\n\n"
finally:
dealer.close()
await self.engine_client.connection_manager.cleanup_request(request_id)
self.engine_client.semaphore.release()
api_server_logger.info(f"release {self.engine_client.semaphore.status()}")
api_server_logger.info(f"release {request_id} {self.engine_client.semaphore.status()}")
yield "data: [DONE]\n\n"

async def chat_completion_full_generator(
Expand All @@ -364,7 +368,8 @@ async def chat_completion_full_generator(
include_stop_str_in_output = request.include_stop_str_in_output

try:
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
await self._ensure_connection_manager()
dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id)
dealer.write([b"", request_id.encode("utf-8")])
final_res = None
previous_num_tokens = 0
Expand All @@ -373,7 +378,7 @@ async def chat_completion_full_generator(
completion_token_ids = []
while True:
try:
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
response = await asyncio.wait_for(response_queue.get(), timeout=10)
current_waiting_time = 0
except asyncio.TimeoutError:
current_waiting_time += 10
Expand All @@ -386,7 +391,6 @@ async def chat_completion_full_generator(
await asyncio.sleep(0.1)
continue

response = msgpack.unpackb(raw_data[-1])
task_is_finished = False
for data in response:
if data.get("error_code", 200) != 200:
Expand Down Expand Up @@ -416,7 +420,7 @@ async def chat_completion_full_generator(
if task_is_finished:
break
finally:
dealer.close()
await self.engine_client.connection_manager.cleanup_request(request_id)
self.engine_client.semaphore.release()
api_server_logger.info(f"release {self.engine_client.semaphore.status()}")

Expand Down
29 changes: 18 additions & 11 deletions fastdeploy/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
import uuid
from typing import List, Optional

import aiozmq
import msgpack
import numpy as np
from aiozmq import zmq

from fastdeploy.engine.request import RequestOutput
from fastdeploy.entrypoints.openai.protocol import (
Expand Down Expand Up @@ -52,6 +49,12 @@ def __init__(self, engine_client, pid, ips, max_waiting_time):
else:
self.master_ip = self.master_ip.split(",")[0]

async def _ensure_connection_manager(self):
"""ensure connection manager initialized"""
if not self.engine_client.connection_initialized:
await self.engine_client.connection_manager.initialize()
self.engine_client.connection_initialized = True

def _check_master(self):
if self.master_ip is None:
return True
Expand Down Expand Up @@ -169,7 +172,8 @@ async def completion_full_generator(
try:
request_ids = [f"{request_id}-{i}" for i in range(num_choices)]
# create dealer
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
await self._ensure_connection_manager()
dealer, response_queue = await self.engine.connection_manager.get_connection(request_id)

for rid in request_ids:
dealer.write([b"", rid.encode("utf-8")])
Expand All @@ -182,7 +186,7 @@ async def completion_full_generator(
current_waiting_time = 0
while num_choices > 0:
try:
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
response = await asyncio.wait_for(response_queue.get(), timeout=10)
current_waiting_time = 0
except asyncio.TimeoutError:
current_waiting_time += 10
Expand All @@ -194,7 +198,7 @@ async def completion_full_generator(
current_waiting_time = 0
await asyncio.sleep(0.1)
continue
response = msgpack.unpackb(raw_data[-1])

for data in response:
rid = int(data["request_id"].split("-")[-1])
if data.get("error_code", 200) != 200:
Expand Down Expand Up @@ -239,7 +243,8 @@ async def completion_full_generator(
finally:
self.engine_client.semaphore.release()
if dealer is not None:
dealer.close()
await self.engine_client.connection_manager.cleanup_request(request_id)
self.engine_client.semaphore.release()

async def _echo_back_prompt(self, request, res, idx):
if res["outputs"].get("send_idx", -1) == 0 and request.echo:
Expand Down Expand Up @@ -272,7 +277,9 @@ async def completion_stream_generator(
Process the stream completion request.
"""
try:
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
await self._ensure_connection_manager()
dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id)
dealer.write([b"", request_id.encode("utf-8")])
Copy link
Preview

Copilot AI Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This write operation appears to be misplaced. The write should happen after the loop that creates multiple request IDs (lines 284-290), not before it.

Suggested change
dealer.write([b"", request_id.encode("utf-8")])

Copilot uses AI. Check for mistakes.


for i in range(num_choices):
req_id = f"{request_id}-{i}"
Expand All @@ -296,7 +303,7 @@ async def completion_stream_generator(
current_waiting_time = 0
while num_choices > 0:
try:
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
response = await asyncio.wait_for(response_queue.get(), timeout=10)
current_waiting_time = 0
except asyncio.TimeoutError:
current_waiting_time += 10
Expand All @@ -309,7 +316,6 @@ async def completion_stream_generator(
await asyncio.sleep(0.1)
continue

response = msgpack.unpackb(raw_data[-1])
for res in response:
idx = int(res["request_id"].split("-")[-1])
if res.get("error_code", 200) != 200:
Expand Down Expand Up @@ -436,7 +442,8 @@ async def completion_stream_generator(
del request
self.engine_client.semaphore.release()
if dealer is not None:
dealer.close()
await self.engine_client.connection_manager.cleanup_request(request_id)
self.engine_client.semaphore.release()
yield "data: [DONE]\n\n"

def request_output_to_completion_response(
Expand Down
152 changes: 152 additions & 0 deletions fastdeploy/entrypoints/openai/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import asyncio
import heapq
import random

import aiozmq
import msgpack
import zmq

from fastdeploy.utils import api_server_logger


class DealerConnectionManager:
"""
Manager for dealer connections, supporting multiplexing and connection reuse
"""

def __init__(self, pid, max_connections=10):
self.pid = pid
self.max_connections = max(max_connections, 10)
self.connections = []
self.connection_load = []
self.connection_heap = []
self.request_map = {} # request_id -> response_queue
self.lock = asyncio.Lock()
self.connection_tasks = []
self.running = False

async def initialize(self):
"""initialize all connections"""
self.running = True
for index in range(self.max_connections):
await self._add_connection(index)
api_server_logger.info(f"Started {self.max_connections} connections")

async def _add_connection(self, index):
"""create a new connection and start listening task"""
try:
dealer = await aiozmq.create_zmq_stream(
zmq.DEALER,
connect=f"ipc:///dev/shm/router_{self.pid}.ipc",
)
async with self.lock:
self.connections.append(dealer)
self.connection_load.append(0)
heapq.heappush(self.connection_heap, (0, index))

# start listening
task = asyncio.create_task(self._listen_connection(dealer, index))
self.connection_tasks.append(task)
return True
except Exception as e:
api_server_logger.error(f"Failed to create dealer: {str(e)}")
return False

async def _listen_connection(self, dealer, conn_index):
"""
listen for messages from the dealer connection
"""
while self.running:
try:
raw_data = await dealer.read()
response = msgpack.unpackb(raw_data[-1])
request_id = response[-1]["request_id"]
async with self.lock:
if request_id in self.request_map:
await self.request_map[request_id].put(response)
if response[-1]["finished"]:
self._update_load(conn_index, -1)
except Exception as e:
api_server_logger.error(f"Listener error: {str(e)}")
break

def _update_load(self, conn_index, delta):
"""Update connection load and maintain the heap"""
self.connection_load[conn_index] += delta
heapq.heapify(self.connection_heap)

# For Debugging purposes
if random.random() < 0.01:
min_load = self.connection_heap[0][0] if self.connection_heap else 0
max_load = max(self.connection_load) if self.connection_load else 0
api_server_logger.debug(f"Connection load update: min={min_load}, max={max_load}")

def _get_least_loaded_connection(self):
"""
Get the least loaded connection
"""
if not self.connection_heap:
return None

load, conn_index = self.connection_heap[0]
self._update_load(conn_index, 1)

return self.connections[conn_index]

async def get_connection(self, request_id):
"""get a connection for the request"""

response_queue = asyncio.Queue()

async with self.lock:
self.request_map[request_id] = response_queue
dealer = self._get_least_loaded_connection()
if not dealer:
raise RuntimeError("No available connections")

return dealer, response_queue

async def cleanup_request(self, request_id):
"""
clean up the request after it is finished
"""
async with self.lock:
if request_id in self.request_map:
del self.request_map[request_id]

async def close(self):
"""
close all connections and tasks
"""
self.running = False

for task in self.connection_tasks:
task.cancel()

async with self.lock:
for dealer in self.connections:
try:
dealer.close()
except:
pass
self.connections.clear()
self.connection_load.clear()
self.request_map.clear()

api_server_logger.info("All connections and tasks closed")
2 changes: 1 addition & 1 deletion fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
# set trace attribute job_id.
"FD_JOB_ID": lambda: os.getenv("FD_JOB_ID"),
# support max connections
"FD_SUPPORT_MAX_CONNECTIONS": lambda: 768,
"FD_SUPPORT_MAX_CONNECTIONS": lambda: int(os.getenv("FD_SUPPORT_MAX_CONNECTIONS", "1024")),
}


Expand Down
Loading
Loading