File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -437,21 +437,21 @@ def forward(
437
437
def forward_normal_dual_stream (
438
438
self , hidden_states : torch .Tensor , can_fuse_mlp_allreduce : bool = False
439
439
) -> torch .Tensor :
440
- # router_logits: (num_tokens, n_experts)
441
- router_logits = self .gate (hidden_states )
442
440
443
441
current_stream = torch .cuda .current_stream ()
444
442
self .alt_stream .wait_stream (current_stream )
445
443
shared_output = self ._forward_shared_experts (hidden_states )
446
444
447
445
with torch .cuda .stream (self .alt_stream ):
446
+ # router_logits: (num_tokens, n_experts)
447
+ router_logits = self .gate (hidden_states )
448
448
final_hidden_states = self .experts (
449
449
hidden_states = hidden_states , router_logits = router_logits
450
450
)
451
451
if not _is_cuda :
452
452
final_hidden_states *= self .routed_scaling_factor
453
453
current_stream .wait_stream (self .alt_stream )
454
- final_hidden_states = final_hidden_states + shared_output
454
+ final_hidden_states += shared_output
455
455
if self .tp_size > 1 and not can_fuse_mlp_allreduce :
456
456
final_hidden_states = tensor_model_parallel_all_reduce (final_hidden_states )
457
457
return final_hidden_states
You can’t perform that action at this time.
0 commit comments