Skip to content

Commit beec24f

Browse files
authored
【Inference Optimize】DeepSeek-v3 model inference performance optimization (#3455)
* DSK_OPT_01 * update FA3
1 parent c95b339 commit beec24f

File tree

2 files changed

+47
-45
lines changed

2 files changed

+47
-45
lines changed

fastdeploy/model_executor/layers/attention/mla_attention_backend.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424
import paddle
2525
from paddle.nn.functional.flash_attention import flash_attn_unpadded
2626

27+
try:
28+
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
29+
except:
30+
flash_attention_v3_varlen = None
31+
2732
from fastdeploy.model_executor.layers.attention.ops import (
2833
get_block_shape_and_split_kv_block,
2934
init_kv_signal_per_query,
@@ -92,6 +97,7 @@ class MLAAttentionBackend(AttentionBackend):
9297

9398
__infer_dynamic_dims_fields__ = ["attention_metadata"]
9499
attention_metadata: MLAAttentionMetadata
100+
flash_attn_func: callable = None
95101

96102
def __init__(
97103
self,
@@ -148,6 +154,22 @@ def __init__(
148154

149155
self.rank, self.device_id = init_rank_and_device_id(fd_config)
150156

157+
if self.flash_attn_func is None:
158+
prop = paddle.device.cuda.get_device_properties()
159+
cc = prop.major * 10 + prop.minor
160+
is_current_sm_supported = cc >= 90
161+
is_paddle_supported = any(num >= 90 for num in paddle.version.cuda_archs())
162+
if is_current_sm_supported and is_paddle_supported:
163+
self.flash_attn_func = flash_attention_v3_varlen
164+
print("The current platform supports Flash Attention V3.")
165+
self.flash_attn_kwargs = {"softmax_scale": self.attn_softmax_scale}
166+
else:
167+
self.flash_attn_func = flash_attn_unpadded
168+
self.flash_attn_kwargs = {"scale": self.attn_softmax_scale, "training": False}
169+
print(
170+
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
171+
)
172+
151173
def init_attention_metadata(self, forward_meta: ForwardMeta):
152174
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
153175
metadata = MLAAttentionMetadata()
@@ -269,17 +291,16 @@ def forward_extend(
269291
)
270292

271293
# Flash注意力计算
272-
fmha_out = flash_attn_unpadded(
294+
fmha_out = self.flash_attn_func(
273295
q,
274296
k,
275297
v,
276298
forward_meta.cu_seqlens_q,
277299
forward_meta.cu_seqlens_k,
278300
metadata.max_enc_len_this_time,
279301
metadata.max_enc_len_this_time,
280-
self.attn_softmax_scale,
281-
causal=True,
282-
training=False,
302+
causal=self.causal,
303+
**self.flash_attn_kwargs,
283304
)[0]
284305

285306
return fmha_out
@@ -418,17 +439,16 @@ def forward_mixed(
418439
)
419440

420441
# FA
421-
fmha_out = flash_attn_unpadded(
442+
fmha_out = self.flash_attn_func(
422443
q,
423444
k,
424445
v,
425446
forward_meta.cu_seqlens_q,
426447
forward_meta.cu_seqlens_k,
427448
metadata.max_enc_len_this_time,
428449
metadata.max_enc_len_this_time,
429-
self.attn_softmax_scale,
430-
causal=True,
431-
training=False,
450+
causal=self.causal,
451+
**self.flash_attn_kwargs,
432452
)[0]
433453

434454
return fmha_out

fastdeploy/model_executor/models/deepseek_v3.py

Lines changed: 19 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -316,30 +316,23 @@ def forward(
316316
mask_encoder_batch: paddle.Tensor,
317317
):
318318
""" """
319-
layernorm_out = hidden_states
320-
fmha_out = paddle.zeros(
321-
shape=[
322-
layernorm_out.shape[0],
323-
self.num_attention_heads_tp * self.v_head_dim,
324-
],
325-
dtype=layernorm_out.dtype,
326-
)
327-
328-
if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
329-
query = self.q_a_proj(layernorm_out)
330-
query = self.q_a_layernorm(query)
331-
query = self.q_b_proj(query)
332319

333-
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
334-
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
320+
# NOTE: (changwenbin) Bring out the public calculation in PD MIX to avoid repeated calculation.
321+
fmha_out = None
322+
query = self.q_a_proj(hidden_states)
323+
query = self.q_a_layernorm(query)
324+
query = self.q_b_proj(query)
325+
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
326+
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
335327

336-
compressed_kv = self.kv_a_proj_with_mqa(layernorm_out)
337-
compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
338-
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
339-
compressed_kv = self.kv_a_layernorm(compressed_kv)
328+
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
329+
compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
330+
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
331+
compressed_kv = self.kv_a_layernorm(compressed_kv)
340332

341-
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
333+
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
342334

335+
if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
343336
key_value = self.kv_b_proj(compressed_kv)
344337
key_value = key_value.reshape(
345338
[
@@ -371,23 +364,9 @@ def forward(
371364
fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
372365
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype)
373366

374-
fmha_out = fmha_out + fmha_out_prefill
375-
if forward_meta.max_len_tensor_cpu[2]: # max_dec_len_this_time
376-
query = self.q_a_proj(layernorm_out)
377-
query = self.q_a_layernorm(query)
378-
ln_out_or_q_c = query
379-
380-
compressed_kv = self.kv_a_proj_with_mqa(layernorm_out)
381-
compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
382-
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
383-
compressed_kv = self.kv_a_layernorm(compressed_kv)
384-
385-
query = self.q_b_proj(ln_out_or_q_c)
386-
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
387-
388-
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
389-
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
367+
fmha_out = fmha_out_prefill
390368

369+
if forward_meta.max_len_tensor_cpu[2]: # max_dec_len_this_time
391370
q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2])
392371

393372
q_input = paddle.concat([q_nope_out, query_pe], axis=-1)
@@ -416,7 +395,10 @@ def forward(
416395
.transpose([1, 0, 2])
417396
.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
418397
)
419-
fmha_out = fmha_out + fmha_out_decode
398+
if fmha_out is None:
399+
fmha_out = fmha_out_decode
400+
else:
401+
fmha_out = fmha_out + fmha_out_decode
420402

421403
output = self.o_proj(fmha_out)
422404
return output

0 commit comments

Comments
 (0)