File tree Expand file tree Collapse file tree 1 file changed +7
-2
lines changed
python/sglang/srt/layers/moe/fused_moe_triton Expand file tree Collapse file tree 1 file changed +7
-2
lines changed Original file line number Diff line number Diff line change @@ -750,9 +750,11 @@ def moe_align_block_size(
750
750
by block_size for proper block matrix operations.
751
751
"""
752
752
max_num_tokens_padded = topk_ids .numel () + num_experts * (block_size - 1 )
753
- sorted_ids , cumsum_buffer = init_sorted_ids_and_cumsum_buffer (
754
- max_num_tokens_padded , topk_ids . numel ( ), num_experts , topk_ids .device
753
+ sorted_ids = torch . empty (
754
+ ( max_num_tokens_padded ,), dtype = torch . int32 , device = topk_ids .device
755
755
)
756
+ sorted_ids .fill_ (topk_ids .numel ())
757
+
756
758
max_num_m_blocks = triton .cdiv (max_num_tokens_padded , block_size )
757
759
expert_ids = torch .empty (
758
760
(max_num_m_blocks ,), dtype = torch .int32 , device = topk_ids .device
@@ -768,6 +770,9 @@ def moe_align_block_size(
768
770
num_tokens_post_pad ,
769
771
)
770
772
else :
773
+ cumsum_buffer = torch .empty (
774
+ (num_experts + 1 ,), dtype = torch .int32 , device = topk_ids .device
775
+ )
771
776
token_cnts_buffer = torch .empty (
772
777
(num_experts + 1 ) * num_experts ,
773
778
dtype = torch .int32 ,
You can’t perform that action at this time.
0 commit comments