Skip to content

Commit 4a8c7c4

Browse files
committed
fix mtp_rej_topp input
1 parent 3ee6053 commit 4a8c7c4

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def rejection_top_p_sampling(
141141
top_k_renorm_probs,
142142
)
143143

144-
if not any(x > 0 for x in top_k_list):
144+
if top_k_list and not any(x > 0 for x in top_k_list):
145145
ids = rejection_top_p_sampling(
146146
x,
147147
top_p,
@@ -177,7 +177,7 @@ def min_p_sampling(
177177
"""
178178
min_p_sampling
179179
"""
180-
if not any(x > 0 for x in min_p_arr_cpu):
180+
if min_p_arr_cpu and not any(x > 0 for x in min_p_arr_cpu):
181181
return probs
182182
else:
183183
if current_platform.is_cuda():

fastdeploy/model_executor/layers/sample/sampler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,5 +457,7 @@ def forward_cuda(
457457
)
458458
probs = F.softmax(logits)
459459

460-
_, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
460+
_, next_tokens = top_k_top_p_sampling(
461+
probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list
462+
)
461463
return next_tokens

0 commit comments

Comments
 (0)