Skip to content

Commit 1c9b26f

Browse files
committed
Configuration with k=128 that supports multiple scenarios
1 parent 856872e commit 1c9b26f

File tree

4 files changed

+52
-68
lines changed

4 files changed

+52
-68
lines changed

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ struct DefaultMma<float_e4m3_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignm
469469
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
470470
{
471471
private:
472-
using Mma = DefaultWint2xMma<float_e4m3_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, half_t,
472+
using Mma = DefaultWint2xMma<float_e4m3_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, bfloat16_t,
473473
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
474474
WarpShape, InstructionShape, 2, Operator>;
475475

@@ -517,7 +517,7 @@ struct DefaultMma<float_e4m3_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignm
517517
false, SharedMemoryClear>
518518
{
519519
private:
520-
using Mma = DefaultWint2xMma<float_e4m3_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, half_t,
520+
using Mma = DefaultWint2xMma<float_e4m3_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, bfloat16_t,
521521
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
522522
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
523523

custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -789,18 +789,18 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
789789

790790
using ElementSuperScale = typename Mma::QuantParamsAccessor::ElementSuperScale;
791791

792-
// static_assert(platform::is_same<ElementSuperScale, cutlass::half_t>::value,
793-
// "ElementSuperScale must be half_t");
792+
static_assert(platform::is_same<ElementSuperScale, cutlass::bfloat16_t>::value,
793+
"ElementSuperScale must be bfloat16_t");
794794

795-
// TODO, 多了一个reinterpret_cast
795+
// TODO("baoqiwen"), reinterpret_cast
796796
ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * gemm_n;
797797
typename Mma::QuantParamsAccessor::IteratorSuperScale iterator_super_scale(
798798
Mma::QuantParamsAccessor::LayoutSuperScale(gemm_n),
799799
reinterpret_cast<ElementSuperScale*>(weight_scale_ptr),
800800
{1, gemm_n},
801801
thread_idx,
802802
tb_offset_scale);
803-
803+
804804
int local_scale_pointer_offset = ((ThreadblockShape::kK + 127) / 128) * (gemm_n * 2);
805805
int64_t offset_in_bytes = problem_idx * gemm_k * gemm_n / 128;
806806
uint4b_t *local_scale_ptr = reinterpret_cast<uint4b_t *>(params.local_scale + offset_in_bytes);

custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h

Lines changed: 45 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,33 @@ void dispatch_gemm_config(const InType* A,
501501
occupancy); \
502502
break;
503503

504+
#define dispatch_gemm_config_with_k_macro(AA, BB, CC, DD, EE, FF, GG) \
505+
case CutlassTileConfig:: \
506+
CtaShape##AA##x##BB##x##CC##_WarpShape##DD##x##EE##x##FF: \
507+
dispatch_gemm_config<InType, \
508+
OutType, \
509+
WeightQuantTraits, \
510+
arch, \
511+
EpilogueTag, \
512+
cutlass::gemm::GemmShape<AA, BB, GG>, \
513+
cutlass::gemm::GemmShape<DD, EE, GG>>( \
514+
A, \
515+
B, \
516+
weight_scales, \
517+
biases, \
518+
C, \
519+
total_rows_before_expert, \
520+
total_rows, \
521+
gemm_n, \
522+
gemm_k, \
523+
num_experts, \
524+
quant_args_B, \
525+
gemm_config, \
526+
multi_processor_count, \
527+
stream, \
528+
occupancy); \
529+
break;
530+
504531
// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32.
505532
// This overload is only enabled when T == WeightType.
506533
template <typename InType,
@@ -574,11 +601,12 @@ void dispatch_moe_gemm_to_cutlass(const InType* A,
574601
int multi_processor_count,
575602
cudaStream_t stream,
576603
int* occupancy = nullptr) {
604+
constexpr int tile_shape_k = 128 * 8 / cutlass::sizeof_bits<InType>::value;
577605
if constexpr (std::is_same<arch, cutlass::arch::Sm70>::value) {
578606
if constexpr (WeightQuantTraits::kQuantMethod != cutlass::WintQuantMethod::kWeightOnlyInt2) {
579607
switch (gemm_config.tile_config) {
580-
dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64);
581-
dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64);
608+
dispatch_gemm_config_with_k_macro(32, 128, 64, 32, 32, 64, tile_shape_k);
609+
dispatch_gemm_config_with_k_macro(64, 128, 64, 64, 64, 64, tile_shape_k);
582610
case CutlassTileConfig::Undefined:
583611
throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
584612
break;
@@ -598,31 +626,26 @@ void dispatch_moe_gemm_to_cutlass(const InType* A,
598626
"[dispatch_moe_gemm_to_cutlass] weight_only_int2 does not support sm70.");
599627
}
600628
} else {
601-
constexpr int tile_shape_k = 128 * 8 / cutlass::sizeof_bits<InType>::value;
602-
CUTLASS_TRACE_HOST("tile_shape_k = " << tile_shape_k);
629+
// CUTLASS_TRACE_HOST("tile_shape_k = " << tile_shape_k);
603630
CUTLASS_TRACE_HOST("Current tile_config value = " << static_cast<int>(gemm_config.tile_config));
604631

605632

606633
switch (gemm_config.tile_config) {
607634
// dispatch_gemm_config_macro(16, 128, 128, 16, 32, 128);
608-
dispatch_gemm_config_macro(16, 256, 128, 16, 64, 128);
609-
610-
// if (tile_shape_k == 64) {
611-
// dispatch_gemm_config_macro(16, 128, 64, 16, 32, 64);
612-
// } else if (tile_shape_k == 128){
613-
// dispatch_gemm_config_macro(16, 128, 128, 16, 32, 128);
614-
// }
615-
// dispatch_gemm_config_macro(16, 128, tile_shape_k, 16, 32, tile_shape_k);
616-
// dispatch_gemm_config_macro(16, 256, 64, 16, 64, 64);
617-
// dispatch_gemm_config_macro(64, 64, 64, 32, 32, 64);
618-
// dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64);
619-
// dispatch_gemm_config_macro(128, 64, 64, 64, 32, 64);
620-
// dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64);
621-
// dispatch_gemm_config_macro(128, 128, 64, 64, 64, 64);
622-
// dispatch_gemm_config_macro(128, 128, 64, 128, 32, 64);
623-
// dispatch_gemm_config_macro(128, 256, 64, 64, 64, 64);
624-
// dispatch_gemm_config_macro(64, 128, 64, 64, 32, 64);
625-
// dispatch_gemm_config_macro(256, 128, 64, 64, 64, 64);
635+
// dispatch_gemm_config_macro(16, 256, 128, 16, 64, 128);
636+
// dispatch_gemm_config_macro(16, 128, 64, 16, 32, 64);
637+
638+
dispatch_gemm_config_with_k_macro(16, 128, 64, 16, 32, 64, tile_shape_k);
639+
// dispatch_gemm_config_with_k_macro(16, 256, 64, 16, 64, 64, tile_shape_k);
640+
// dispatch_gemm_config_with_k_macro(64, 64, 64, 32, 32, 64, tile_shape_k);
641+
// dispatch_gemm_config_with_k_macro(32, 128, 64, 32, 32, 64, tile_shape_k);
642+
// dispatch_gemm_config_with_k_macro(128, 64, 64, 64, 32, 64, tile_shape_k);
643+
// dispatch_gemm_config_with_k_macro(64, 128, 64, 64, 64, 64, tile_shape_k);
644+
// dispatch_gemm_config_with_k_macro(128, 128, 64, 64, 64, 64, tile_shape_k);
645+
// dispatch_gemm_config_with_k_macro(128, 128, 64, 128, 32, 64, tile_shape_k);
646+
// dispatch_gemm_config_with_k_macro(128, 256, 64, 64, 64, 64, tile_shape_k);
647+
// dispatch_gemm_config_with_k_macro(64, 128, 64, 64, 32, 64, tile_shape_k);
648+
// dispatch_gemm_config_with_k_macro(256, 128, 64, 64, 64, 64, tile_shape_k);
626649
case CutlassTileConfig::Undefined:
627650
throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
628651
break;
@@ -637,46 +660,6 @@ void dispatch_moe_gemm_to_cutlass(const InType* A,
637660
"mixed type tensorop GEMM.");
638661
break;
639662
}
640-
641-
// if (tile_shape_k == 64) {
642-
// switch (gemm_config.tile_config) {
643-
// // dispatch_gemm_config_macro(16, 128, 128, 16, 32, 128);
644-
// dispatch_gemm_config_macro(16, 128, 64, 16, 32, 64);
645-
646-
// case CutlassTileConfig::Undefined:
647-
// throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
648-
// break;
649-
// case CutlassTileConfig::ChooseWithHeuristic:
650-
// throw std::runtime_error(
651-
// "[dispatch_moe_gemm_to_cutlass] gemm config should have "
652-
// "already been set by heuristic.");
653-
// break;
654-
// default:
655-
// throw std::runtime_error(
656-
// "[dispatch_moe_gemm_to_cutlass] Config is invalid for "
657-
// "mixed type tensorop GEMM.");
658-
// break;
659-
// }
660-
// } else if (tile_shape_k == 128) {
661-
// switch (gemm_config.tile_config) {
662-
// dispatch_gemm_config_macro(16, 128, 128, 16, 32, 128);
663-
664-
// case CutlassTileConfig::Undefined:
665-
// throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
666-
// break;
667-
// case CutlassTileConfig::ChooseWithHeuristic:
668-
// throw std::runtime_error(
669-
// "[dispatch_moe_gemm_to_cutlass] gemm config should have "
670-
// "already been set by heuristic.");
671-
// break;
672-
// default:
673-
// throw std::runtime_error(
674-
// "[dispatch_moe_gemm_to_cutlass] Config is invalid for "
675-
// "mixed type tensorop GEMM.");
676-
// break;
677-
// }
678-
// }
679-
680663
}
681664
}
682665

custom_ops/gpu_ops/helper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ namespace cub = hipcub;
4242
#endif
4343
#include <fstream>
4444
#include <iostream>
45+
#include <cutlass/numeric_types.h>
4546

4647
#include "env.h"
4748
#include "paddle/extension.h"

0 commit comments

Comments
 (0)