@@ -170,7 +170,69 @@ def test_consistency():
170
170
actual = triggered_step_triton [i ]
171
171
assert expected == actual , f"Sample { i } triggered at different steps: { expected } vs { actual } "
172
172
173
- print ("Triton vs Normal: All tokens, states, and trigger timings match." )
173
+ print ("[consistency]Triton vs Normal: All tokens, states, and trigger timings match." )
174
+
175
+
176
+ def test_consistency_with_real_batch_size ():
177
+ batch_size = 20
178
+ real_batch_size = 15
179
+ vocab_size = 103424
180
+ window_size = 3000
181
+ threshold = 0.9
182
+ eos_token_id = vocab_size
183
+ max_steps = 10
184
+
185
+ fixed_token_id = np .random .randint (0 , vocab_size )
186
+ early_stop_batch_id = np .random .randint (0 , real_batch_size )
187
+
188
+ trigger_step_flags = [[i , np .random .randint (0 , max_steps + 1 )] for i in range (batch_size )]
189
+ trigger_step_flags = dict (trigger_step_flags )
190
+ cfg = EarlyStopConfig ({"enable_early_stop" : True , "window_size" : window_size , "threshold" : threshold })
191
+ stopper_normal = RepetitionEarlyStopper ()
192
+ stopper_normal .initialize (batch_size , cfg )
193
+ stopper_triton = RepetitionEarlyStopper ()
194
+ stopper_triton .initialize (batch_size , cfg )
195
+
196
+ next_tokens_normal = paddle .randint (0 , vocab_size , shape = [real_batch_size , 1 ], dtype = "int64" )
197
+ next_tokens_triton = next_tokens_normal .clone ()
198
+
199
+ next_tokens_normal [early_stop_batch_id , 0 ] = fixed_token_id
200
+ next_tokens_triton [early_stop_batch_id , 0 ] = fixed_token_id
201
+
202
+ stop_flags_normal = paddle .zeros_like (next_tokens_normal )
203
+ stop_flags_triton = stop_flags_normal .clone ()
204
+
205
+ triggered_step_normal = [None ] * batch_size
206
+ triggered_step_triton = [None ] * batch_size
207
+
208
+ for step in range (max_steps ):
209
+
210
+ flags = [trigger_step_flags [i ] for i in range (real_batch_size )]
211
+ probs_np = simulate_step_probs (real_batch_size , early_stop_batch_id , fixed_token_id , vocab_size , step , flags )
212
+ probs = paddle .to_tensor (probs_np )
213
+
214
+ stopper_normal .process_normal (probs , next_tokens_normal , stop_flags_normal )
215
+ stopper_triton .process_triton (probs , next_tokens_triton , stop_flags_triton )
216
+
217
+ assert np .allclose (stop_flags_normal .numpy (), stop_flags_triton .numpy ()), f"stop flags mismatch at step { step } "
218
+
219
+ trunc_scores_diff = paddle .abs (stopper_normal .trunc_scores - stopper_triton .trunc_scores )
220
+ assert paddle .all (trunc_scores_diff < 1e-5 ), f"trunc_scores mismatch at step { step } "
221
+
222
+ out_normal = stop_flags_normal .numpy ()
223
+ out_triton = stop_flags_triton .numpy ()
224
+ for i in range (real_batch_size ):
225
+ if out_normal [i , 0 ] == eos_token_id and triggered_step_normal [i ] is None :
226
+ triggered_step_normal [i ] = step
227
+ if out_triton [i , 0 ] == eos_token_id and triggered_step_triton [i ] is None :
228
+ triggered_step_triton [i ] = step
229
+
230
+ for i in range (batch_size ):
231
+ expected = triggered_step_normal [i ]
232
+ actual = triggered_step_triton [i ]
233
+ assert expected == actual , f"Sample { i } triggered at different steps: { expected } vs { actual } "
234
+
235
+ print ("[consistency_with_real_batch_size]Triton vs Normal: All tokens, states, and trigger timings match." )
174
236
175
237
176
238
def test_performance ():
@@ -232,4 +294,5 @@ def test_performance():
232
294
if __name__ == "__main__" :
233
295
test_repetition_early_stopper ()
234
296
test_consistency ()
297
+ test_consistency_with_real_batch_size ()
235
298
test_performance ()
0 commit comments