@@ -57,6 +57,8 @@ __device__ void moe_fused_gate_impl(
57
57
int64_t num_rows,
58
58
int64_t topk_group,
59
59
int64_t topk,
60
+ int64_t n_share_experts_fusion,
61
+ double routed_scaling_factor,
60
62
Params params) {
61
63
int tidx = threadIdx .x ;
62
64
int64_t thread_row =
@@ -65,6 +67,9 @@ __device__ void moe_fused_gate_impl(
65
67
return ;
66
68
}
67
69
70
+ // Calculate topk_excluding_share_expert_fusion from topk
71
+ int64_t topk_excluding_share_expert_fusion = topk - (n_share_experts_fusion > 0 ? 1 : 0 );
72
+
68
73
// Cast pointers to type T:
69
74
auto * input_ptr = reinterpret_cast <T*>(input);
70
75
auto * bias_ptr = reinterpret_cast <T*>(bias);
@@ -163,7 +168,7 @@ __device__ void moe_fused_gate_impl(
163
168
164
169
// //////////////////// Topk //////////////////////
165
170
float output_sum = 0 .0f ;
166
- for (int k_idx = 0 ; k_idx < topk ; ++k_idx) {
171
+ for (int k_idx = 0 ; k_idx < topk_excluding_share_expert_fusion ; ++k_idx) {
167
172
// local argmax
168
173
T max_val = bias_chunk[0 ];
169
174
int expert = first_elt_read_by_thread;
@@ -181,7 +186,7 @@ __device__ void moe_fused_gate_impl(
181
186
max_val = static_cast <T>(-FLT_MAX);
182
187
}
183
188
184
- // argmax reduce
189
+ // argmax reduce
185
190
#pragma unroll
186
191
for (int mask = params.THREADS_PER_ROW / 2 ; mask > 0 ; mask /= 2 ) {
187
192
T other_max =
@@ -195,36 +200,46 @@ __device__ void moe_fused_gate_impl(
195
200
}
196
201
}
197
202
198
- if (k_idx < topk) {
199
- int thread_to_clear_in_group = expert / params.VPT ;
200
- int64_t idx = topk * thread_row + k_idx;
203
+ int thread_to_clear_in_group = expert / params.VPT ;
204
+ int64_t idx = topk * thread_row + k_idx;
201
205
202
- if (thread_group_idx == thread_to_clear_in_group) {
203
- int expert_to_clear_in_thread = expert % params.VPT ;
206
+ if (thread_group_idx == thread_to_clear_in_group) {
207
+ int expert_to_clear_in_thread = expert % params.VPT ;
204
208
205
- // clear the max value in the thread
206
- bias_chunk[expert_to_clear_in_thread] = static_cast <T>(-FLT_MAX);
209
+ // clear the max value in the thread
210
+ bias_chunk[expert_to_clear_in_thread] = static_cast <T>(-FLT_MAX);
207
211
208
- // store output
209
- output_ptr[idx] = static_cast <float >(row_chunk[expert_to_clear_in_thread]);
210
- indices_ptr[idx] = static_cast <int32_t >(expert);
211
- }
212
+ // store output
213
+ output_ptr[idx] = static_cast <float >(row_chunk[expert_to_clear_in_thread]);
214
+ indices_ptr[idx] = static_cast <int32_t >(expert);
215
+ }
212
216
213
- // accumulate sum
214
- if (thread_group_idx == 0 ) {
215
- output_sum += output_ptr[idx];
216
- }
217
+ // accumulate sum for all elements
218
+ if (thread_group_idx == 0 ) {
219
+ output_sum += output_ptr[idx];
217
220
}
218
221
219
222
__syncthreads ();
220
223
}
221
224
225
+ if (thread_group_idx == 0 && n_share_experts_fusion > 0 ) {
226
+ int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion;
227
+
228
+ // Use round-robin to select expert
229
+ int64_t expert_offset = thread_row % n_share_experts_fusion;
230
+ indices_ptr[last_idx] = static_cast <int32_t >(params.NUM_EXPERTS + expert_offset);
231
+
232
+ // Set the weight to the sum of all weights divided by routed_scaling_factor
233
+ output_ptr[last_idx] = output_sum / routed_scaling_factor;
234
+ }
235
+ __syncthreads ();
236
+
222
237
// //////////////////// Rescale Output //////////////////////
223
238
if (thread_group_idx == 0 ) {
224
239
#pragma unroll
225
240
for (int ii = 0 ; ii < topk; ++ii) {
226
241
int64_t const idx = topk * thread_row + ii;
227
- output_ptr[idx] = static_cast < float >( static_cast <T>( output_ptr[idx]) / static_cast <T>( output_sum)) ;
242
+ output_ptr[idx] = output_ptr[idx] / output_sum;
228
243
}
229
244
}
230
245
}
@@ -257,9 +272,21 @@ __global__ void moe_fused_gate_kernel(
257
272
int32_t * indices_ptr,
258
273
int64_t num_rows,
259
274
int64_t topk_group,
260
- int64_t topk) {
275
+ int64_t topk,
276
+ int64_t n_share_experts_fusion,
277
+ double routed_scaling_factor) {
261
278
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
262
- moe_fused_gate_impl<T>(input, bias, output_ptr, indices_ptr, num_rows, topk_group, topk, params);
279
+ moe_fused_gate_impl<T>(
280
+ input,
281
+ bias,
282
+ output_ptr,
283
+ indices_ptr,
284
+ num_rows,
285
+ topk_group,
286
+ topk,
287
+ n_share_experts_fusion,
288
+ routed_scaling_factor,
289
+ params);
263
290
}
264
291
265
292
// Macro to compute compile-time constants and launch the kernel.
@@ -277,7 +304,9 @@ __global__ void moe_fused_gate_kernel(
277
304
indices.data_ptr <int32_t >(), \
278
305
num_rows, \
279
306
topk_group, \
280
- topk); \
307
+ topk, \
308
+ n_share_experts_fusion, \
309
+ routed_scaling_factor); \
281
310
dispatched = true ; \
282
311
} while (0 )
283
312
@@ -303,7 +332,9 @@ __global__ void moe_fused_gate_kernel_dynamic(
303
332
int64_t num_experts,
304
333
int64_t num_expert_group,
305
334
int64_t topk_group,
306
- int64_t topk) {
335
+ int64_t topk,
336
+ int64_t n_share_experts_fusion,
337
+ double routed_scaling_factor) {
307
338
KernelParamsDynamic params;
308
339
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
309
340
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
@@ -312,14 +343,30 @@ __global__ void moe_fused_gate_kernel_dynamic(
312
343
params.ROWS_PER_WARP = std::max<int64_t >(1 , WARP_SIZE / num_expert_group); // WARP_SIZE is fixed as 32
313
344
params.ROWS_PER_CTA = params.WARPS_PER_CTA * params.ROWS_PER_WARP ;
314
345
315
- moe_fused_gate_impl<T>(input, bias, output_ptr, indices_ptr, num_rows, topk_group, topk, params);
346
+ moe_fused_gate_impl<T>(
347
+ input,
348
+ bias,
349
+ output_ptr,
350
+ indices_ptr,
351
+ num_rows,
352
+ topk_group,
353
+ topk,
354
+ n_share_experts_fusion,
355
+ routed_scaling_factor,
356
+ params);
316
357
}
317
358
318
359
// ------------------------------------------------------------------------------
319
360
// Host Launcher Function
320
361
// ------------------------------------------------------------------------------
321
- std::vector<at::Tensor>
322
- moe_fused_gate (at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, int64_t topk_group, int64_t topk) {
362
+ std::vector<at::Tensor> moe_fused_gate (
363
+ at::Tensor& input,
364
+ at::Tensor& bias,
365
+ int64_t num_expert_group,
366
+ int64_t topk_group,
367
+ int64_t topk,
368
+ int64_t n_share_experts_fusion,
369
+ double routed_scaling_factor) {
323
370
int64_t num_rows = input.size (0 );
324
371
int32_t num_experts = input.size (1 );
325
372
auto options = torch::TensorOptions ().dtype (torch::kFloat32 ).device (torch::kCUDA );
@@ -416,7 +463,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
416
463
num_experts,
417
464
num_expert_group,
418
465
topk_group,
419
- topk);
466
+ topk,
467
+ n_share_experts_fusion,
468
+ routed_scaling_factor);
420
469
} else if (input.scalar_type () == at::kHalf ) {
421
470
moe_fused_gate_kernel_dynamic<float16_t ><<<num_blocks, block_dim, 0 , stream>>> (
422
471
input.data_ptr (),
@@ -427,7 +476,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
427
476
num_experts,
428
477
num_expert_group,
429
478
topk_group,
430
- topk);
479
+ topk,
480
+ n_share_experts_fusion,
481
+ routed_scaling_factor);
431
482
} else if (input.scalar_type () == at::kFloat ) {
432
483
moe_fused_gate_kernel_dynamic<float32_t ><<<num_blocks, block_dim, 0 , stream>>> (
433
484
input.data_ptr (),
@@ -438,7 +489,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
438
489
num_experts,
439
490
num_expert_group,
440
491
topk_group,
441
- topk);
492
+ topk,
493
+ n_share_experts_fusion,
494
+ routed_scaling_factor);
442
495
} else {
443
496
TORCH_CHECK (false , " Unsupported data type for moe_fused_gate" );
444
497
}
0 commit comments