Skip to content

Commit 50f1b6d

Browse files
authored
Remove copy after bmm (sgl-project#7441)
1 parent 5962e70 commit 50f1b6d

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

python/sglang/srt/models/deepseek_v2.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,13 +1084,16 @@ def forward_absorb_core(
10841084
masked_m,
10851085
expected_m,
10861086
)
1087-
attn_bmm_output = attn_bmm_output[:, :expected_m, :]
1087+
attn_bmm_output = (
1088+
attn_bmm_output[:, :expected_m, :].transpose(0, 1).flatten(1, 2)
1089+
)
10881090
elif _is_hip:
10891091
# TODO(haishaw): add bmm_fp8 to ROCm
10901092
attn_bmm_output = torch.bmm(
10911093
attn_output.to(torch.bfloat16).transpose(0, 1),
10921094
self.w_vc.to(torch.bfloat16) * self.w_scale,
10931095
)
1096+
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
10941097
elif self.w_vc.dtype == torch.float8_e4m3fn:
10951098
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
10961099
attn_output.transpose(0, 1),
@@ -1103,10 +1106,21 @@ def forward_absorb_core(
11031106
self.w_scale,
11041107
torch.bfloat16,
11051108
)
1109+
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
11061110
else:
1107-
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
1108-
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1109-
output, _ = self.o_proj(attn_output)
1111+
attn_bmm_output = torch.empty(
1112+
(attn_output.shape[0], self.num_local_heads * self.v_head_dim),
1113+
dtype=attn_output.dtype,
1114+
device=attn_output.device,
1115+
)
1116+
torch.bmm(
1117+
attn_output.transpose(0, 1),
1118+
self.w_vc,
1119+
out=attn_bmm_output.view(
1120+
-1, self.num_local_heads, self.v_head_dim
1121+
).transpose(0, 1),
1122+
)
1123+
output, _ = self.o_proj(attn_bmm_output)
11101124

11111125
return output
11121126

0 commit comments

Comments
 (0)