Skip to content

Commit 8dd52a6

Browse files
committed
add unittest
1 parent bab0951 commit 8dd52a6

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

test/layers/test_append_attention.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ def setUp(self):
349349
self.rope_theta = 10000
350350
self.dtype = "float16"
351351
self.use_qk_norm = True
352+
self.use_mask_offset = False
352353
self.init_tensor()
353354

354355
def init_tensor(self):
@@ -404,6 +405,12 @@ def init_tensor(self):
404405
self.cu_seqlens_k,
405406
) = get_padding_offset(self.batch_size, self.seq_len, self.seq_lens_this_time)
406407
self.token_num = self.padding_offset.shape[0]
408+
self.mask_offset = None
409+
if self.use_mask_offset:
410+
self.mask_offset = paddle.full(self.seq_len * self.batch_size, 0, "int32")
411+
for i in range(self.batch_size):
412+
for j in range(self.seq_len):
413+
self.mask_offset[i * self.seq_len + j] = j
407414

408415
def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask=None):
409416
paddle.disable_static()
@@ -505,6 +512,7 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask
505512
None, # cache_v_zp
506513
None, # linear_shift
507514
None, # linear_smooth
515+
self.mask_offset, # mask_offset
508516
None, # kv_signal_data
509517
q_norm_weight, # q_norm_weight
510518
k_norm_weight, # k_norm_weight
@@ -560,6 +568,8 @@ def test_all(self):
560568
# encoder
561569
# self.seq_lens_encoder,self.seq_lens_decoder,self.max_enc_len_this_time,self.max_dec_len_this_time=get_encoder_decoder_len(self.batch_size,self.seq_len)
562570
self.seq_lens_this_time = self.seq_lens_encoder
571+
if self.use_mask_offset:
572+
print("encoder mask_offset: ", self.mask_offset)
563573
self.cmp_append_attention(attn_mask=self.attention_mask)
564574
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
565575
self.cache_k,
@@ -590,6 +600,11 @@ def test_all(self):
590600
self.cu_seqlens_q,
591601
self.cu_seqlens_k,
592602
) = get_padding_offset(self.batch_size, 1, self.seq_lens_this_time)
603+
if self.use_mask_offset:
604+
self.mask_offset = paddle.full(self.batch_size, 0, "int32")
605+
for i in range(self.batch_size):
606+
self.mask_offset[i] = self.seq_lens_dec[i]
607+
print("decoder mask_offset: ", self.mask_offset)
593608
self.cmp_append_attention(naive_cache_k, naive_cache_v, None)
594609

595610

@@ -614,6 +629,7 @@ def setUp(self):
614629
self.rope_theta = 10000
615630
self.dtype = "float16"
616631
self.use_qk_norm = False
632+
self.use_mask_offset = True
617633
self.init_tensor()
618634

619635

0 commit comments

Comments
 (0)