Skip to content

Commit 3f87f83

Browse files
authored
Fuse q_a_proj and kv_a_proj (#5619)
1 parent ce5412b commit 3f87f83

File tree

1 file changed

+78
-25
lines changed

1 file changed

+78
-25
lines changed

python/sglang/srt/models/deepseek_v2.py

Lines changed: 78 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -443,12 +443,12 @@ def __init__(
443443

444444
# For tensor parallel attention
445445
if self.q_lora_rank is not None:
446-
self.q_a_proj = ReplicatedLinear(
446+
self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
447447
self.hidden_size,
448-
self.q_lora_rank,
448+
self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
449449
bias=False,
450450
quant_config=quant_config,
451-
prefix=add_prefix("q_a_proj", prefix),
451+
prefix=add_prefix("fused_qkv_a_proj_with_mqa", prefix),
452452
)
453453
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
454454
self.q_b_proj = ColumnParallelLinear(
@@ -470,6 +470,14 @@ def __init__(
470470
tp_rank=attn_tp_rank,
471471
tp_size=attn_tp_size,
472472
)
473+
self.kv_a_proj_with_mqa = ReplicatedLinear(
474+
self.hidden_size,
475+
self.kv_lora_rank + self.qk_rope_head_dim,
476+
bias=False,
477+
quant_config=quant_config,
478+
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
479+
)
480+
473481
self.kv_b_proj = ColumnParallelLinear(
474482
self.kv_lora_rank,
475483
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
@@ -490,14 +498,6 @@ def __init__(
490498
tp_rank=attn_tp_rank,
491499
tp_size=attn_tp_size,
492500
)
493-
494-
self.kv_a_proj_with_mqa = ReplicatedLinear(
495-
self.hidden_size,
496-
self.kv_lora_rank + self.qk_rope_head_dim,
497-
bias=False,
498-
quant_config=quant_config,
499-
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
500-
)
501501
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
502502

503503
if rope_scaling:
@@ -656,15 +656,18 @@ def forward_normal(
656656
forward_batch: ForwardBatch,
657657
) -> torch.Tensor:
658658
if self.q_lora_rank is not None:
659-
q = self.q_a_proj(hidden_states)[0]
659+
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
660+
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
661+
)
660662
q = self.q_a_layernorm(q)
661663
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
662664
else:
663665
q = self.q_proj(hidden_states)[0].view(
664666
-1, self.num_local_heads, self.qk_head_dim
665667
)
668+
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
669+
666670
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
667-
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
668671
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
669672
latent_cache = latent_cache.unsqueeze(1)
670673
kv_a = self.kv_a_layernorm(kv_a.contiguous())
@@ -699,13 +702,16 @@ def forward_absorb(
699702
zero_allocator: BumpAllocator,
700703
) -> torch.Tensor:
701704
if self.q_lora_rank is not None:
702-
q = self.q_a_proj(hidden_states)[0]
705+
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
706+
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
707+
)
703708
q = self.q_a_layernorm(q)
704709
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
705710
else:
706711
q = self.q_proj(hidden_states)[0].view(
707712
-1, self.num_local_heads, self.qk_head_dim
708713
)
714+
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
709715
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
710716

711717
if self.use_deep_gemm_bmm:
@@ -744,7 +750,6 @@ def forward_absorb(
744750

745751
q_nope_out = q_nope_out.transpose(0, 1)
746752

747-
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
748753
k_nope = latent_cache[..., : self.kv_lora_rank]
749754
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
750755
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
@@ -819,13 +824,16 @@ def forward_absorb_fused_mla_rope(
819824
q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
820825
)
821826
if self.q_lora_rank is not None:
822-
q = self.q_a_proj(hidden_states)[0]
827+
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
828+
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
829+
)
823830
q = self.q_a_layernorm(q)
824831
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
825832
else:
826833
q = self.q_proj(hidden_states)[0].view(
827834
-1, self.num_local_heads, self.qk_head_dim
828835
)
836+
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
829837
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
830838

831839
if self.w_kc.dtype == torch.float8_e4m3fnuz:
@@ -846,8 +854,6 @@ def forward_absorb_fused_mla_rope(
846854
else:
847855
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
848856
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
849-
850-
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
851857
v_input = latent_cache[..., : self.kv_lora_rank]
852858
v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
853859
k_input = latent_cache.unsqueeze(1)
@@ -1018,15 +1024,17 @@ def forward_normal_chunked_kv(
10181024

10191025
# First do normal mha forward to get output for extended part
10201026
if self.q_lora_rank is not None:
1021-
q = self.q_a_proj(hidden_states)[0]
1027+
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
1028+
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
1029+
)
10221030
q = self.q_a_layernorm(q)
10231031
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
10241032
else:
10251033
q = self.q_proj(hidden_states)[0].view(
10261034
-1, self.num_local_heads, self.qk_head_dim
10271035
)
1036+
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
10281037
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
1029-
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
10301038
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
10311039
latent_cache = latent_cache.unsqueeze(1)
10321040
kv_a = self.kv_a_layernorm(kv_a.contiguous())
@@ -1668,6 +1676,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
16681676
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
16691677
)
16701678

1679+
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
1680+
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
1681+
self.config.q_lora_rank is not None
1682+
)
1683+
cached_a_proj = {} if fuse_qkv_a_proj else None
1684+
16711685
params_dict = dict(self.named_parameters())
16721686
for name, loaded_weight in weights:
16731687
# TODO(HandH1998): Modify it when nextn is supported.
@@ -1723,11 +1737,50 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
17231737
if name.endswith(".bias") and name not in params_dict:
17241738
continue
17251739

1726-
param = params_dict[name]
1727-
weight_loader = getattr(
1728-
param, "weight_loader", default_weight_loader
1729-
)
1730-
weight_loader(param, loaded_weight)
1740+
if fuse_qkv_a_proj and (
1741+
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
1742+
):
1743+
cached_a_proj[name] = loaded_weight
1744+
q_a_proj_name = (
1745+
name
1746+
if "q_a_proj" in name
1747+
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
1748+
)
1749+
kv_a_proj_name = (
1750+
name
1751+
if "kv_a_proj_with_mqa" in name
1752+
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
1753+
)
1754+
1755+
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
1756+
if (
1757+
q_a_proj_name in cached_a_proj
1758+
and kv_a_proj_name in cached_a_proj
1759+
):
1760+
1761+
q_a_proj_weight = cached_a_proj[q_a_proj_name]
1762+
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
1763+
fused_weight = torch.cat(
1764+
[q_a_proj_weight, kv_a_proj_weight], dim=0
1765+
)
1766+
1767+
param_name = name.replace(
1768+
"q_a_proj", "fused_qkv_a_proj_with_mqa"
1769+
)
1770+
param = params_dict[param_name]
1771+
1772+
weight_loader = getattr(
1773+
param, "weight_loader", default_weight_loader
1774+
)
1775+
weight_loader(param, fused_weight)
1776+
cached_a_proj.pop(q_a_proj_name)
1777+
cached_a_proj.pop(kv_a_proj_name)
1778+
else:
1779+
param = params_dict[name]
1780+
weight_loader = getattr(
1781+
param, "weight_loader", default_weight_loader
1782+
)
1783+
weight_loader(param, loaded_weight)
17311784

17321785
self.post_load_weights()
17331786

0 commit comments

Comments
 (0)