Skip to content

Commit faef052

Browse files
committed
fix
1 parent 91d8a5f commit faef052

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

fastdeploy/entrypoints/openai/serving_completion.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ async def completion_full_generator(
173173
request_ids = [f"{request_id}-{i}" for i in range(num_choices)]
174174
# create dealer
175175
await self._ensure_connection_manager()
176-
dealer, response_queue = await self.engine.connection_manager.get_connection(request_id)
176+
dealer, response_queue = await self.engine.connection_manager.get_connection(request_id, num_choices)
177177

178178
for rid in request_ids:
179179
dealer.write([b"", rid.encode("utf-8")])
@@ -277,7 +277,9 @@ async def completion_stream_generator(
277277
"""
278278
try:
279279
await self._ensure_connection_manager()
280-
dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id)
280+
dealer, response_queue = await self.engine_client.connection_manager.get_connection(
281+
request_id, num_choices
282+
)
281283

282284
for i in range(num_choices):
283285
req_id = f"{request_id}-{i}"

fastdeploy/entrypoints/openai/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(self, pid, max_connections=10):
3737
self.connection_load = []
3838
self.connection_heap = []
3939
self.request_map = {} # request_id -> response_queue
40+
self.request_num = {} # request_id -> num_choices
4041
self.lock = asyncio.Lock()
4142
self.connection_tasks = []
4243
self.running = False
@@ -77,11 +78,15 @@ async def _listen_connection(self, dealer, conn_index):
7778
raw_data = await dealer.read()
7879
response = msgpack.unpackb(raw_data[-1])
7980
request_id = response[-1]["request_id"]
81+
if "cmpl" in request_id:
82+
request_id = request_id.rsplit("-", 1)[0]
8083
async with self.lock:
8184
if request_id in self.request_map:
8285
await self.request_map[request_id].put(response)
8386
if response[-1]["finished"]:
84-
self._update_load(conn_index, -1)
87+
self.request_num[request_id] -= 1
88+
if self.request_num[request_id] == 0:
89+
self._update_load(conn_index, -1)
8590
except Exception as e:
8691
api_server_logger.error(f"Listener error: {str(e)}")
8792
break
@@ -109,13 +114,14 @@ def _get_least_loaded_connection(self):
109114

110115
return self.connections[conn_index]
111116

112-
async def get_connection(self, request_id):
117+
async def get_connection(self, request_id, num_choices=1):
113118
"""get a connection for the request"""
114119

115120
response_queue = asyncio.Queue()
116121

117122
async with self.lock:
118123
self.request_map[request_id] = response_queue
124+
self.request_num[request_id] = num_choices
119125
dealer = self._get_least_loaded_connection()
120126
if not dealer:
121127
raise RuntimeError("No available connections")
@@ -129,6 +135,7 @@ async def cleanup_request(self, request_id):
129135
async with self.lock:
130136
if request_id in self.request_map:
131137
del self.request_map[request_id]
138+
del self.request_num[request_id]
132139

133140
async def close(self):
134141
"""

0 commit comments

Comments
 (0)