@@ -37,6 +37,7 @@ def __init__(self, pid, max_connections=10):
3737 self .connection_load = []
3838 self .connection_heap = []
3939 self .request_map = {} # request_id -> response_queue
40+ self .request_num = {} # request_id -> num_choices
4041 self .lock = asyncio .Lock ()
4142 self .connection_tasks = []
4243 self .running = False
@@ -77,11 +78,15 @@ async def _listen_connection(self, dealer, conn_index):
7778 raw_data = await dealer .read ()
7879 response = msgpack .unpackb (raw_data [- 1 ])
7980 request_id = response [- 1 ]["request_id" ]
81+ if "cmpl" in request_id :
82+ request_id = request_id .rsplit ("-" , 1 )[0 ]
8083 async with self .lock :
8184 if request_id in self .request_map :
8285 await self .request_map [request_id ].put (response )
8386 if response [- 1 ]["finished" ]:
84- self ._update_load (conn_index , - 1 )
87+ self .request_num [request_id ] -= 1
88+ if self .request_num [request_id ] == 0 :
89+ self ._update_load (conn_index , - 1 )
8590 except Exception as e :
8691 api_server_logger .error (f"Listener error: { str (e )} " )
8792 break
@@ -109,13 +114,14 @@ def _get_least_loaded_connection(self):
109114
110115 return self .connections [conn_index ]
111116
112- async def get_connection (self , request_id ):
117+ async def get_connection (self , request_id , num_choices = 1 ):
113118 """get a connection for the request"""
114119
115120 response_queue = asyncio .Queue ()
116121
117122 async with self .lock :
118123 self .request_map [request_id ] = response_queue
124+ self .request_num [request_id ] = num_choices
119125 dealer = self ._get_least_loaded_connection ()
120126 if not dealer :
121127 raise RuntimeError ("No available connections" )
@@ -129,6 +135,7 @@ async def cleanup_request(self, request_id):
129135 async with self .lock :
130136 if request_id in self .request_map :
131137 del self .request_map [request_id ]
138+ del self .request_num [request_id ]
132139
133140 async def close (self ):
134141 """
0 commit comments