Skip to content

Commit d4d53e9

Browse files
committed
Add MixedMoeGemmRunner struct to support MOE with multiple data types.
1 parent 1c9b26f commit d4d53e9

11 files changed

+1723
-694
lines changed

custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -784,13 +784,13 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
784784
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
785785
cutlass::MatrixCoord tb_offset_local_scale{0, threadblock_offset.n() * 2};
786786

787-
// static_assert(platform::is_same<ElementScale, cutlass::float_e4m3_t>::value,
788-
// "ElementScale must be float_e4m3_t");
787+
// static_assert(platform::is_same<ElementScale, cutlass::float_e4m3_t>::value,
788+
// "ElementScale must be float_e4m3_t");
789789

790790
using ElementSuperScale = typename Mma::QuantParamsAccessor::ElementSuperScale;
791791

792-
static_assert(platform::is_same<ElementSuperScale, cutlass::bfloat16_t>::value,
793-
"ElementSuperScale must be bfloat16_t");
792+
// static_assert(platform::is_same<ElementSuperScale, cutlass::bfloat16_t>::value,
793+
// "ElementSuperScale must be bfloat16_t");
794794

795795
// TODO("baoqiwen"), reinterpret_cast
796796
ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * gemm_n;

custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424

2525
namespace phi {
2626

27-
template <typename InType,
28-
typename OutType,
27+
template <typename T, /*The type used for activations/scales/compute*/
2928
typename WeightQuantTraits /* The quant traits for the MoE weights */>
3029
class MoeGemmRunner {
3130
public:
@@ -34,11 +33,11 @@ class MoeGemmRunner {
3433

3534
MoeGemmRunner();
3635

37-
void moe_gemm_bias_act(const InType* A,
36+
void moe_gemm_bias_act(const T* A,
3837
const WeightType* B,
39-
const OutType* weight_scales,
40-
const OutType* biases,
41-
OutType* C,
38+
const T* weight_scales,
39+
const T* biases,
40+
T* C,
4241
int64_t* total_rows_before_expert,
4342
int64_t total_rows,
4443
int64_t tune_total_rows,
@@ -49,10 +48,10 @@ class MoeGemmRunner {
4948
std::string activation_type,
5049
cudaStream_t stream);
5150

52-
void moe_gemm(const InType* A,
51+
void moe_gemm(const T* A,
5352
const WeightType* B,
54-
const OutType* weight_scales,
55-
OutType* C,
53+
const T* weight_scales,
54+
T* C,
5655
int64_t* total_rows_before_expert,
5756
int64_t total_rows,
5857
int64_t tune_total_rows,
@@ -64,11 +63,11 @@ class MoeGemmRunner {
6463

6564
private:
6665
template <typename EpilogueTag>
67-
void dispatch_to_arch(const InType* A,
66+
void dispatch_to_arch(const T* A,
6867
const WeightType* B,
69-
const OutType* weight_scales,
70-
const OutType* biases,
71-
OutType* C,
68+
const T* weight_scales,
69+
const T* biases,
70+
T* C,
7271
int64_t* total_rows_before_expert,
7372
int64_t total_rows,
7473
int64_t gemm_n,
@@ -80,11 +79,11 @@ class MoeGemmRunner {
8079
int* occupancy = nullptr);
8180

8281
template <typename EpilogueTag>
83-
void run_gemm(const InType* A,
82+
void run_gemm(const T* A,
8483
const WeightType* B,
85-
const OutType* weight_scales,
86-
const OutType* biases,
87-
OutType* C,
84+
const T* weight_scales,
85+
const T* biases,
86+
T* C,
8887
int64_t* total_rows_before_expert,
8988
int64_t total_rows,
9089
int64_t tune_total_rows,
@@ -99,4 +98,4 @@ class MoeGemmRunner {
9998
int multi_processor_count_;
10099
};
101100

102-
} // namespace phi
101+
} // namespace phi

custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int2.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ namespace phi {
2323

2424
#ifdef PADDLE_CUDA_BF16
2525
template class MoeGemmRunner<
26-
__nv_bfloat16,
2726
__nv_bfloat16,
2827
cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kWeightOnlyInt2>>;
2928
#endif

custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int2.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,6 @@
2222
namespace phi {
2323

2424
template class MoeGemmRunner<
25-
half, half, cutlass::WintQuantTraits<half, cutlass::WintQuantMethod::kWeightOnlyInt2>>;
25+
half, cutlass::WintQuantTraits<half, cutlass::WintQuantMethod::kWeightOnlyInt2>>;
2626

2727
} // namespace phi

custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp8_int2_bf16.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
*/
1616

1717
#pragma once
18-
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
19-
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h"
18+
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_mixed_io_kernels.h"
19+
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_mixed_io_template.h"
2020
#include "helper.h"
2121

2222
namespace phi {
2323

2424
#ifdef PADDLE_CUDA_BF16
25-
template class MoeGemmRunner<
25+
template class MixedMoeGemmRunner<
2626
cutlass::float_e4m3_t,
2727
__nv_bfloat16,
2828
cutlass::WintQuantTraits<cutlass::float_e4m3_t, cutlass::WintQuantMethod::kWeightOnlyInt2>>;

custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp8_int2_fp16.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
*/
1616

1717
#pragma once
18-
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
19-
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h"
18+
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_mixed_io_kernels.h"
19+
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_mixed_io_template.h"
2020
#include "helper.h"
2121

2222
namespace phi {
2323

24-
template class MoeGemmRunner<
24+
template class MixedMoeGemmRunner<
2525
cutlass::float_e4m3_t,
2626
half,
2727
cutlass::WintQuantTraits<cutlass::float_e4m3_t, cutlass::WintQuantMethod::kWeightOnlyInt2>>;

0 commit comments

Comments
 (0)