@@ -315,6 +315,11 @@ def _process_sampling_with_logprob_batch_output(self):
315315 scores = self .output_scores [: batch * (K + 1 )].numpy ().reshape ([batch , K + 1 ])[:, : (K + 1 )]
316316 ranks = self .output_ranks [:batch ].numpy ()
317317 batch_result = list ()
318+ if envs .ENABLE_V1_KVCACHE_SCHEDULER :
319+ need_to_be_reschedule_req_ids = list (self .resource_manager .to_be_rescheduled_request_id_set )
320+ for request_id in need_to_be_reschedule_req_ids :
321+ if self .resource_manager .requests [request_id ].idx >= (batch - 1 ): # No more token generated for preempted request
322+ self .resource_manager .reschedule_preempt_task (request_id )
318323 for i in range (batch ):
319324 if self .resource_manager .stop_flags [i ]:
320325 continue
@@ -326,6 +331,9 @@ def _process_sampling_with_logprob_batch_output(self):
326331 if recovery_stop :
327332 llm_logger .info (f"recovery stop signal found at task { task_id } " )
328333 if not recovery_stop and token_id < 0 :
334+ if envs .ENABLE_V1_KVCACHE_SCHEDULER :
335+ if task_id in self .resource_manager .to_be_rescheduled_request_id_set :
336+ self .resource_manager .reschedule_preempt_task (task_id )
329337 continue
330338
331339 if task .get ("prefill_chunk_info" , None ) is not None :
@@ -382,6 +390,7 @@ def _process_sampling_with_logprob_batch_output(self):
382390 self .tokens_counter [task_id ] += 1
383391 if token_id != RECOVERY_STOP_SIGNAL :
384392 result .outputs .token_ids .append (token_id )
393+ task .output_token_ids .append (token_id )
385394 result .outputs .logprob = float (scores [i , 0 ])
386395 # Construct top_logprobs
387396 topk_token_ids = tokens [i , :].tolist ()
0 commit comments