diff --git a/fastdeploy/model_executor/layers/sample/early_stopper.py b/fastdeploy/model_executor/layers/sample/early_stopper.py index 3ac0daf2fb..5f0a248881 100644 --- a/fastdeploy/model_executor/layers/sample/early_stopper.py +++ b/fastdeploy/model_executor/layers/sample/early_stopper.py @@ -67,16 +67,17 @@ def process(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: def process_normal(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor): # Get the probability score corresponding to next_tokens in this step next_scores = paddle.index_sample(probs, next_tokens) + real_bsz = probs.shape[0] # Sliding window: Move left one grid and insert new score - self.trunc_scores[:, :-1] = self.trunc_scores[:, 1:] - self.trunc_scores[:, -1:] = next_scores + self.trunc_scores[:real_bsz, :-1] = self.trunc_scores[:real_bsz, 1:] + self.trunc_scores[:real_bsz, -1:] = next_scores # Determine which samples need to be terminated: all trunc_scores are greater than threshold need_trunc_all = paddle.all(self.trunc_scores > self.threshold, axis=-1).unsqueeze(-1) # Add the stop flags - stop_flags[need_trunc_all] = True + stop_flags[need_trunc_all[:real_bsz]] = True # Reset trunc_scores of truncated samples to 0 to avoid false triggering in the next step reset_mask = need_trunc_all.tile([1, self.window_size]) diff --git a/scripts/coverage_run.sh b/scripts/coverage_run.sh index 66ea16225e..8f0e149c49 100644 --- a/scripts/coverage_run.sh +++ b/scripts/coverage_run.sh @@ -26,7 +26,6 @@ done failed_tests_file="failed_tests.log" > "$failed_tests_file" disabled_tests=( - layers/test_sampler.py layers/test_append_attention.py layers/test_attention.py operators/test_rejection_top_p_sampling.py @@ -36,7 +35,6 @@ disabled_tests=( operators/test_stop_generation.py operators/test_air_topp_sampling.py operators/test_fused_moe.py - layers/test_repetition_early_stopper.py operators/test_stop_generation_multi_ends.py graph_optimization/test_cuda_graph.py ) diff --git a/test/layers/test_repetition_early_stopper.py b/test/layers/test_repetition_early_stopper.py index 8dd59d7973..490331b4a4 100644 --- a/test/layers/test_repetition_early_stopper.py +++ b/test/layers/test_repetition_early_stopper.py @@ -170,7 +170,69 @@ def test_consistency(): actual = triggered_step_triton[i] assert expected == actual, f"Sample {i} triggered at different steps: {expected} vs {actual}" - print("Triton vs Normal: All tokens, states, and trigger timings match.") + print("[consistency]Triton vs Normal: All tokens, states, and trigger timings match.") + + +def test_consistency_with_real_batch_size(): + batch_size = 20 + real_batch_size = 15 + vocab_size = 103424 + window_size = 3000 + threshold = 0.9 + eos_token_id = vocab_size + max_steps = 10 + + fixed_token_id = np.random.randint(0, vocab_size) + early_stop_batch_id = np.random.randint(0, real_batch_size) + + trigger_step_flags = [[i, np.random.randint(0, max_steps + 1)] for i in range(batch_size)] + trigger_step_flags = dict(trigger_step_flags) + cfg = EarlyStopConfig({"enable_early_stop": True, "window_size": window_size, "threshold": threshold}) + stopper_normal = RepetitionEarlyStopper() + stopper_normal.initialize(batch_size, cfg) + stopper_triton = RepetitionEarlyStopper() + stopper_triton.initialize(batch_size, cfg) + + next_tokens_normal = paddle.randint(0, vocab_size, shape=[real_batch_size, 1], dtype="int64") + next_tokens_triton = next_tokens_normal.clone() + + next_tokens_normal[early_stop_batch_id, 0] = fixed_token_id + next_tokens_triton[early_stop_batch_id, 0] = fixed_token_id + + stop_flags_normal = paddle.zeros_like(next_tokens_normal) + stop_flags_triton = stop_flags_normal.clone() + + triggered_step_normal = [None] * batch_size + triggered_step_triton = [None] * batch_size + + for step in range(max_steps): + + flags = [trigger_step_flags[i] for i in range(real_batch_size)] + probs_np = simulate_step_probs(real_batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags) + probs = paddle.to_tensor(probs_np) + + stopper_normal.process_normal(probs, next_tokens_normal, stop_flags_normal) + stopper_triton.process_triton(probs, next_tokens_triton, stop_flags_triton) + + assert np.allclose(stop_flags_normal.numpy(), stop_flags_triton.numpy()), f"stop flags mismatch at step {step}" + + trunc_scores_diff = paddle.abs(stopper_normal.trunc_scores - stopper_triton.trunc_scores) + assert paddle.all(trunc_scores_diff < 1e-5), f"trunc_scores mismatch at step {step}" + + out_normal = stop_flags_normal.numpy() + out_triton = stop_flags_triton.numpy() + for i in range(real_batch_size): + if out_normal[i, 0] == eos_token_id and triggered_step_normal[i] is None: + triggered_step_normal[i] = step + if out_triton[i, 0] == eos_token_id and triggered_step_triton[i] is None: + triggered_step_triton[i] = step + + for i in range(batch_size): + expected = triggered_step_normal[i] + actual = triggered_step_triton[i] + assert expected == actual, f"Sample {i} triggered at different steps: {expected} vs {actual}" + + print("[consistency_with_real_batch_size]Triton vs Normal: All tokens, states, and trigger timings match.") def test_performance(): @@ -232,4 +294,5 @@ def test_performance(): if __name__ == "__main__": test_repetition_early_stopper() test_consistency() + test_consistency_with_real_batch_size() test_performance() diff --git a/test/layers/test_sampler.py b/test/layers/test_sampler.py index c2fb690187..7b0954c22c 100644 --- a/test/layers/test_sampler.py +++ b/test/layers/test_sampler.py @@ -57,6 +57,7 @@ def _create_default_sampling_metadata( bad_words_token_ids=paddle.full(shape=[batch_size], fill_value=-1, dtype="int64"), eos_token_ids=paddle.full(shape=[batch_size], fill_value=-2, dtype="int64"), min_p=paddle.randn([batch_size]), + seed=paddle.to_tensor([[2025]]), ) return fake_sampling_metadata