Skip to content

Commit 4d4c4cb

Browse files
committed
fix
1 parent 7ccd241 commit 4d4c4cb

File tree

2 files changed

+116
-0
lines changed

2 files changed

+116
-0
lines changed

fastdeploy/entrypoints/openai/api_server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ async def lifespan(app: FastAPI):
156156
await engine_client.connection_manager.close()
157157
engine_client.zmq_client.close()
158158
from prometheus_client import multiprocess
159+
159160
multiprocess.mark_process_dead(os.getpid())
160161
api_server_logger.info(f"Closing metrics client pid: {pid}")
161162
except Exception as e:
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
import unittest
18+
from unittest import mock
19+
20+
import msgpack
21+
22+
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
23+
24+
25+
class TestDealerConnectionManager(unittest.IsolatedAsyncioTestCase):
26+
def setUp(self):
27+
self.patchers = [mock.patch("aiozmq.create_zmq_stream"), mock.patch("fastdeploy.utils.api_server_logger")]
28+
for p in self.patchers:
29+
p.start()
30+
self.addCleanup(p.stop)
31+
32+
self.mock_create_stream = self.patchers[0].start()
33+
self.mock_logger = self.patchers[1].start()
34+
35+
async def test_initialize(self):
36+
"""Test initialization of connections"""
37+
manager = DealerConnectionManager(pid=1, max_connections=5)
38+
39+
# Mock the stream creation
40+
mock_stream = mock.AsyncMock()
41+
self.mock_create_stream.return_value = mock_stream
42+
43+
await manager.initialize()
44+
45+
# Verify connections were created
46+
self.assertEqual(len(manager.connections), 5)
47+
self.mock_logger.info.assert_called_with("Started 5 connections")
48+
49+
async def test_get_connection(self):
50+
"""Test getting a connection with load balancing"""
51+
manager = DealerConnectionManager(pid=1, max_connections=2)
52+
53+
# Mock the stream creation
54+
mock_stream1 = mock.AsyncMock()
55+
mock_stream2 = mock.AsyncMock()
56+
self.mock_create_stream.side_effect = [mock_stream1, mock_stream2]
57+
58+
await manager.initialize()
59+
60+
# First request
61+
conn1, queue1 = await manager.get_connection("req1")
62+
self.assertIs(conn1, mock_stream1)
63+
64+
# Second request should use different connection
65+
conn2, queue2 = await manager.get_connection("req2")
66+
self.assertIs(conn2, mock_stream2)
67+
68+
# Third request should go back to first connection (least loaded)
69+
conn3, queue3 = await manager.get_connection("req3")
70+
self.assertIs(conn3, mock_stream1)
71+
72+
async def test_listen_connection(self):
73+
"""Test message listening"""
74+
manager = DealerConnectionManager(pid=1)
75+
manager.running = True
76+
77+
# Mock connection
78+
mock_stream = mock.AsyncMock()
79+
mock_stream.read.return_value = [b"", msgpack.packb({"request_id": "req1", "finished": True})]
80+
81+
# Mock response queue
82+
mock_queue = mock.AsyncMock()
83+
manager.request_map["req1"] = mock_queue
84+
85+
await manager._listen_connection(mock_stream, 0)
86+
87+
# Verify message was processed
88+
mock_queue.put.assert_called_once()
89+
self.assertEqual(manager.connection_load[0], -1)
90+
91+
async def test_close(self):
92+
"""Test cleanup on close"""
93+
manager = DealerConnectionManager(pid=1)
94+
manager.running = True
95+
96+
# Mock connection
97+
mock_stream = mock.MagicMock()
98+
mock_task = mock.MagicMock()
99+
manager.connections.append(mock_stream)
100+
manager.connection_tasks.append(mock_task)
101+
manager.request_map["req1"] = mock.AsyncMock()
102+
103+
await manager.close()
104+
105+
# Verify cleanup
106+
self.assertFalse(manager.running)
107+
mock_stream.close.assert_called_once()
108+
mock_task.cancel.assert_called_once()
109+
self.assertEqual(len(manager.connections), 0)
110+
self.assertEqual(len(manager.request_map), 0)
111+
self.mock_logger.info.assert_called_with("All connections and tasks closed")
112+
113+
114+
if __name__ == "__main__":
115+
unittest.main()

0 commit comments

Comments
 (0)