@@ -316,30 +316,23 @@ def forward(
316
316
mask_encoder_batch : paddle .Tensor ,
317
317
):
318
318
""" """
319
- layernorm_out = hidden_states
320
- fmha_out = paddle .zeros (
321
- shape = [
322
- layernorm_out .shape [0 ],
323
- self .num_attention_heads_tp * self .v_head_dim ,
324
- ],
325
- dtype = layernorm_out .dtype ,
326
- )
327
-
328
- if forward_meta .max_len_tensor_cpu [1 ]: # max_enc_len_this_time
329
- query = self .q_a_proj (layernorm_out )
330
- query = self .q_a_layernorm (query )
331
- query = self .q_b_proj (query )
332
319
333
- query = query .reshape ([- 1 , self .num_attention_heads_tp , self .qk_head_dim ])
334
- query_nope , query_pe = query .split ([self .qk_nope_head_dim , self .qk_rope_head_dim ], axis = - 1 )
320
+ # NOTE: (changwenbin) Bring out the public calculation in PD MIX to avoid repeated calculation.
321
+ fmha_out = None
322
+ query = self .q_a_proj (hidden_states )
323
+ query = self .q_a_layernorm (query )
324
+ query = self .q_b_proj (query )
325
+ query = query .reshape ([- 1 , self .num_attention_heads_tp , self .qk_head_dim ])
326
+ query_nope , query_pe = query .split ([self .qk_nope_head_dim , self .qk_rope_head_dim ], axis = - 1 )
335
327
336
- compressed_kv = self .kv_a_proj_with_mqa (layernorm_out )
337
- compressed_kv , key_pe = compressed_kv .split ([self .kv_lora_rank , self .qk_rope_head_dim ], axis = - 1 )
338
- key_pe = key_pe .reshape ([- 1 , 1 , self .qk_rope_head_dim ])
339
- compressed_kv = self .kv_a_layernorm (compressed_kv )
328
+ compressed_kv = self .kv_a_proj_with_mqa (hidden_states )
329
+ compressed_kv , key_pe = compressed_kv .split ([self .kv_lora_rank , self .qk_rope_head_dim ], axis = - 1 )
330
+ key_pe = key_pe .reshape ([- 1 , 1 , self .qk_rope_head_dim ])
331
+ compressed_kv = self .kv_a_layernorm (compressed_kv )
340
332
341
- query_pe , key_pe = self .rotary_emb (position_ids , query_pe , key_pe )
333
+ query_pe , key_pe = self .rotary_emb (position_ids , query_pe , key_pe )
342
334
335
+ if forward_meta .max_len_tensor_cpu [1 ]: # max_enc_len_this_time
343
336
key_value = self .kv_b_proj (compressed_kv )
344
337
key_value = key_value .reshape (
345
338
[
@@ -371,23 +364,9 @@ def forward(
371
364
fmha_out_prefill = fmha_out_prefill .reshape ([- 1 , self .num_attention_heads_tp * self .v_head_dim ])
372
365
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch .cast (fmha_out_prefill .dtype )
373
366
374
- fmha_out = fmha_out + fmha_out_prefill
375
- if forward_meta .max_len_tensor_cpu [2 ]: # max_dec_len_this_time
376
- query = self .q_a_proj (layernorm_out )
377
- query = self .q_a_layernorm (query )
378
- ln_out_or_q_c = query
379
-
380
- compressed_kv = self .kv_a_proj_with_mqa (layernorm_out )
381
- compressed_kv , key_pe = compressed_kv .split ([self .kv_lora_rank , self .qk_rope_head_dim ], axis = - 1 )
382
- key_pe = key_pe .reshape ([- 1 , 1 , self .qk_rope_head_dim ])
383
- compressed_kv = self .kv_a_layernorm (compressed_kv )
384
-
385
- query = self .q_b_proj (ln_out_or_q_c )
386
- query = query .reshape ([- 1 , self .num_attention_heads_tp , self .qk_head_dim ])
387
-
388
- query_nope , query_pe = query .split ([self .qk_nope_head_dim , self .qk_rope_head_dim ], axis = - 1 )
389
- query_pe , key_pe = self .rotary_emb (position_ids , query_pe , key_pe )
367
+ fmha_out = fmha_out_prefill
390
368
369
+ if forward_meta .max_len_tensor_cpu [2 ]: # max_dec_len_this_time
391
370
q_nope_out = self .kv_b_proj_bmm (query_nope .transpose ([1 , 0 , 2 ]), proj_type = "k" ).transpose ([1 , 0 , 2 ])
392
371
393
372
q_input = paddle .concat ([q_nope_out , query_pe ], axis = - 1 )
@@ -416,7 +395,10 @@ def forward(
416
395
.transpose ([1 , 0 , 2 ])
417
396
.reshape ([- 1 , self .num_attention_heads_tp * self .v_head_dim ])
418
397
)
419
- fmha_out = fmha_out + fmha_out_decode
398
+ if fmha_out is None :
399
+ fmha_out = fmha_out_decode
400
+ else :
401
+ fmha_out = fmha_out + fmha_out_decode
420
402
421
403
output = self .o_proj (fmha_out )
422
404
return output
0 commit comments