Skip to content

Commit a303325

Browse files
fzyzcjych-wan
andauthored
Fix DeepSeek bug causing 2.2% MMLU drop when TP!=DP (#4883)
Co-authored-by: ch-wan <cwan39@gatech.edu>
1 parent 42873ea commit a303325

File tree

1 file changed

+10
-25
lines changed

1 file changed

+10
-25
lines changed

python/sglang/srt/models/deepseek_v2.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,29 +1102,17 @@ def forward_normal(
11021102
else:
11031103
hidden_states, residual = self.input_layernorm(hidden_states, residual)
11041104

1105+
assert not (
1106+
self.attn_tp_size != 1 and self.input_is_scattered
1107+
), "moe_layer_freq > 1 is not supported when attn_tp_size > 1"
1108+
11051109
# Self Attention
11061110
hidden_states = self.self_attn(
11071111
positions=positions,
11081112
hidden_states=hidden_states,
11091113
forward_batch=forward_batch,
11101114
)
11111115

1112-
if self.attn_tp_size != 1 and self.input_is_scattered:
1113-
hidden_states, local_hidden_states = (
1114-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1115-
hidden_states,
1116-
)
1117-
tp_all_gather(
1118-
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
1119-
)
1120-
residual, local_residual = (
1121-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1122-
residual,
1123-
)
1124-
tp_all_gather(
1125-
list(residual.tensor_split(self.attn_tp_size)), local_residual
1126-
)
1127-
11281116
# Gather
11291117
if get_tensor_model_parallel_world_size() > 1:
11301118
# all gather and all reduce
@@ -1223,26 +1211,20 @@ def forward_deepep(
12231211
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
12241212

12251213
if self.is_last_layer and self.attn_tp_size != 1:
1214+
hidden_states += residual
1215+
residual = None
12261216
hidden_states, local_hidden_states = (
12271217
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
12281218
hidden_states,
12291219
)
12301220
tp_all_gather(
12311221
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
12321222
)
1233-
residual, local_residual = (
1234-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1235-
residual,
1236-
)
1237-
tp_all_gather(
1238-
list(residual.tensor_split(self.attn_tp_size)), local_residual
1239-
)
12401223

12411224
return hidden_states, residual
12421225

12431226

12441227
class DeepseekV2Model(nn.Module):
1245-
12461228
fall_back_to_pt_during_load = False
12471229

12481230
def __init__(
@@ -1296,7 +1278,10 @@ def forward(
12961278
positions, hidden_states, forward_batch, residual
12971279
)
12981280
if not forward_batch.forward_mode.is_idle():
1299-
hidden_states, _ = self.norm(hidden_states, residual)
1281+
if residual is None:
1282+
hidden_states = self.norm(hidden_states)
1283+
else:
1284+
hidden_states, _ = self.norm(hidden_states, residual)
13001285
return hidden_states
13011286

13021287

0 commit comments

Comments
 (0)