1
+ // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #include " cute/algorithm/copy.hpp"
16
+ #include " cute/atom/mma_atom.hpp"
17
+ #include " cutlass/gemm/collective/collective_builder.hpp"
18
+
19
+ #include " cutlass/cutlass.h"
20
+ #include " cutlass/layout/layout.h"
21
+ #include " cutlass/numeric_types.h"
22
+ #include " cutlass/pipeline/pipeline.hpp"
23
+
24
+ using namespace cute ;
25
+
26
+ template <int kStages , class GemmType , class OutputType , class SmemLayoutA ,
27
+ class SmemLayoutE ,
28
+ class SmemLayoutB , class SmemLayoutC >
29
+ struct SharedStorage {
30
+ union {
31
+ struct {
32
+ cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutA>> smem_a;
33
+ cute::array_aligned<uint32_t , cute::cosize_v<SmemLayoutE>> smem_e;
34
+ cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutB>> smem_b;
35
+ };
36
+ cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutC>> smem_c;
37
+ };
38
+
39
+ struct {
40
+ typename cutlass::PipelineTmaAsync<kStages >::SharedStorage pipeline;
41
+ };
42
+ };
43
+
44
+ template <int kBlockM_ , int kBlockN_ , int kBlockK_ ,
45
+ int kNWarps_ , int kStages_ ,
46
+ int kTiles_ , int M_,
47
+ int TokenPackSize_,
48
+ int TAIL_N_ = 0 ,
49
+ int kClusterM_ = 1 ,
50
+ typename elem_type=cutlass::float_e4m3_t ,
51
+ typename OutputType = cutlass::bfloat16_t >
52
+ struct Kernel_traits {
53
+ using Element = elem_type;
54
+ using ElementAccum = float ;
55
+ using ElementOutput = OutputType;
56
+ static_assert (cutlass::sizeof_bits_v<Element> == 8 );
57
+
58
+ static constexpr int kNWarps = kNWarps_ ;
59
+ static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
60
+ static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
61
+ static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
62
+
63
+ static_assert (kNWarps_ == 12 );
64
+
65
+ static constexpr int kBlockM = kBlockM_ ;
66
+ static constexpr int kBlockN = kBlockN_ ;
67
+ static constexpr int kBlockK = kBlockK_ ;
68
+ static constexpr int kTiles = kTiles_ ;
69
+ static constexpr int TokenPackSize = TokenPackSize_;
70
+ static constexpr int TAIL_N = TAIL_N_;
71
+ static constexpr int M = M_;
72
+
73
+ using TileShape_MNK = Shape<Int<kBlockM >, Int<kBlockN >, Int<kBlockK >>;
74
+ using TileShape_MNK_TAIL = Shape<Int<kBlockM >, Int<TAIL_N>, Int<kBlockK >>;
75
+ static constexpr int kClusterM = kClusterM_ ;
76
+ using ClusterShape_MNK = Shape<Int<kClusterM >, _1, _1>;
77
+
78
+ static constexpr int kStages = kStages_ ;
79
+ static_assert (kStages > 1 );
80
+
81
+ using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64 >, _1, _1>>;
82
+
83
+ using TiledMma = decltype (cute::make_tiled_mma(
84
+ cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
85
+ AtomLayoutMNK{}));
86
+
87
+ using Mma = decltype (cute::GMMA::ss_op_selector_sparse<Element, Element, ElementAccum, TileShape_MNK>());
88
+
89
+ using Mma_TAIL = decltype (cute::GMMA::ss_op_selector_sparse<Element, Element, ElementAccum, TileShape_MNK_TAIL>());
90
+
91
+ using SmemLayoutAtomA = decltype (
92
+ cutlass::gemm::collective::detail::rs_smem_selector<
93
+ GMMA::Major::K, Element, Int<kBlockM / 2 >, Int<kBlockK >>());
94
+
95
+ using SmemLayoutA = decltype (
96
+ tile_to_shape (SmemLayoutAtomA{},
97
+ make_shape (Int<kBlockM / 2 >{}, Int<kBlockK >{}, Int<kStages >{})));
98
+
99
+ using SmemLayoutAtomB = decltype (
100
+ cutlass::gemm::collective::detail::ss_smem_selector<
101
+ GMMA::Major::K, Element, decltype (cute::get<1 >(TileShape_MNK{})),
102
+ decltype (cute::get<2 >(TileShape_MNK{}))>());
103
+
104
+ using SmemLayoutB = decltype (
105
+ tile_to_shape (SmemLayoutAtomB{},
106
+ make_shape (shape<1 >(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
107
+
108
+ using SmemLayoutAtomB_TAIL = decltype (
109
+ cutlass::gemm::collective::detail::rs_smem_selector<
110
+ GMMA::Major::K, Element, decltype (cute::get<1 >(TileShape_MNK_TAIL{})),
111
+ decltype (cute::get<2 >(TileShape_MNK_TAIL{}))>());
112
+
113
+ using SmemLayoutB_TAIL = decltype (
114
+ tile_to_shape (SmemLayoutAtomB_TAIL{},
115
+ make_shape (
116
+ shape<1 >(TileShape_MNK_TAIL{}),
117
+ shape<2>(TileShape_MNK_TAIL{}),
118
+ Int<kStages>{})
119
+ ));
120
+ using SmemLayoutAtomC = decltype (
121
+ cutlass::gemm::collective::detail::ss_smem_selector<
122
+ GMMA::Major::K, ElementOutput,
123
+ decltype (cute::get<0 >(TileShape_MNK{})),
124
+ decltype (cute::get<1 >(TileShape_MNK{}))>());
125
+
126
+ using SmemLayoutC = decltype (tile_to_shape(SmemLayoutAtomC{}, select<0 , 1 >(TileShape_MNK{})));
127
+
128
+ using SmemLayoutE = Layout<Shape<Int<NumMmaThreads>, Int<kBlockK / 64 >, Int<kStages >>>;
129
+
130
+ using SharedStorage = SharedStorage<
131
+ kStages , Element, ElementOutput, SmemLayoutA, SmemLayoutE, SmemLayoutB, SmemLayoutC>;
132
+
133
+ using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages >;
134
+ using PipelineState = typename cutlass::PipelineState<kStages >;
135
+
136
+ static constexpr int kNumVecElem = ceil_div(128 , sizeof_bits_v<OutputType>);
137
+ static constexpr int kNumThreadsPerRow = kBlockN / kNumVecElem ;
138
+ static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow ;
139
+ using TiledCopyCAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t >, OutputType>;
140
+ using TiledCopyCThrLayout = decltype (cute::make_layout(
141
+ cute::make_shape (Int<kNumRows >{}, Int<kNumThreadsPerRow >{}),
142
+ LayoutRight{}));
143
+ using TiledCopyCValLayout = decltype (cute::make_layout(
144
+ cute::make_shape (_1{}, Int<kNumVecElem >{}),
145
+ LayoutRight{}));
146
+ using TiledCopyC = decltype (make_tiled_copy(
147
+ TiledCopyCAtom{},
148
+ TiledCopyCThrLayout{}, // Thr layout
149
+ TiledCopyCValLayout{} // Val layout
150
+ ));
151
+ };
0 commit comments