Skip to content

Commit fde43aa

Browse files
committed
optimize DeepSeek_v3 Eliminate redundant calculations & encoder using FA3
1 parent 19fda4e commit fde43aa

File tree

2 files changed

+52
-50
lines changed

2 files changed

+52
-50
lines changed

fastdeploy/model_executor/layers/attention/mla_attention_backend.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424
import paddle
2525
from paddle.nn.functional.flash_attention import flash_attn_unpadded
2626

27+
try:
28+
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
29+
except:
30+
flash_attention_v3_varlen = None
31+
2732
from fastdeploy.model_executor.layers.attention.ops import (
2833
get_block_shape_and_split_kv_block,
2934
init_kv_signal_per_query,
@@ -91,6 +96,7 @@ class MLAAttentionBackend(AttentionBackend):
9196
"""
9297

9398
__infer_dynamic_dims_fields__ = ["attention_metadata"]
99+
flash_attn_func: callable = None
94100
attention_metadata: MLAAttentionMetadata
95101

96102
def __init__(
@@ -147,6 +153,21 @@ def __init__(
147153
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
148154

149155
self.rank, self.device_id = init_rank_and_device_id(fd_config)
156+
if self.flash_attn_func is None:
157+
prop = paddle.device.cuda.get_device_properties()
158+
cc = prop.major * 10 + prop.minor
159+
is_current_sm_supported = cc >= 90
160+
is_paddle_supported = any(num >= 90 for num in paddle.version.cuda_archs())
161+
if is_current_sm_supported and is_paddle_supported:
162+
self.flash_attn_func = flash_attention_v3_varlen
163+
print("The current platform supports Flash Attention V3.")
164+
self.flash_attn_kwargs = {}
165+
else:
166+
self.flash_attn_func = flash_attn_unpadded
167+
self.flash_attn_kwargs = {"scale": self.head_dim**-0.5, "training": False}
168+
print(
169+
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
170+
)
150171

151172
def init_attention_metadata(self, forward_meta: ForwardMeta):
152173
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
@@ -269,17 +290,16 @@ def forward_extend(
269290
)
270291

271292
# Flash注意力计算
272-
fmha_out = flash_attn_unpadded(
293+
fmha_out = self.flash_attn_func(
273294
q,
274295
k,
275296
v,
276297
forward_meta.cu_seqlens_q,
277298
forward_meta.cu_seqlens_k,
278-
metadata.max_enc_len_this_time,
279-
metadata.max_enc_len_this_time,
280-
self.attn_softmax_scale,
281-
causal=True,
282-
training=False,
299+
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
300+
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
301+
causal=self.causal,
302+
**self.flash_attn_kwargs,
283303
)[0]
284304

285305
return fmha_out
@@ -418,17 +438,16 @@ def forward_mixed(
418438
)
419439

420440
# FA
421-
fmha_out = flash_attn_unpadded(
441+
fmha_out = self.flash_attn_func(
422442
q,
423443
k,
424444
v,
425445
forward_meta.cu_seqlens_q,
426446
forward_meta.cu_seqlens_k,
427-
metadata.max_enc_len_this_time,
428-
metadata.max_enc_len_this_time,
429-
self.attn_softmax_scale,
430-
causal=True,
431-
training=False,
447+
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
448+
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
449+
causal=self.causal,
450+
**self.flash_attn_kwargs,
432451
)[0]
433452

434453
return fmha_out

fastdeploy/model_executor/models/deepseek_v3.py

Lines changed: 21 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -316,30 +316,23 @@ def forward(
316316
mask_encoder_batch: paddle.Tensor,
317317
):
318318
""" """
319-
layernorm_out = hidden_states
320-
fmha_out = paddle.zeros(
321-
shape=[
322-
layernorm_out.shape[0],
323-
self.num_attention_heads_tp * self.v_head_dim,
324-
],
325-
dtype=layernorm_out.dtype,
326-
)
327-
328-
if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
329-
query = self.q_a_proj(layernorm_out)
330-
query = self.q_a_layernorm(query)
331-
query = self.q_b_proj(query)
319+
fmha_out = None
320+
# NOTE: (changwenbin) Bring out the public calculation in PD MIX to avoid repeated calculation.
321+
query = self.q_a_proj(hidden_states)
322+
query = self.q_a_layernorm(query)
323+
query = self.q_b_proj(query)
324+
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
325+
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
332326

333-
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
334-
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
327+
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
328+
compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
329+
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
330+
compressed_kv = self.kv_a_layernorm(compressed_kv)
335331

336-
compressed_kv = self.kv_a_proj_with_mqa(layernorm_out)
337-
compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
338-
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
339-
compressed_kv = self.kv_a_layernorm(compressed_kv)
340-
341-
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
332+
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
342333

334+
if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
335+
# NOTE: (changwenbin) We will take the public part
343336
key_value = self.kv_b_proj(compressed_kv)
344337
key_value = key_value.reshape(
345338
[
@@ -371,23 +364,10 @@ def forward(
371364
fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
372365
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype)
373366

374-
fmha_out = fmha_out + fmha_out_prefill
375-
if forward_meta.max_len_tensor_cpu[2]: # max_dec_len_this_time
376-
query = self.q_a_proj(layernorm_out)
377-
query = self.q_a_layernorm(query)
378-
ln_out_or_q_c = query
379-
380-
compressed_kv = self.kv_a_proj_with_mqa(layernorm_out)
381-
compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
382-
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
383-
compressed_kv = self.kv_a_layernorm(compressed_kv)
384-
385-
query = self.q_b_proj(ln_out_or_q_c)
386-
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
387-
388-
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
389-
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
367+
fmha_out = fmha_out_prefill
390368

369+
if forward_meta.max_len_tensor_cpu[2]: # max_dec_len_this_time
370+
# NOTE: (changwenbin) We will take the public part
391371
q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2])
392372

393373
q_input = paddle.concat([q_nope_out, query_pe], axis=-1)
@@ -416,7 +396,10 @@ def forward(
416396
.transpose([1, 0, 2])
417397
.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
418398
)
419-
fmha_out = fmha_out + fmha_out_decode
399+
if fmha_out is None:
400+
fmha_out = fmha_out_decode
401+
else:
402+
fmha_out = fmha_out + fmha_out_decode
420403

421404
output = self.o_proj(fmha_out)
422405
return output

0 commit comments

Comments
 (0)