Skip to content

[Bug fix] Fix bug in logprob in release 2.0.4 #3445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions fastdeploy/output/token_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,11 @@ def _process_sampling_with_logprob_batch_output(self):
scores = self.output_scores[: batch * (K + 1)].numpy().reshape([batch, K + 1])[:, : (K + 1)]
ranks = self.output_ranks[:batch].numpy()
batch_result = list()
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
need_to_be_reschedule_req_ids = list(self.resource_manager.to_be_rescheduled_request_id_set)
for request_id in need_to_be_reschedule_req_ids:
if self.resource_manager.requests[request_id].idx >= (batch - 1): # No more token generated for preempted request
self.resource_manager.reschedule_preempt_task(request_id)
for i in range(batch):
if self.resource_manager.stop_flags[i]:
continue
Expand All @@ -326,6 +331,9 @@ def _process_sampling_with_logprob_batch_output(self):
if recovery_stop:
llm_logger.info(f"recovery stop signal found at task {task_id}")
if not recovery_stop and token_id < 0:
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if task_id in self.resource_manager.to_be_rescheduled_request_id_set:
self.resource_manager.reschedule_preempt_task(task_id)
continue

if task.get("prefill_chunk_info", None) is not None:
Expand Down Expand Up @@ -382,6 +390,7 @@ def _process_sampling_with_logprob_batch_output(self):
self.tokens_counter[task_id] += 1
if token_id != RECOVERY_STOP_SIGNAL:
result.outputs.token_ids.append(token_id)
task.output_token_ids.append(token_id)
result.outputs.logprob = float(scores[i, 0])
# Construct top_logprobs
topk_token_ids = tokens[i, :].tolist()
Expand Down
Loading