@@ -443,12 +443,12 @@ def __init__(
443
443
444
444
# For tensor parallel attention
445
445
if self .q_lora_rank is not None :
446
- self .q_a_proj = ReplicatedLinear (
446
+ self .fused_qkv_a_proj_with_mqa = ReplicatedLinear (
447
447
self .hidden_size ,
448
- self .q_lora_rank ,
448
+ self .q_lora_rank + self . kv_lora_rank + self . qk_rope_head_dim ,
449
449
bias = False ,
450
450
quant_config = quant_config ,
451
- prefix = add_prefix ("q_a_proj " , prefix ),
451
+ prefix = add_prefix ("fused_qkv_a_proj_with_mqa " , prefix ),
452
452
)
453
453
self .q_a_layernorm = RMSNorm (self .q_lora_rank , eps = config .rms_norm_eps )
454
454
self .q_b_proj = ColumnParallelLinear (
@@ -470,6 +470,14 @@ def __init__(
470
470
tp_rank = attn_tp_rank ,
471
471
tp_size = attn_tp_size ,
472
472
)
473
+ self .kv_a_proj_with_mqa = ReplicatedLinear (
474
+ self .hidden_size ,
475
+ self .kv_lora_rank + self .qk_rope_head_dim ,
476
+ bias = False ,
477
+ quant_config = quant_config ,
478
+ prefix = add_prefix ("kv_a_proj_with_mqa" , prefix ),
479
+ )
480
+
473
481
self .kv_b_proj = ColumnParallelLinear (
474
482
self .kv_lora_rank ,
475
483
self .num_heads * (self .qk_nope_head_dim + self .v_head_dim ),
@@ -490,14 +498,6 @@ def __init__(
490
498
tp_rank = attn_tp_rank ,
491
499
tp_size = attn_tp_size ,
492
500
)
493
-
494
- self .kv_a_proj_with_mqa = ReplicatedLinear (
495
- self .hidden_size ,
496
- self .kv_lora_rank + self .qk_rope_head_dim ,
497
- bias = False ,
498
- quant_config = quant_config ,
499
- prefix = add_prefix ("kv_a_proj_with_mqa" , prefix ),
500
- )
501
501
self .kv_a_layernorm = RMSNorm (self .kv_lora_rank , eps = config .rms_norm_eps )
502
502
503
503
if rope_scaling :
@@ -656,15 +656,18 @@ def forward_normal(
656
656
forward_batch : ForwardBatch ,
657
657
) -> torch .Tensor :
658
658
if self .q_lora_rank is not None :
659
- q = self .q_a_proj (hidden_states )[0 ]
659
+ q , latent_cache = self .fused_qkv_a_proj_with_mqa (hidden_states )[0 ].split (
660
+ [self .q_lora_rank , self .kv_lora_rank + self .qk_rope_head_dim ], dim = - 1
661
+ )
660
662
q = self .q_a_layernorm (q )
661
663
q = self .q_b_proj (q )[0 ].view (- 1 , self .num_local_heads , self .qk_head_dim )
662
664
else :
663
665
q = self .q_proj (hidden_states )[0 ].view (
664
666
- 1 , self .num_local_heads , self .qk_head_dim
665
667
)
668
+ latent_cache = self .kv_a_proj_with_mqa (hidden_states )[0 ]
669
+
666
670
_ , q_pe = q .split ([self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
667
- latent_cache = self .kv_a_proj_with_mqa (hidden_states )[0 ]
668
671
kv_a , _ = latent_cache .split ([self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
669
672
latent_cache = latent_cache .unsqueeze (1 )
670
673
kv_a = self .kv_a_layernorm (kv_a .contiguous ())
@@ -699,13 +702,16 @@ def forward_absorb(
699
702
zero_allocator : BumpAllocator ,
700
703
) -> torch .Tensor :
701
704
if self .q_lora_rank is not None :
702
- q = self .q_a_proj (hidden_states )[0 ]
705
+ q , latent_cache = self .fused_qkv_a_proj_with_mqa (hidden_states )[0 ].split (
706
+ [self .q_lora_rank , self .kv_lora_rank + self .qk_rope_head_dim ], dim = - 1
707
+ )
703
708
q = self .q_a_layernorm (q )
704
709
q = self .q_b_proj (q )[0 ].view (- 1 , self .num_local_heads , self .qk_head_dim )
705
710
else :
706
711
q = self .q_proj (hidden_states )[0 ].view (
707
712
- 1 , self .num_local_heads , self .qk_head_dim
708
713
)
714
+ latent_cache = self .kv_a_proj_with_mqa (hidden_states )[0 ]
709
715
q_nope , q_pe = q .split ([self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
710
716
711
717
if self .use_deep_gemm_bmm :
@@ -744,7 +750,6 @@ def forward_absorb(
744
750
745
751
q_nope_out = q_nope_out .transpose (0 , 1 )
746
752
747
- latent_cache = self .kv_a_proj_with_mqa (hidden_states )[0 ]
748
753
k_nope = latent_cache [..., : self .kv_lora_rank ]
749
754
k_nope = self .kv_a_layernorm (k_nope ).unsqueeze (1 )
750
755
k_pe = latent_cache [..., self .kv_lora_rank :].unsqueeze (1 )
@@ -819,13 +824,16 @@ def forward_absorb_fused_mla_rope(
819
824
q_len , self .num_local_heads , self .kv_lora_rank + self .qk_rope_head_dim
820
825
)
821
826
if self .q_lora_rank is not None :
822
- q = self .q_a_proj (hidden_states )[0 ]
827
+ q , latent_cache = self .fused_qkv_a_proj_with_mqa (hidden_states )[0 ].split (
828
+ [self .q_lora_rank , self .kv_lora_rank + self .qk_rope_head_dim ], dim = - 1
829
+ )
823
830
q = self .q_a_layernorm (q )
824
831
q = self .q_b_proj (q )[0 ].view (- 1 , self .num_local_heads , self .qk_head_dim )
825
832
else :
826
833
q = self .q_proj (hidden_states )[0 ].view (
827
834
- 1 , self .num_local_heads , self .qk_head_dim
828
835
)
836
+ latent_cache = self .kv_a_proj_with_mqa (hidden_states )[0 ]
829
837
q_nope , q_pe = q .split ([self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
830
838
831
839
if self .w_kc .dtype == torch .float8_e4m3fnuz :
@@ -846,8 +854,6 @@ def forward_absorb_fused_mla_rope(
846
854
else :
847
855
q_nope_out = torch .bmm (q_nope .transpose (0 , 1 ), self .w_kc )
848
856
q_input [..., : self .kv_lora_rank ] = q_nope_out .transpose (0 , 1 )
849
-
850
- latent_cache = self .kv_a_proj_with_mqa (hidden_states )[0 ]
851
857
v_input = latent_cache [..., : self .kv_lora_rank ]
852
858
v_input = self .kv_a_layernorm (v_input .contiguous ()).unsqueeze (1 )
853
859
k_input = latent_cache .unsqueeze (1 )
@@ -1018,15 +1024,17 @@ def forward_normal_chunked_kv(
1018
1024
1019
1025
# First do normal mha forward to get output for extended part
1020
1026
if self .q_lora_rank is not None :
1021
- q = self .q_a_proj (hidden_states )[0 ]
1027
+ q , latent_cache = self .fused_qkv_a_proj_with_mqa (hidden_states )[0 ].split (
1028
+ [self .q_lora_rank , self .kv_lora_rank + self .qk_rope_head_dim ], dim = - 1
1029
+ )
1022
1030
q = self .q_a_layernorm (q )
1023
1031
q = self .q_b_proj (q )[0 ].view (- 1 , self .num_local_heads , self .qk_head_dim )
1024
1032
else :
1025
1033
q = self .q_proj (hidden_states )[0 ].view (
1026
1034
- 1 , self .num_local_heads , self .qk_head_dim
1027
1035
)
1036
+ latent_cache = self .kv_a_proj_with_mqa (hidden_states )[0 ]
1028
1037
_ , q_pe = q .split ([self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
1029
- latent_cache = self .kv_a_proj_with_mqa (hidden_states )[0 ]
1030
1038
kv_a , _ = latent_cache .split ([self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
1031
1039
latent_cache = latent_cache .unsqueeze (1 )
1032
1040
kv_a = self .kv_a_layernorm (kv_a .contiguous ())
@@ -1668,6 +1676,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
1668
1676
num_experts = self .config .n_routed_experts + self .n_share_experts_fusion ,
1669
1677
)
1670
1678
1679
+ # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
1680
+ fuse_qkv_a_proj = hasattr (self .config , "q_lora_rank" ) and (
1681
+ self .config .q_lora_rank is not None
1682
+ )
1683
+ cached_a_proj = {} if fuse_qkv_a_proj else None
1684
+
1671
1685
params_dict = dict (self .named_parameters ())
1672
1686
for name , loaded_weight in weights :
1673
1687
# TODO(HandH1998): Modify it when nextn is supported.
@@ -1723,11 +1737,50 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
1723
1737
if name .endswith (".bias" ) and name not in params_dict :
1724
1738
continue
1725
1739
1726
- param = params_dict [name ]
1727
- weight_loader = getattr (
1728
- param , "weight_loader" , default_weight_loader
1729
- )
1730
- weight_loader (param , loaded_weight )
1740
+ if fuse_qkv_a_proj and (
1741
+ "q_a_proj" in name or "kv_a_proj_with_mqa" in name
1742
+ ):
1743
+ cached_a_proj [name ] = loaded_weight
1744
+ q_a_proj_name = (
1745
+ name
1746
+ if "q_a_proj" in name
1747
+ else name .replace ("kv_a_proj_with_mqa" , "q_a_proj" )
1748
+ )
1749
+ kv_a_proj_name = (
1750
+ name
1751
+ if "kv_a_proj_with_mqa" in name
1752
+ else name .replace ("q_a_proj" , "kv_a_proj_with_mqa" )
1753
+ )
1754
+
1755
+ # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
1756
+ if (
1757
+ q_a_proj_name in cached_a_proj
1758
+ and kv_a_proj_name in cached_a_proj
1759
+ ):
1760
+
1761
+ q_a_proj_weight = cached_a_proj [q_a_proj_name ]
1762
+ kv_a_proj_weight = cached_a_proj [kv_a_proj_name ]
1763
+ fused_weight = torch .cat (
1764
+ [q_a_proj_weight , kv_a_proj_weight ], dim = 0
1765
+ )
1766
+
1767
+ param_name = name .replace (
1768
+ "q_a_proj" , "fused_qkv_a_proj_with_mqa"
1769
+ )
1770
+ param = params_dict [param_name ]
1771
+
1772
+ weight_loader = getattr (
1773
+ param , "weight_loader" , default_weight_loader
1774
+ )
1775
+ weight_loader (param , fused_weight )
1776
+ cached_a_proj .pop (q_a_proj_name )
1777
+ cached_a_proj .pop (kv_a_proj_name )
1778
+ else :
1779
+ param = params_dict [name ]
1780
+ weight_loader = getattr (
1781
+ param , "weight_loader" , default_weight_loader
1782
+ )
1783
+ weight_loader (param , loaded_weight )
1731
1784
1732
1785
self .post_load_weights ()
1733
1786
0 commit comments