@@ -349,6 +349,7 @@ def setUp(self):
349
349
self .rope_theta = 10000
350
350
self .dtype = "float16"
351
351
self .use_qk_norm = True
352
+ self .use_mask_offset = False
352
353
self .init_tensor ()
353
354
354
355
def init_tensor (self ):
@@ -404,6 +405,12 @@ def init_tensor(self):
404
405
self .cu_seqlens_k ,
405
406
) = get_padding_offset (self .batch_size , self .seq_len , self .seq_lens_this_time )
406
407
self .token_num = self .padding_offset .shape [0 ]
408
+ self .mask_offset = None
409
+ if self .use_mask_offset :
410
+ self .mask_offset = paddle .full (self .seq_len * self .batch_size , 0 , "int32" )
411
+ for i in range (self .batch_size ):
412
+ for j in range (self .seq_len ):
413
+ self .mask_offset [i * self .seq_len + j ] = j
407
414
408
415
def cmp_append_attention (self , naive_cache_k = None , naive_cache_v = None , attn_mask = None ):
409
416
paddle .disable_static ()
@@ -505,6 +512,7 @@ def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask
505
512
None , # cache_v_zp
506
513
None , # linear_shift
507
514
None , # linear_smooth
515
+ self .mask_offset , # mask_offset
508
516
None , # kv_signal_data
509
517
q_norm_weight , # q_norm_weight
510
518
k_norm_weight , # k_norm_weight
@@ -560,6 +568,8 @@ def test_all(self):
560
568
# encoder
561
569
# self.seq_lens_encoder,self.seq_lens_decoder,self.max_enc_len_this_time,self.max_dec_len_this_time=get_encoder_decoder_len(self.batch_size,self.seq_len)
562
570
self .seq_lens_this_time = self .seq_lens_encoder
571
+ if self .use_mask_offset :
572
+ print ("encoder mask_offset: " , self .mask_offset )
563
573
self .cmp_append_attention (attn_mask = self .attention_mask )
564
574
naive_cache_k , naive_cache_v = block_cache_to_naive_cache (
565
575
self .cache_k ,
@@ -590,6 +600,11 @@ def test_all(self):
590
600
self .cu_seqlens_q ,
591
601
self .cu_seqlens_k ,
592
602
) = get_padding_offset (self .batch_size , 1 , self .seq_lens_this_time )
603
+ if self .use_mask_offset :
604
+ self .mask_offset = paddle .full (self .batch_size , 0 , "int32" )
605
+ for i in range (self .batch_size ):
606
+ self .mask_offset [i ] = self .seq_lens_dec [i ]
607
+ print ("decoder mask_offset: " , self .mask_offset )
593
608
self .cmp_append_attention (naive_cache_k , naive_cache_v , None )
594
609
595
610
@@ -614,6 +629,7 @@ def setUp(self):
614
629
self .rope_theta = 10000
615
630
self .dtype = "float16"
616
631
self .use_qk_norm = False
632
+ self .use_mask_offset = True
617
633
self .init_tensor ()
618
634
619
635
0 commit comments