@@ -1084,13 +1084,16 @@ def forward_absorb_core(
1084
1084
masked_m ,
1085
1085
expected_m ,
1086
1086
)
1087
- attn_bmm_output = attn_bmm_output [:, :expected_m , :]
1087
+ attn_bmm_output = (
1088
+ attn_bmm_output [:, :expected_m , :].transpose (0 , 1 ).flatten (1 , 2 )
1089
+ )
1088
1090
elif _is_hip :
1089
1091
# TODO(haishaw): add bmm_fp8 to ROCm
1090
1092
attn_bmm_output = torch .bmm (
1091
1093
attn_output .to (torch .bfloat16 ).transpose (0 , 1 ),
1092
1094
self .w_vc .to (torch .bfloat16 ) * self .w_scale ,
1093
1095
)
1096
+ attn_bmm_output = attn_bmm_output .transpose (0 , 1 ).flatten (1 , 2 )
1094
1097
elif self .w_vc .dtype == torch .float8_e4m3fn :
1095
1098
attn_output_val , attn_output_scale = per_tensor_quant_mla_fp8 (
1096
1099
attn_output .transpose (0 , 1 ),
@@ -1103,10 +1106,21 @@ def forward_absorb_core(
1103
1106
self .w_scale ,
1104
1107
torch .bfloat16 ,
1105
1108
)
1109
+ attn_bmm_output = attn_bmm_output .transpose (0 , 1 ).flatten (1 , 2 )
1106
1110
else :
1107
- attn_bmm_output = torch .bmm (attn_output .transpose (0 , 1 ), self .w_vc )
1108
- attn_output = attn_bmm_output .transpose (0 , 1 ).flatten (1 , 2 )
1109
- output , _ = self .o_proj (attn_output )
1111
+ attn_bmm_output = torch .empty (
1112
+ (attn_output .shape [0 ], self .num_local_heads * self .v_head_dim ),
1113
+ dtype = attn_output .dtype ,
1114
+ device = attn_output .device ,
1115
+ )
1116
+ torch .bmm (
1117
+ attn_output .transpose (0 , 1 ),
1118
+ self .w_vc ,
1119
+ out = attn_bmm_output .view (
1120
+ - 1 , self .num_local_heads , self .v_head_dim
1121
+ ).transpose (0 , 1 ),
1122
+ )
1123
+ output , _ = self .o_proj (attn_bmm_output )
1110
1124
1111
1125
return output
1112
1126
0 commit comments