@@ -35,10 +35,11 @@ def setUp(self):
35
35
def test_top_p_sampling_reject_case1 (self ):
36
36
"""Test with fixed top_p=0.8 and different random seeds"""
37
37
top_p_paddle = paddle .full ((self .batch_size ,), 0.8 )
38
+ top_k_paddle = paddle .full ((self .batch_size ,), 20 ).cast ("int64" )
38
39
39
40
# Test with different seeds
40
41
for seed in [1024 , 2033 , 2033 ]:
41
- samples = rejection_top_p_sampling (self .paddle_norm_prob , top_p_paddle , seed )
42
+ samples = rejection_top_p_sampling (self .paddle_norm_prob , top_p_paddle , top_k_paddle , seed )
42
43
self ._validate_samples (samples )
43
44
44
45
# Basic validation
@@ -48,13 +49,12 @@ def test_top_p_sampling_reject_case1(self):
48
49
def test_top_p_sampling_reject_case2 (self ):
49
50
"""Test with varying top_p values across batch"""
50
51
top_p_paddle = paddle .uniform (shape = [self .batch_size ], min = 0.1 , max = 1.0 )
51
- samples = rejection_top_p_sampling ( self .paddle_norm_prob , top_p_paddle , - 1 )
52
-
52
+ top_k_paddle = paddle . full (( self .batch_size ,), 20 ). cast ( "int64" )
53
+ samples = rejection_top_p_sampling ( self . paddle_norm_prob , top_p_paddle , top_k_paddle , - 1 )
53
54
self ._validate_samples (samples )
54
55
55
56
# Additional check that we're getting different results for different top_p
56
57
unique_samples = len (paddle .unique (samples ))
57
- print (f"Unique samples: { unique_samples } " )
58
58
self .assertGreater (unique_samples , 1 ) # Should have some diversity
59
59
60
60
def _validate_samples (self , samples ):
0 commit comments