Skip to content

[FixBug] compute early stopping with real batch size #3418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 19, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions fastdeploy/model_executor/layers/sample/early_stopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,17 @@ def process(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags:
def process_normal(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
# Get the probability score corresponding to next_tokens in this step
next_scores = paddle.index_sample(probs, next_tokens)
real_bsz = probs.shape[0]

# Sliding window: Move left one grid and insert new score
self.trunc_scores[:, :-1] = self.trunc_scores[:, 1:]
self.trunc_scores[:, -1:] = next_scores
self.trunc_scores[:real_bsz, :-1] = self.trunc_scores[:real_bsz, 1:]
self.trunc_scores[:real_bsz, -1:] = next_scores

# Determine which samples need to be terminated: all trunc_scores are greater than threshold
need_trunc_all = paddle.all(self.trunc_scores > self.threshold, axis=-1).unsqueeze(-1)

# Add the stop flags
stop_flags[need_trunc_all] = True
stop_flags[need_trunc_all[:real_bsz]] = True

# Reset trunc_scores of truncated samples to 0 to avoid false triggering in the next step
reset_mask = need_trunc_all.tile([1, self.window_size])
Expand Down
65 changes: 64 additions & 1 deletion test/layers/test_repetition_early_stopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,69 @@ def test_consistency():
actual = triggered_step_triton[i]
assert expected == actual, f"Sample {i} triggered at different steps: {expected} vs {actual}"

print("Triton vs Normal: All tokens, states, and trigger timings match.")
print("[consistency]Triton vs Normal: All tokens, states, and trigger timings match.")


def test_consistency_with_real_batch_size():
batch_size = 20
real_batch_size = 15
vocab_size = 103424
window_size = 3000
threshold = 0.9
eos_token_id = vocab_size
max_steps = 10

fixed_token_id = np.random.randint(0, vocab_size)
early_stop_batch_id = np.random.randint(0, real_batch_size)

trigger_step_flags = [[i, np.random.randint(0, max_steps + 1)] for i in range(batch_size)]
trigger_step_flags = dict(trigger_step_flags)
cfg = EarlyStopConfig({"enable_early_stop": True, "window_size": window_size, "threshold": threshold})
stopper_normal = RepetitionEarlyStopper()
stopper_normal.initialize(batch_size, cfg)
stopper_triton = RepetitionEarlyStopper()
stopper_triton.initialize(batch_size, cfg)

next_tokens_normal = paddle.randint(0, vocab_size, shape=[real_batch_size, 1], dtype="int64")
next_tokens_triton = next_tokens_normal.clone()

next_tokens_normal[early_stop_batch_id, 0] = fixed_token_id
next_tokens_triton[early_stop_batch_id, 0] = fixed_token_id

stop_flags_normal = paddle.zeros_like(next_tokens_normal)
stop_flags_triton = stop_flags_normal.clone()

triggered_step_normal = [None] * batch_size
triggered_step_triton = [None] * batch_size

for step in range(max_steps):

flags = [trigger_step_flags[i] for i in range(real_batch_size)]
probs_np = simulate_step_probs(real_batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags)
probs = paddle.to_tensor(probs_np)

stopper_normal.process_normal(probs, next_tokens_normal, stop_flags_normal)
stopper_triton.process_triton(probs, next_tokens_triton, stop_flags_triton)

assert np.allclose(stop_flags_normal.numpy(), stop_flags_triton.numpy()), f"stop flags mismatch at step {step}"

trunc_scores_diff = paddle.abs(stopper_normal.trunc_scores - stopper_triton.trunc_scores)
assert paddle.all(trunc_scores_diff < 1e-5), f"trunc_scores mismatch at step {step}"

out_normal = stop_flags_normal.numpy()
out_triton = stop_flags_triton.numpy()
for i in range(real_batch_size):
if out_normal[i, 0] == eos_token_id and triggered_step_normal[i] is None:
triggered_step_normal[i] = step
if out_triton[i, 0] == eos_token_id and triggered_step_triton[i] is None:
triggered_step_triton[i] = step

for i in range(batch_size):
expected = triggered_step_normal[i]
actual = triggered_step_triton[i]
assert expected == actual, f"Sample {i} triggered at different steps: {expected} vs {actual}"

print("[consistency_with_real_batch_size]Triton vs Normal: All tokens, states, and trigger timings match.")


def test_performance():
Expand Down Expand Up @@ -232,4 +294,5 @@ def test_performance():
if __name__ == "__main__":
test_repetition_early_stopper()
test_consistency()
test_consistency_with_real_batch_size()
test_performance()
Loading