Skip to content

Commit 8e09b37

Browse files
authored
Sgl kernel fused_moe_gate support n_shared_experts (#5440)
1 parent 53dcf38 commit 8e09b37

File tree

5 files changed

+140
-38
lines changed

5 files changed

+140
-38
lines changed

sgl-kernel/csrc/common_extension.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
146146
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
147147

148148
m.def(
149-
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk) -> "
149+
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
150+
"n_share_experts_fusion, float routed_scaling_factor) -> "
150151
"(Tensor[])");
151152
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
152153

sgl-kernel/csrc/moe/moe_fused_gate.cu

Lines changed: 81 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ __device__ void moe_fused_gate_impl(
5757
int64_t num_rows,
5858
int64_t topk_group,
5959
int64_t topk,
60+
int64_t n_share_experts_fusion,
61+
double routed_scaling_factor,
6062
Params params) {
6163
int tidx = threadIdx.x;
6264
int64_t thread_row =
@@ -65,6 +67,9 @@ __device__ void moe_fused_gate_impl(
6567
return;
6668
}
6769

70+
// Calculate topk_excluding_share_expert_fusion from topk
71+
int64_t topk_excluding_share_expert_fusion = topk - (n_share_experts_fusion > 0 ? 1 : 0);
72+
6873
// Cast pointers to type T:
6974
auto* input_ptr = reinterpret_cast<T*>(input);
7075
auto* bias_ptr = reinterpret_cast<T*>(bias);
@@ -163,7 +168,7 @@ __device__ void moe_fused_gate_impl(
163168

164169
////////////////////// Topk //////////////////////
165170
float output_sum = 0.0f;
166-
for (int k_idx = 0; k_idx < topk; ++k_idx) {
171+
for (int k_idx = 0; k_idx < topk_excluding_share_expert_fusion; ++k_idx) {
167172
// local argmax
168173
T max_val = bias_chunk[0];
169174
int expert = first_elt_read_by_thread;
@@ -181,7 +186,7 @@ __device__ void moe_fused_gate_impl(
181186
max_val = static_cast<T>(-FLT_MAX);
182187
}
183188

184-
// argmax reduce
189+
// argmax reduce
185190
#pragma unroll
186191
for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
187192
T other_max =
@@ -195,36 +200,46 @@ __device__ void moe_fused_gate_impl(
195200
}
196201
}
197202

198-
if (k_idx < topk) {
199-
int thread_to_clear_in_group = expert / params.VPT;
200-
int64_t idx = topk * thread_row + k_idx;
203+
int thread_to_clear_in_group = expert / params.VPT;
204+
int64_t idx = topk * thread_row + k_idx;
201205

202-
if (thread_group_idx == thread_to_clear_in_group) {
203-
int expert_to_clear_in_thread = expert % params.VPT;
206+
if (thread_group_idx == thread_to_clear_in_group) {
207+
int expert_to_clear_in_thread = expert % params.VPT;
204208

205-
// clear the max value in the thread
206-
bias_chunk[expert_to_clear_in_thread] = static_cast<T>(-FLT_MAX);
209+
// clear the max value in the thread
210+
bias_chunk[expert_to_clear_in_thread] = static_cast<T>(-FLT_MAX);
207211

208-
// store output
209-
output_ptr[idx] = static_cast<float>(row_chunk[expert_to_clear_in_thread]);
210-
indices_ptr[idx] = static_cast<int32_t>(expert);
211-
}
212+
// store output
213+
output_ptr[idx] = static_cast<float>(row_chunk[expert_to_clear_in_thread]);
214+
indices_ptr[idx] = static_cast<int32_t>(expert);
215+
}
212216

213-
// accumulate sum
214-
if (thread_group_idx == 0) {
215-
output_sum += output_ptr[idx];
216-
}
217+
// accumulate sum for all elements
218+
if (thread_group_idx == 0) {
219+
output_sum += output_ptr[idx];
217220
}
218221

219222
__syncthreads();
220223
}
221224

225+
if (thread_group_idx == 0 && n_share_experts_fusion > 0) {
226+
int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion;
227+
228+
// Use round-robin to select expert
229+
int64_t expert_offset = thread_row % n_share_experts_fusion;
230+
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
231+
232+
// Set the weight to the sum of all weights divided by routed_scaling_factor
233+
output_ptr[last_idx] = output_sum / routed_scaling_factor;
234+
}
235+
__syncthreads();
236+
222237
////////////////////// Rescale Output //////////////////////
223238
if (thread_group_idx == 0) {
224239
#pragma unroll
225240
for (int ii = 0; ii < topk; ++ii) {
226241
int64_t const idx = topk * thread_row + ii;
227-
output_ptr[idx] = static_cast<float>(static_cast<T>(output_ptr[idx]) / static_cast<T>(output_sum));
242+
output_ptr[idx] = output_ptr[idx] / output_sum;
228243
}
229244
}
230245
}
@@ -257,9 +272,21 @@ __global__ void moe_fused_gate_kernel(
257272
int32_t* indices_ptr,
258273
int64_t num_rows,
259274
int64_t topk_group,
260-
int64_t topk) {
275+
int64_t topk,
276+
int64_t n_share_experts_fusion,
277+
double routed_scaling_factor) {
261278
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
262-
moe_fused_gate_impl<T>(input, bias, output_ptr, indices_ptr, num_rows, topk_group, topk, params);
279+
moe_fused_gate_impl<T>(
280+
input,
281+
bias,
282+
output_ptr,
283+
indices_ptr,
284+
num_rows,
285+
topk_group,
286+
topk,
287+
n_share_experts_fusion,
288+
routed_scaling_factor,
289+
params);
263290
}
264291

265292
// Macro to compute compile-time constants and launch the kernel.
@@ -277,7 +304,9 @@ __global__ void moe_fused_gate_kernel(
277304
indices.data_ptr<int32_t>(), \
278305
num_rows, \
279306
topk_group, \
280-
topk); \
307+
topk, \
308+
n_share_experts_fusion, \
309+
routed_scaling_factor); \
281310
dispatched = true; \
282311
} while (0)
283312

@@ -303,7 +332,9 @@ __global__ void moe_fused_gate_kernel_dynamic(
303332
int64_t num_experts,
304333
int64_t num_expert_group,
305334
int64_t topk_group,
306-
int64_t topk) {
335+
int64_t topk,
336+
int64_t n_share_experts_fusion,
337+
double routed_scaling_factor) {
307338
KernelParamsDynamic params;
308339
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
309340
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
@@ -312,14 +343,30 @@ __global__ void moe_fused_gate_kernel_dynamic(
312343
params.ROWS_PER_WARP = std::max<int64_t>(1, WARP_SIZE / num_expert_group); // WARP_SIZE is fixed as 32
313344
params.ROWS_PER_CTA = params.WARPS_PER_CTA * params.ROWS_PER_WARP;
314345

315-
moe_fused_gate_impl<T>(input, bias, output_ptr, indices_ptr, num_rows, topk_group, topk, params);
346+
moe_fused_gate_impl<T>(
347+
input,
348+
bias,
349+
output_ptr,
350+
indices_ptr,
351+
num_rows,
352+
topk_group,
353+
topk,
354+
n_share_experts_fusion,
355+
routed_scaling_factor,
356+
params);
316357
}
317358

318359
//------------------------------------------------------------------------------
319360
// Host Launcher Function
320361
//------------------------------------------------------------------------------
321-
std::vector<at::Tensor>
322-
moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, int64_t topk_group, int64_t topk) {
362+
std::vector<at::Tensor> moe_fused_gate(
363+
at::Tensor& input,
364+
at::Tensor& bias,
365+
int64_t num_expert_group,
366+
int64_t topk_group,
367+
int64_t topk,
368+
int64_t n_share_experts_fusion,
369+
double routed_scaling_factor) {
323370
int64_t num_rows = input.size(0);
324371
int32_t num_experts = input.size(1);
325372
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
@@ -416,7 +463,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
416463
num_experts,
417464
num_expert_group,
418465
topk_group,
419-
topk);
466+
topk,
467+
n_share_experts_fusion,
468+
routed_scaling_factor);
420469
} else if (input.scalar_type() == at::kHalf) {
421470
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
422471
input.data_ptr(),
@@ -427,7 +476,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
427476
num_experts,
428477
num_expert_group,
429478
topk_group,
430-
topk);
479+
topk,
480+
n_share_experts_fusion,
481+
routed_scaling_factor);
431482
} else if (input.scalar_type() == at::kFloat) {
432483
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
433484
input.data_ptr(),
@@ -438,7 +489,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
438489
num_experts,
439490
num_expert_group,
440491
topk_group,
441-
topk);
492+
topk,
493+
n_share_experts_fusion,
494+
routed_scaling_factor);
442495
} else {
443496
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
444497
}

sgl-kernel/include/sgl_kernel_ops.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,14 @@ void topk_softmax(
200200
torch::Tensor& token_expert_indices,
201201
torch::Tensor& gating_output);
202202

203-
std::vector<at::Tensor>
204-
moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, int64_t topk_group, int64_t topk);
203+
std::vector<at::Tensor> moe_fused_gate(
204+
at::Tensor& input,
205+
at::Tensor& bias,
206+
int64_t num_expert_group,
207+
int64_t topk_group,
208+
int64_t topk,
209+
int64_t n_share_experts_fusion,
210+
double routed_scaling_factor);
205211

206212
/*
207213
* From csrc/speculative

sgl-kernel/python/sgl_kernel/moe.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,29 @@ def topk_softmax(
3434
)
3535

3636

37-
def moe_fused_gate(input_tensor, bias, num_expert_group, topk_group, topk):
37+
def moe_fused_gate(
38+
input_tensor,
39+
bias,
40+
num_expert_group,
41+
topk_group,
42+
topk,
43+
n_share_experts_fusion=0,
44+
routed_scaling_factor=0,
45+
):
3846
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
3947
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
4048
# as the group weight to select exerpt groups and then select topk experts within the selected groups
4149
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
4250
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now.
4351
# for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
52+
# n_share_experts_fusion: if > 0, the last expert will be replaced with a round-robin shared expert
53+
# routed_scaling_factor: if > 0, the last expert will be scaled by this factor
4454
return torch.ops.sgl_kernel.moe_fused_gate.default(
45-
input_tensor, bias, num_expert_group, topk_group, topk
55+
input_tensor,
56+
bias,
57+
num_expert_group,
58+
topk_group,
59+
topk,
60+
n_share_experts_fusion,
61+
routed_scaling_factor,
4662
)

sgl-kernel/tests/test_moe_fused_gate.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,24 @@
1919
(512, 16, 8, 16),
2020
],
2121
)
22-
def test_moe_fused_gate_combined(seq_length, dtype, params):
22+
@pytest.mark.parametrize("n_share_experts_fusion", [0, 1, 8, 16])
23+
def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusion):
2324
num_experts, num_expert_group, topk_group, topk = params
2425

2526
torch.manual_seed(seq_length)
2627
tensor = torch.rand((seq_length, num_experts)).to(dtype).cuda()
2728
scores = tensor.clone()
2829
bias = torch.rand(num_experts).to(dtype).cuda()
30+
topk = topk + min(1, n_share_experts_fusion)
2931

3032
output, indices = moe_fused_gate(
3133
tensor,
3234
bias,
3335
num_expert_group=num_expert_group,
3436
topk_group=topk_group,
3537
topk=topk,
38+
n_share_experts_fusion=n_share_experts_fusion,
39+
routed_scaling_factor=2.5,
3640
)
3741
ref_output, ref_indices = biased_grouped_topk(
3842
scores,
@@ -43,8 +47,30 @@ def test_moe_fused_gate_combined(seq_length, dtype, params):
4347
num_expert_group=num_expert_group,
4448
topk_group=topk_group,
4549
compiled=False,
50+
n_share_experts_fusion=n_share_experts_fusion,
4651
)
4752

53+
# When n_share_experts_fusion > 0, ignore the comparison of the last topk dimension
54+
if n_share_experts_fusion > 0:
55+
original_indices = indices.clone()
56+
original_ref_indices = ref_indices.clone()
57+
58+
indices = indices[:, :-1]
59+
ref_indices = ref_indices[:, :-1]
60+
61+
valid_min = num_experts
62+
valid_max = num_experts + n_share_experts_fusion
63+
shared_indices = original_indices[:, -1]
64+
shared_ref_indices = original_ref_indices[:, -1]
65+
if shared_indices is not None:
66+
assert torch.all(
67+
(shared_indices >= valid_min) & (shared_indices < valid_max)
68+
), f"Shared expert indices out of range: found values outside [{valid_min}, {valid_max})"
69+
if shared_ref_indices is not None:
70+
assert torch.all(
71+
(shared_ref_indices >= valid_min) & (shared_ref_indices < valid_max)
72+
), f"Shared expert reference indices out of range: found values outside [{valid_min}, {valid_max})"
73+
4874
idx_check = torch.allclose(
4975
ref_indices.sort()[0].to(torch.int32),
5076
indices.sort()[0].to(torch.int32),
@@ -54,17 +80,17 @@ def test_moe_fused_gate_combined(seq_length, dtype, params):
5480
output_check = torch.allclose(
5581
ref_output.sort()[0].to(torch.float32),
5682
output.sort()[0].to(torch.float32),
57-
rtol=1e-04,
58-
atol=1e-05,
83+
rtol=1e-02,
84+
atol=1e-03,
5985
)
6086

6187
assert idx_check, (
6288
f"Indices mismatch at seq_length {seq_length}, dtype {dtype}, "
63-
f"params {params}"
89+
f"params {params}, n_share_experts_fusion {n_share_experts_fusion}"
6490
)
6591
assert output_check, (
6692
f"Output mismatch at seq_length {seq_length}, dtype {dtype}, "
67-
f"params {params}"
93+
f"params {params}, n_share_experts_fusion {n_share_experts_fusion}"
6894
)
6995

7096

0 commit comments

Comments
 (0)