@@ -56,6 +56,7 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l
56
56
self .running : list [Request ] = []
57
57
self .finish_execution_pool = ThreadPoolExecutor (max_workers = 1 )
58
58
self .lock = threading .Lock ()
59
+ self .to_be_rescheduled_request_id_set = set ()
59
60
60
61
def allocated_slots (self , request : Request ):
61
62
return len (request .block_tables ) * self .config .cache_config .block_size
@@ -76,6 +77,13 @@ def _prepare_decode_task(self, request):
76
77
77
78
def _prepare_preempt_task (self , request ):
78
79
return ScheduledPreemptTask (idx = request .idx , request_id = request .request_id )
80
+
81
+ def reschedule_preempt_task (self , request_id ):
82
+ with self .lock :
83
+ if request_id in self .to_be_rescheduled_request_id_set and request_id in self .requests :
84
+ request = self .requests [request_id ]
85
+ self .waiting .appendleft (request )
86
+ self .to_be_rescheduled_request_id_set .remove (request_id )
79
87
80
88
def _trigger_preempt (self , request , num_new_blocks , preempted_reqs , scheduled_reqs ):
81
89
can_schedule = True
@@ -85,7 +93,7 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_re
85
93
preempted_req .status = RequestStatus .PREEMPTED
86
94
preempted_req .num_computed_tokens = 0
87
95
self ._free_blocks (preempted_req )
88
- self .waiting . appendleft (preempted_req )
96
+ self .to_be_rescheduled_request_id_set . add (preempted_req . request_id )
89
97
preempted_reqs .append (preempted_req )
90
98
scheduled_reqs .append (self ._prepare_preempt_task (preempted_req ))
91
99
if preempted_req == request :
@@ -308,8 +316,9 @@ def get_real_bsz(self) -> int:
308
316
return self .real_bsz
309
317
310
318
def add_request (self , request : Request ) -> None :
311
- self .waiting .append (request )
312
- self .requests [request .request_id ] = request
319
+ with self .lock :
320
+ self .waiting .append (request )
321
+ self .requests [request .request_id ] = request
313
322
314
323
def _free_blocks (self , request : Request ):
315
324
self .cache_manager .recycle_gpu_blocks (request .block_tables )
@@ -331,9 +340,15 @@ def finish_requests(self, request_ids: Union[str, Iterable[str]]):
331
340
if request is None :
332
341
# Invalid request ID.
333
342
continue
334
- request .status = RequestStatus .FINISHED
335
- self .running .remove (request )
336
- self ._free_blocks (request )
343
+ if request in self .running : # normally run and finished
344
+ self .running .remove (request )
345
+ request .status = RequestStatus .FINISHED
346
+ self ._free_blocks (request )
347
+ if request .request_id in self .to_be_rescheduled_request_id_set : # finished after preempted, blocks have been recycled.
348
+ self .to_be_rescheduled_request_id_set .remove (request .request_id ) # just remove from to_be_rescheduled_request_id_set
349
+ if request in self .waiting : # after finished, this request still scheduled from preempted to waiting, unexpected error, should not be here
350
+ raise RuntimeError (f"request { request .request_id } scheduled into waiting list, after finished" )
351
+
337
352
self .tasks_list [request .idx ] = None
338
353
self .stop_flags [request .idx ] = True
339
354
del self .requests [req_id ]
0 commit comments