@@ -1102,29 +1102,17 @@ def forward_normal(
1102
1102
else :
1103
1103
hidden_states , residual = self .input_layernorm (hidden_states , residual )
1104
1104
1105
+ assert not (
1106
+ self .attn_tp_size != 1 and self .input_is_scattered
1107
+ ), "moe_layer_freq > 1 is not supported when attn_tp_size > 1"
1108
+
1105
1109
# Self Attention
1106
1110
hidden_states = self .self_attn (
1107
1111
positions = positions ,
1108
1112
hidden_states = hidden_states ,
1109
1113
forward_batch = forward_batch ,
1110
1114
)
1111
1115
1112
- if self .attn_tp_size != 1 and self .input_is_scattered :
1113
- hidden_states , local_hidden_states = (
1114
- forward_batch .gathered_buffer [: forward_batch .input_ids .shape [0 ]],
1115
- hidden_states ,
1116
- )
1117
- tp_all_gather (
1118
- list (hidden_states .tensor_split (self .attn_tp_size )), local_hidden_states
1119
- )
1120
- residual , local_residual = (
1121
- forward_batch .gathered_buffer [: forward_batch .input_ids .shape [0 ]],
1122
- residual ,
1123
- )
1124
- tp_all_gather (
1125
- list (residual .tensor_split (self .attn_tp_size )), local_residual
1126
- )
1127
-
1128
1116
# Gather
1129
1117
if get_tensor_model_parallel_world_size () > 1 :
1130
1118
# all gather and all reduce
@@ -1223,26 +1211,20 @@ def forward_deepep(
1223
1211
hidden_states = self .mlp (hidden_states , forward_batch .forward_mode )
1224
1212
1225
1213
if self .is_last_layer and self .attn_tp_size != 1 :
1214
+ hidden_states += residual
1215
+ residual = None
1226
1216
hidden_states , local_hidden_states = (
1227
1217
forward_batch .gathered_buffer [: forward_batch .input_ids .shape [0 ]],
1228
1218
hidden_states ,
1229
1219
)
1230
1220
tp_all_gather (
1231
1221
list (hidden_states .tensor_split (self .attn_tp_size )), local_hidden_states
1232
1222
)
1233
- residual , local_residual = (
1234
- forward_batch .gathered_buffer [: forward_batch .input_ids .shape [0 ]],
1235
- residual ,
1236
- )
1237
- tp_all_gather (
1238
- list (residual .tensor_split (self .attn_tp_size )), local_residual
1239
- )
1240
1223
1241
1224
return hidden_states , residual
1242
1225
1243
1226
1244
1227
class DeepseekV2Model (nn .Module ):
1245
-
1246
1228
fall_back_to_pt_during_load = False
1247
1229
1248
1230
def __init__ (
@@ -1296,7 +1278,10 @@ def forward(
1296
1278
positions , hidden_states , forward_batch , residual
1297
1279
)
1298
1280
if not forward_batch .forward_mode .is_idle ():
1299
- hidden_states , _ = self .norm (hidden_states , residual )
1281
+ if residual is None :
1282
+ hidden_states = self .norm (hidden_states )
1283
+ else :
1284
+ hidden_states , _ = self .norm (hidden_states , residual )
1300
1285
return hidden_states
1301
1286
1302
1287
0 commit comments