Skip to content

Commit 98aa836

Browse files
authored
Overlap the gating function with shared experts in DeepSeek (sgl-project#7978)
1 parent 22bd857 commit 98aa836

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

python/sglang/srt/models/deepseek_v2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,21 +437,21 @@ def forward(
437437
def forward_normal_dual_stream(
438438
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
439439
) -> torch.Tensor:
440-
# router_logits: (num_tokens, n_experts)
441-
router_logits = self.gate(hidden_states)
442440

443441
current_stream = torch.cuda.current_stream()
444442
self.alt_stream.wait_stream(current_stream)
445443
shared_output = self._forward_shared_experts(hidden_states)
446444

447445
with torch.cuda.stream(self.alt_stream):
446+
# router_logits: (num_tokens, n_experts)
447+
router_logits = self.gate(hidden_states)
448448
final_hidden_states = self.experts(
449449
hidden_states=hidden_states, router_logits=router_logits
450450
)
451451
if not _is_cuda:
452452
final_hidden_states *= self.routed_scaling_factor
453453
current_stream.wait_stream(self.alt_stream)
454-
final_hidden_states = final_hidden_states + shared_output
454+
final_hidden_states += shared_output
455455
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
456456
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
457457
return final_hidden_states

0 commit comments

Comments
 (0)