Skip to content

Commit 7ccd241

Browse files
committed
[Feature] add dealer manager to reuse the connection
1 parent 314ad93 commit 7ccd241

File tree

6 files changed

+195
-25
lines changed

6 files changed

+195
-25
lines changed

fastdeploy/entrypoints/engine_client.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
# limitations under the License.
1515
"""
1616

17+
import os
1718
import time
1819
import uuid
1920

2021
import numpy as np
2122

2223
from fastdeploy import envs
2324
from fastdeploy.engine.config import ModelConfig
25+
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
2426
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
2527
from fastdeploy.input.preprocess import InputPreprocessor
2628
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
@@ -90,6 +92,11 @@ def __init__(
9092
suffix=pid,
9193
create=False,
9294
)
95+
self.semaphore = StatefulSemaphore((envs.FD_SUPPORT_MAX_CONNECTIONS + workers - 1) // workers)
96+
self.connection_manager = DealerConnectionManager(
97+
pid, max_connections=int(os.getenv("FD_DEALER_CONNECTIONS", 50))
98+
)
99+
self.connection_initialized = False
93100

94101
def create_zmq_client(self, model, mode):
95102
"""

fastdeploy/entrypoints/openai/api_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,9 @@ async def lifespan(app: FastAPI):
153153
yield
154154
# close zmq
155155
try:
156+
await engine_client.connection_manager.close()
156157
engine_client.zmq_client.close()
157158
from prometheus_client import multiprocess
158-
159159
multiprocess.mark_process_dead(os.getpid())
160160
api_server_logger.info(f"Closing metrics client pid: {pid}")
161161
except Exception as e:

fastdeploy/entrypoints/openai/serving_chat.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@
2020
import uuid
2121
from typing import List, Optional
2222

23-
import aiozmq
24-
import msgpack
2523
import numpy as np
26-
from aiozmq import zmq
2724

2825
from fastdeploy.entrypoints.openai.protocol import (
2926
ChatCompletionRequest,
@@ -62,6 +59,12 @@ def __init__(self, engine_client, pid, ips, max_waiting_time, chat_template):
6259
else:
6360
self.master_ip = self.master_ip.split(",")[0]
6461

62+
async def _ensure_connection_manager(self):
63+
"""ensure connection manager initialized"""
64+
if not self.engine_client.connection_initialized:
65+
await self.engine_client.connection_manager.initialize()
66+
self.engine_client.connection_initialized = True
67+
6568
def _check_master(self):
6669
if self.master_ip is None:
6770
return True
@@ -170,14 +173,16 @@ async def chat_completion_stream_generator(
170173
choices=[],
171174
model=model_name,
172175
)
176+
173177
try:
174-
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
178+
await self._ensure_connection_manager()
179+
dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id)
175180
dealer.write([b"", request_id.encode("utf-8")])
176181
choices = []
177182
current_waiting_time = 0
178183
while num_choices > 0:
179184
try:
180-
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
185+
response = await asyncio.wait_for(response_queue.get(), timeout=10)
181186
current_waiting_time = 0
182187
except asyncio.TimeoutError:
183188
current_waiting_time += 10
@@ -192,7 +197,6 @@ async def chat_completion_stream_generator(
192197
current_waiting_time = 0
193198
await asyncio.sleep(0.01)
194199
continue
195-
response = msgpack.unpackb(raw_data[-1])
196200
for res in response:
197201
if res.get("error_code", 200) != 200:
198202
raise ValueError("{}".format(res["error_msg"]))
@@ -339,9 +343,9 @@ async def chat_completion_stream_generator(
339343
error_data = self._create_streaming_error_response(str(e))
340344
yield f"data: {error_data}\n\n"
341345
finally:
342-
dealer.close()
346+
await self.engine_client.connection_manager.cleanup_request(request_id)
343347
self.engine_client.semaphore.release()
344-
api_server_logger.info(f"release {self.engine_client.semaphore.status()}")
348+
api_server_logger.info(f"release {request_id} {self.engine_client.semaphore.status()}")
345349
yield "data: [DONE]\n\n"
346350

347351
async def chat_completion_full_generator(
@@ -364,7 +368,8 @@ async def chat_completion_full_generator(
364368
include_stop_str_in_output = request.include_stop_str_in_output
365369

366370
try:
367-
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
371+
await self._ensure_connection_manager()
372+
dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id)
368373
dealer.write([b"", request_id.encode("utf-8")])
369374
final_res = None
370375
previous_num_tokens = 0
@@ -373,7 +378,7 @@ async def chat_completion_full_generator(
373378
completion_token_ids = []
374379
while True:
375380
try:
376-
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
381+
response = await asyncio.wait_for(response_queue.get(), timeout=10)
377382
current_waiting_time = 0
378383
except asyncio.TimeoutError:
379384
current_waiting_time += 10
@@ -386,7 +391,6 @@ async def chat_completion_full_generator(
386391
await asyncio.sleep(0.1)
387392
continue
388393

389-
response = msgpack.unpackb(raw_data[-1])
390394
task_is_finished = False
391395
for data in response:
392396
if data.get("error_code", 200) != 200:
@@ -416,7 +420,7 @@ async def chat_completion_full_generator(
416420
if task_is_finished:
417421
break
418422
finally:
419-
dealer.close()
423+
await self.engine_client.connection_manager.cleanup_request(request_id)
420424
self.engine_client.semaphore.release()
421425
api_server_logger.info(f"release {self.engine_client.semaphore.status()}")
422426

fastdeploy/entrypoints/openai/serving_completion.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,7 @@
1919
import uuid
2020
from typing import List, Optional
2121

22-
import aiozmq
23-
import msgpack
2422
import numpy as np
25-
from aiozmq import zmq
2623

2724
from fastdeploy.engine.request import RequestOutput
2825
from fastdeploy.entrypoints.openai.protocol import (
@@ -52,6 +49,12 @@ def __init__(self, engine_client, pid, ips, max_waiting_time):
5249
else:
5350
self.master_ip = self.master_ip.split(",")[0]
5451

52+
async def _ensure_connection_manager(self):
53+
"""ensure connection manager initialized"""
54+
if not self.engine_client.connection_initialized:
55+
await self.engine_client.connection_manager.initialize()
56+
self.engine_client.connection_initialized = True
57+
5558
def _check_master(self):
5659
if self.master_ip is None:
5760
return True
@@ -169,7 +172,8 @@ async def completion_full_generator(
169172
try:
170173
request_ids = [f"{request_id}-{i}" for i in range(num_choices)]
171174
# create dealer
172-
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
175+
await self._ensure_connection_manager()
176+
dealer, response_queue = await self.engine.connection_manager.get_connection(request_id)
173177

174178
for rid in request_ids:
175179
dealer.write([b"", rid.encode("utf-8")])
@@ -182,7 +186,7 @@ async def completion_full_generator(
182186
current_waiting_time = 0
183187
while num_choices > 0:
184188
try:
185-
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
189+
response = await asyncio.wait_for(response_queue.get(), timeout=10)
186190
current_waiting_time = 0
187191
except asyncio.TimeoutError:
188192
current_waiting_time += 10
@@ -194,7 +198,7 @@ async def completion_full_generator(
194198
current_waiting_time = 0
195199
await asyncio.sleep(0.1)
196200
continue
197-
response = msgpack.unpackb(raw_data[-1])
201+
198202
for data in response:
199203
rid = int(data["request_id"].split("-")[-1])
200204
if data.get("error_code", 200) != 200:
@@ -239,7 +243,8 @@ async def completion_full_generator(
239243
finally:
240244
self.engine_client.semaphore.release()
241245
if dealer is not None:
242-
dealer.close()
246+
await self.engine_client.connection_manager.cleanup_request(request_id)
247+
self.engine_client.semaphore.release()
243248

244249
async def _echo_back_prompt(self, request, res, idx):
245250
if res["outputs"].get("send_idx", -1) == 0 and request.echo:
@@ -272,7 +277,9 @@ async def completion_stream_generator(
272277
Process the stream completion request.
273278
"""
274279
try:
275-
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
280+
await self._ensure_connection_manager()
281+
dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id)
282+
dealer.write([b"", request_id.encode("utf-8")])
276283

277284
for i in range(num_choices):
278285
req_id = f"{request_id}-{i}"
@@ -296,7 +303,7 @@ async def completion_stream_generator(
296303
current_waiting_time = 0
297304
while num_choices > 0:
298305
try:
299-
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
306+
response = await asyncio.wait_for(response_queue.get(), timeout=10)
300307
current_waiting_time = 0
301308
except asyncio.TimeoutError:
302309
current_waiting_time += 10
@@ -309,7 +316,6 @@ async def completion_stream_generator(
309316
await asyncio.sleep(0.1)
310317
continue
311318

312-
response = msgpack.unpackb(raw_data[-1])
313319
for res in response:
314320
idx = int(res["request_id"].split("-")[-1])
315321
if res.get("error_code", 200) != 200:
@@ -436,7 +442,8 @@ async def completion_stream_generator(
436442
del request
437443
self.engine_client.semaphore.release()
438444
if dealer is not None:
439-
dealer.close()
445+
await self.engine_client.connection_manager.cleanup_request(request_id)
446+
self.engine_client.semaphore.release()
440447
yield "data: [DONE]\n\n"
441448

442449
def request_output_to_completion_response(
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
import asyncio
18+
import heapq
19+
import random
20+
21+
import aiozmq
22+
import msgpack
23+
import zmq
24+
25+
from fastdeploy.utils import api_server_logger
26+
27+
28+
class DealerConnectionManager:
29+
"""
30+
Manager for dealer connections, supporting multiplexing and connection reuse
31+
"""
32+
33+
def __init__(self, pid, max_connections=10):
34+
self.pid = pid
35+
self.max_connections = max(max_connections, 10)
36+
self.connections = []
37+
self.connection_load = []
38+
self.connection_heap = []
39+
self.request_map = {} # request_id -> response_queue
40+
self.lock = asyncio.Lock()
41+
self.connection_tasks = []
42+
self.running = False
43+
44+
async def initialize(self):
45+
"""initialize all connections"""
46+
self.running = True
47+
for index in range(self.max_connections):
48+
await self._add_connection(index)
49+
api_server_logger.info(f"Started {self.max_connections} connections")
50+
51+
async def _add_connection(self, index):
52+
"""create a new connection and start listening task"""
53+
try:
54+
dealer = await aiozmq.create_zmq_stream(
55+
zmq.DEALER,
56+
connect=f"ipc:///dev/shm/router_{self.pid}.ipc",
57+
)
58+
async with self.lock:
59+
self.connections.append(dealer)
60+
self.connection_load.append(0)
61+
heapq.heappush(self.connection_heap, (0, index))
62+
63+
# start listening
64+
task = asyncio.create_task(self._listen_connection(dealer, index))
65+
self.connection_tasks.append(task)
66+
return True
67+
except Exception as e:
68+
api_server_logger.error(f"Failed to create dealer: {str(e)}")
69+
return False
70+
71+
async def _listen_connection(self, dealer, conn_index):
72+
"""
73+
listen for messages from the dealer connection
74+
"""
75+
while self.running:
76+
try:
77+
raw_data = await dealer.read()
78+
response = msgpack.unpackb(raw_data[-1])
79+
request_id = response[-1]["request_id"]
80+
async with self.lock:
81+
if request_id in self.request_map:
82+
await self.request_map[request_id].put(response)
83+
if response[-1]["finished"]:
84+
self._update_load(conn_index, -1)
85+
except Exception as e:
86+
api_server_logger.error(f"Listener error: {str(e)}")
87+
break
88+
89+
def _update_load(self, conn_index, delta):
90+
"""Update connection load and maintain the heap"""
91+
self.connection_load[conn_index] += delta
92+
heapq.heapify(self.connection_heap)
93+
94+
# For Debugging purposes
95+
if random.random() < 0.01:
96+
min_load = self.connection_heap[0][0] if self.connection_heap else 0
97+
max_load = max(self.connection_load) if self.connection_load else 0
98+
api_server_logger.debug(f"Connection load update: min={min_load}, max={max_load}")
99+
100+
def _get_least_loaded_connection(self):
101+
"""
102+
Get the least loaded connection
103+
"""
104+
if not self.connection_heap:
105+
return None
106+
107+
load, conn_index = self.connection_heap[0]
108+
self._update_load(conn_index, 1)
109+
110+
return self.connections[conn_index]
111+
112+
async def get_connection(self, request_id):
113+
"""get a connection for the request"""
114+
115+
response_queue = asyncio.Queue()
116+
117+
async with self.lock:
118+
self.request_map[request_id] = response_queue
119+
dealer = self._get_least_loaded_connection()
120+
if not dealer:
121+
raise RuntimeError("No available connections")
122+
123+
return dealer, response_queue
124+
125+
async def cleanup_request(self, request_id):
126+
"""
127+
clean up the request after it is finished
128+
"""
129+
async with self.lock:
130+
if request_id in self.request_map:
131+
del self.request_map[request_id]
132+
133+
async def close(self):
134+
"""
135+
close all connections and tasks
136+
"""
137+
self.running = False
138+
139+
for task in self.connection_tasks:
140+
task.cancel()
141+
142+
async with self.lock:
143+
for dealer in self.connections:
144+
try:
145+
dealer.close()
146+
except:
147+
pass
148+
self.connections.clear()
149+
self.connection_load.clear()
150+
self.request_map.clear()
151+
152+
api_server_logger.info("All connections and tasks closed")

fastdeploy/envs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
# set trace attribute job_id.
8686
"FD_JOB_ID": lambda: os.getenv("FD_JOB_ID"),
8787
# support max connections
88-
"FD_SUPPORT_MAX_CONNECTIONS": lambda: 768,
88+
"FD_SUPPORT_MAX_CONNECTIONS": lambda: int(os.getenv("FD_SUPPORT_MAX_CONNECTIONS", "1024")),
8989
}
9090

9191

0 commit comments

Comments
 (0)