@@ -501,6 +501,33 @@ void dispatch_gemm_config(const InType* A,
501
501
occupancy); \
502
502
break ;
503
503
504
+ #define dispatch_gemm_config_with_k_macro (AA, BB, CC, DD, EE, FF, GG ) \
505
+ case CutlassTileConfig:: \
506
+ CtaShape##AA##x##BB##x##CC##_WarpShape##DD##x##EE##x##FF: \
507
+ dispatch_gemm_config<InType, \
508
+ OutType, \
509
+ WeightQuantTraits, \
510
+ arch, \
511
+ EpilogueTag, \
512
+ cutlass::gemm::GemmShape<AA, BB, GG>, \
513
+ cutlass::gemm::GemmShape<DD, EE, GG>>( \
514
+ A, \
515
+ B, \
516
+ weight_scales, \
517
+ biases, \
518
+ C, \
519
+ total_rows_before_expert, \
520
+ total_rows, \
521
+ gemm_n, \
522
+ gemm_k, \
523
+ num_experts, \
524
+ quant_args_B, \
525
+ gemm_config, \
526
+ multi_processor_count, \
527
+ stream, \
528
+ occupancy); \
529
+ break ;
530
+
504
531
// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32.
505
532
// This overload is only enabled when T == WeightType.
506
533
template <typename InType,
@@ -574,11 +601,12 @@ void dispatch_moe_gemm_to_cutlass(const InType* A,
574
601
int multi_processor_count,
575
602
cudaStream_t stream,
576
603
int * occupancy = nullptr ) {
604
+ constexpr int tile_shape_k = 128 * 8 / cutlass::sizeof_bits<InType>::value;
577
605
if constexpr (std::is_same<arch, cutlass::arch::Sm70>::value) {
578
606
if constexpr (WeightQuantTraits::kQuantMethod != cutlass::WintQuantMethod::kWeightOnlyInt2 ) {
579
607
switch (gemm_config.tile_config ) {
580
- dispatch_gemm_config_macro (32 , 128 , 64 , 32 , 32 , 64 );
581
- dispatch_gemm_config_macro (64 , 128 , 64 , 64 , 64 , 64 );
608
+ dispatch_gemm_config_with_k_macro (32 , 128 , 64 , 32 , 32 , 64 , tile_shape_k );
609
+ dispatch_gemm_config_with_k_macro (64 , 128 , 64 , 64 , 64 , 64 , tile_shape_k );
582
610
case CutlassTileConfig::Undefined:
583
611
throw std::runtime_error (" [dispatch_moe_gemm_to_cutlass] gemm config undefined." );
584
612
break ;
@@ -598,31 +626,26 @@ void dispatch_moe_gemm_to_cutlass(const InType* A,
598
626
" [dispatch_moe_gemm_to_cutlass] weight_only_int2 does not support sm70." );
599
627
}
600
628
} else {
601
- constexpr int tile_shape_k = 128 * 8 / cutlass::sizeof_bits<InType>::value;
602
- CUTLASS_TRACE_HOST (" tile_shape_k = " << tile_shape_k);
629
+ // CUTLASS_TRACE_HOST("tile_shape_k = " << tile_shape_k);
603
630
CUTLASS_TRACE_HOST (" Current tile_config value = " << static_cast <int >(gemm_config.tile_config ));
604
631
605
632
606
633
switch (gemm_config.tile_config ) {
607
634
// dispatch_gemm_config_macro(16, 128, 128, 16, 32, 128);
608
- dispatch_gemm_config_macro (16 , 256 , 128 , 16 , 64 , 128 );
609
-
610
- // if (tile_shape_k == 64) {
611
- // dispatch_gemm_config_macro(16, 128, 64, 16, 32, 64);
612
- // } else if (tile_shape_k == 128){
613
- // dispatch_gemm_config_macro(16, 128, 128, 16, 32, 128);
614
- // }
615
- // dispatch_gemm_config_macro(16, 128, tile_shape_k, 16, 32, tile_shape_k);
616
- // dispatch_gemm_config_macro(16, 256, 64, 16, 64, 64);
617
- // dispatch_gemm_config_macro(64, 64, 64, 32, 32, 64);
618
- // dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64);
619
- // dispatch_gemm_config_macro(128, 64, 64, 64, 32, 64);
620
- // dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64);
621
- // dispatch_gemm_config_macro(128, 128, 64, 64, 64, 64);
622
- // dispatch_gemm_config_macro(128, 128, 64, 128, 32, 64);
623
- // dispatch_gemm_config_macro(128, 256, 64, 64, 64, 64);
624
- // dispatch_gemm_config_macro(64, 128, 64, 64, 32, 64);
625
- // dispatch_gemm_config_macro(256, 128, 64, 64, 64, 64);
635
+ // dispatch_gemm_config_macro(16, 256, 128, 16, 64, 128);
636
+ // dispatch_gemm_config_macro(16, 128, 64, 16, 32, 64);
637
+
638
+ dispatch_gemm_config_with_k_macro (16 , 128 , 64 , 16 , 32 , 64 , tile_shape_k);
639
+ // dispatch_gemm_config_with_k_macro(16, 256, 64, 16, 64, 64, tile_shape_k);
640
+ // dispatch_gemm_config_with_k_macro(64, 64, 64, 32, 32, 64, tile_shape_k);
641
+ // dispatch_gemm_config_with_k_macro(32, 128, 64, 32, 32, 64, tile_shape_k);
642
+ // dispatch_gemm_config_with_k_macro(128, 64, 64, 64, 32, 64, tile_shape_k);
643
+ // dispatch_gemm_config_with_k_macro(64, 128, 64, 64, 64, 64, tile_shape_k);
644
+ // dispatch_gemm_config_with_k_macro(128, 128, 64, 64, 64, 64, tile_shape_k);
645
+ // dispatch_gemm_config_with_k_macro(128, 128, 64, 128, 32, 64, tile_shape_k);
646
+ // dispatch_gemm_config_with_k_macro(128, 256, 64, 64, 64, 64, tile_shape_k);
647
+ // dispatch_gemm_config_with_k_macro(64, 128, 64, 64, 32, 64, tile_shape_k);
648
+ // dispatch_gemm_config_with_k_macro(256, 128, 64, 64, 64, 64, tile_shape_k);
626
649
case CutlassTileConfig::Undefined:
627
650
throw std::runtime_error (" [dispatch_moe_gemm_to_cutlass] gemm config undefined." );
628
651
break ;
@@ -637,46 +660,6 @@ void dispatch_moe_gemm_to_cutlass(const InType* A,
637
660
" mixed type tensorop GEMM." );
638
661
break ;
639
662
}
640
-
641
- // if (tile_shape_k == 64) {
642
- // switch (gemm_config.tile_config) {
643
- // // dispatch_gemm_config_macro(16, 128, 128, 16, 32, 128);
644
- // dispatch_gemm_config_macro(16, 128, 64, 16, 32, 64);
645
-
646
- // case CutlassTileConfig::Undefined:
647
- // throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
648
- // break;
649
- // case CutlassTileConfig::ChooseWithHeuristic:
650
- // throw std::runtime_error(
651
- // "[dispatch_moe_gemm_to_cutlass] gemm config should have "
652
- // "already been set by heuristic.");
653
- // break;
654
- // default:
655
- // throw std::runtime_error(
656
- // "[dispatch_moe_gemm_to_cutlass] Config is invalid for "
657
- // "mixed type tensorop GEMM.");
658
- // break;
659
- // }
660
- // } else if (tile_shape_k == 128) {
661
- // switch (gemm_config.tile_config) {
662
- // dispatch_gemm_config_macro(16, 128, 128, 16, 32, 128);
663
-
664
- // case CutlassTileConfig::Undefined:
665
- // throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
666
- // break;
667
- // case CutlassTileConfig::ChooseWithHeuristic:
668
- // throw std::runtime_error(
669
- // "[dispatch_moe_gemm_to_cutlass] gemm config should have "
670
- // "already been set by heuristic.");
671
- // break;
672
- // default:
673
- // throw std::runtime_error(
674
- // "[dispatch_moe_gemm_to_cutlass] Config is invalid for "
675
- // "mixed type tensorop GEMM.");
676
- // break;
677
- // }
678
- // }
679
-
680
663
}
681
664
}
682
665
0 commit comments