Skip to content

Commit 8b12c80

Browse files
authored
[FixBug] compute early stopping with real batch size (#3418)
* [FixBug] compute early stopping with real batch size * update * fix test_sampler
1 parent 3a7a20d commit 8b12c80

File tree

4 files changed

+69
-6
lines changed

4 files changed

+69
-6
lines changed

fastdeploy/model_executor/layers/sample/early_stopper.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,17 @@ def process(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags:
6767
def process_normal(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
6868
# Get the probability score corresponding to next_tokens in this step
6969
next_scores = paddle.index_sample(probs, next_tokens)
70+
real_bsz = probs.shape[0]
7071

7172
# Sliding window: Move left one grid and insert new score
72-
self.trunc_scores[:, :-1] = self.trunc_scores[:, 1:]
73-
self.trunc_scores[:, -1:] = next_scores
73+
self.trunc_scores[:real_bsz, :-1] = self.trunc_scores[:real_bsz, 1:]
74+
self.trunc_scores[:real_bsz, -1:] = next_scores
7475

7576
# Determine which samples need to be terminated: all trunc_scores are greater than threshold
7677
need_trunc_all = paddle.all(self.trunc_scores > self.threshold, axis=-1).unsqueeze(-1)
7778

7879
# Add the stop flags
79-
stop_flags[need_trunc_all] = True
80+
stop_flags[need_trunc_all[:real_bsz]] = True
8081

8182
# Reset trunc_scores of truncated samples to 0 to avoid false triggering in the next step
8283
reset_mask = need_trunc_all.tile([1, self.window_size])

scripts/coverage_run.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ done
2626
failed_tests_file="failed_tests.log"
2727
> "$failed_tests_file"
2828
disabled_tests=(
29-
layers/test_sampler.py
3029
layers/test_append_attention.py
3130
layers/test_attention.py
3231
operators/test_rejection_top_p_sampling.py
@@ -36,7 +35,6 @@ disabled_tests=(
3635
operators/test_stop_generation.py
3736
operators/test_air_topp_sampling.py
3837
operators/test_fused_moe.py
39-
layers/test_repetition_early_stopper.py
4038
operators/test_stop_generation_multi_ends.py
4139
graph_optimization/test_cuda_graph.py
4240
)

test/layers/test_repetition_early_stopper.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,69 @@ def test_consistency():
170170
actual = triggered_step_triton[i]
171171
assert expected == actual, f"Sample {i} triggered at different steps: {expected} vs {actual}"
172172

173-
print("Triton vs Normal: All tokens, states, and trigger timings match.")
173+
print("[consistency]Triton vs Normal: All tokens, states, and trigger timings match.")
174+
175+
176+
def test_consistency_with_real_batch_size():
177+
batch_size = 20
178+
real_batch_size = 15
179+
vocab_size = 103424
180+
window_size = 3000
181+
threshold = 0.9
182+
eos_token_id = vocab_size
183+
max_steps = 10
184+
185+
fixed_token_id = np.random.randint(0, vocab_size)
186+
early_stop_batch_id = np.random.randint(0, real_batch_size)
187+
188+
trigger_step_flags = [[i, np.random.randint(0, max_steps + 1)] for i in range(batch_size)]
189+
trigger_step_flags = dict(trigger_step_flags)
190+
cfg = EarlyStopConfig({"enable_early_stop": True, "window_size": window_size, "threshold": threshold})
191+
stopper_normal = RepetitionEarlyStopper()
192+
stopper_normal.initialize(batch_size, cfg)
193+
stopper_triton = RepetitionEarlyStopper()
194+
stopper_triton.initialize(batch_size, cfg)
195+
196+
next_tokens_normal = paddle.randint(0, vocab_size, shape=[real_batch_size, 1], dtype="int64")
197+
next_tokens_triton = next_tokens_normal.clone()
198+
199+
next_tokens_normal[early_stop_batch_id, 0] = fixed_token_id
200+
next_tokens_triton[early_stop_batch_id, 0] = fixed_token_id
201+
202+
stop_flags_normal = paddle.zeros_like(next_tokens_normal)
203+
stop_flags_triton = stop_flags_normal.clone()
204+
205+
triggered_step_normal = [None] * batch_size
206+
triggered_step_triton = [None] * batch_size
207+
208+
for step in range(max_steps):
209+
210+
flags = [trigger_step_flags[i] for i in range(real_batch_size)]
211+
probs_np = simulate_step_probs(real_batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags)
212+
probs = paddle.to_tensor(probs_np)
213+
214+
stopper_normal.process_normal(probs, next_tokens_normal, stop_flags_normal)
215+
stopper_triton.process_triton(probs, next_tokens_triton, stop_flags_triton)
216+
217+
assert np.allclose(stop_flags_normal.numpy(), stop_flags_triton.numpy()), f"stop flags mismatch at step {step}"
218+
219+
trunc_scores_diff = paddle.abs(stopper_normal.trunc_scores - stopper_triton.trunc_scores)
220+
assert paddle.all(trunc_scores_diff < 1e-5), f"trunc_scores mismatch at step {step}"
221+
222+
out_normal = stop_flags_normal.numpy()
223+
out_triton = stop_flags_triton.numpy()
224+
for i in range(real_batch_size):
225+
if out_normal[i, 0] == eos_token_id and triggered_step_normal[i] is None:
226+
triggered_step_normal[i] = step
227+
if out_triton[i, 0] == eos_token_id and triggered_step_triton[i] is None:
228+
triggered_step_triton[i] = step
229+
230+
for i in range(batch_size):
231+
expected = triggered_step_normal[i]
232+
actual = triggered_step_triton[i]
233+
assert expected == actual, f"Sample {i} triggered at different steps: {expected} vs {actual}"
234+
235+
print("[consistency_with_real_batch_size]Triton vs Normal: All tokens, states, and trigger timings match.")
174236

175237

176238
def test_performance():
@@ -232,4 +294,5 @@ def test_performance():
232294
if __name__ == "__main__":
233295
test_repetition_early_stopper()
234296
test_consistency()
297+
test_consistency_with_real_batch_size()
235298
test_performance()

test/layers/test_sampler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def _create_default_sampling_metadata(
5757
bad_words_token_ids=paddle.full(shape=[batch_size], fill_value=-1, dtype="int64"),
5858
eos_token_ids=paddle.full(shape=[batch_size], fill_value=-2, dtype="int64"),
5959
min_p=paddle.randn([batch_size]),
60+
seed=paddle.to_tensor([[2025]]),
6061
)
6162
return fake_sampling_metadata
6263

0 commit comments

Comments
 (0)