Skip to content

Commit 7151194

Browse files
authored
Remove cumsum_buffer initilization (sgl-project#7439)
1 parent 2ed68d7 commit 7151194

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -750,9 +750,11 @@ def moe_align_block_size(
750750
by block_size for proper block matrix operations.
751751
"""
752752
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
753-
sorted_ids, cumsum_buffer = init_sorted_ids_and_cumsum_buffer(
754-
max_num_tokens_padded, topk_ids.numel(), num_experts, topk_ids.device
753+
sorted_ids = torch.empty(
754+
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
755755
)
756+
sorted_ids.fill_(topk_ids.numel())
757+
756758
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
757759
expert_ids = torch.empty(
758760
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
@@ -768,6 +770,9 @@ def moe_align_block_size(
768770
num_tokens_post_pad,
769771
)
770772
else:
773+
cumsum_buffer = torch.empty(
774+
(num_experts + 1,), dtype=torch.int32, device=topk_ids.device
775+
)
771776
token_cnts_buffer = torch.empty(
772777
(num_experts + 1) * num_experts,
773778
dtype=torch.int32,

0 commit comments

Comments
 (0)