Skip to content

Commit 2141963

Browse files
支持24稀疏
1 parent f516421 commit 2141963

11 files changed

+1632
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,6 @@ third_party
170170

171171
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_*.cu
172172
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_template.h
173+
174+
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8Afp8_sparse_gemm_*.cu
175+
/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8Afp8_sparse_gemm_template.h
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "cute/algorithm/copy.hpp"
16+
#include "cute/atom/mma_atom.hpp"
17+
#include "cutlass/gemm/collective/collective_builder.hpp"
18+
19+
#include "cutlass/cutlass.h"
20+
#include "cutlass/layout/layout.h"
21+
#include "cutlass/numeric_types.h"
22+
#include "cutlass/pipeline/pipeline.hpp"
23+
24+
using namespace cute;
25+
26+
template <int kStages, class GemmType, class OutputType, class SmemLayoutA,
27+
class SmemLayoutE,
28+
class SmemLayoutB, class SmemLayoutC>
29+
struct SharedStorage {
30+
union {
31+
struct {
32+
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutA>> smem_a;
33+
cute::array_aligned<uint32_t, cute::cosize_v<SmemLayoutE>> smem_e;
34+
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutB>> smem_b;
35+
};
36+
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutC>> smem_c;
37+
};
38+
39+
struct {
40+
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline;
41+
};
42+
};
43+
44+
template<int kBlockM_, int kBlockN_, int kBlockK_,
45+
int kNWarps_, int kStages_,
46+
int kTiles_, int M_,
47+
int TokenPackSize_,
48+
int TAIL_N_ = 0,
49+
int kClusterM_ = 1,
50+
typename elem_type=cutlass::float_e4m3_t,
51+
typename OutputType = cutlass::bfloat16_t>
52+
struct Kernel_traits {
53+
using Element = elem_type;
54+
using ElementAccum = float;
55+
using ElementOutput = OutputType;
56+
static_assert(cutlass::sizeof_bits_v<Element> == 8);
57+
58+
static constexpr int kNWarps = kNWarps_;
59+
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
60+
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
61+
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
62+
63+
static_assert(kNWarps_ == 12);
64+
65+
static constexpr int kBlockM = kBlockM_;
66+
static constexpr int kBlockN = kBlockN_;
67+
static constexpr int kBlockK = kBlockK_;
68+
static constexpr int kTiles = kTiles_;
69+
static constexpr int TokenPackSize = TokenPackSize_;
70+
static constexpr int TAIL_N = TAIL_N_;
71+
static constexpr int M = M_;
72+
73+
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kBlockK>>;
74+
using TileShape_MNK_TAIL = Shape<Int<kBlockM>, Int<TAIL_N>, Int<kBlockK>>;
75+
static constexpr int kClusterM = kClusterM_;
76+
using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
77+
78+
static constexpr int kStages = kStages_;
79+
static_assert(kStages > 1);
80+
81+
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
82+
83+
using TiledMma = decltype(cute::make_tiled_mma(
84+
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
85+
AtomLayoutMNK{}));
86+
87+
using Mma = decltype(cute::GMMA::ss_op_selector_sparse<Element, Element, ElementAccum, TileShape_MNK>());
88+
89+
using Mma_TAIL = decltype(cute::GMMA::ss_op_selector_sparse<Element, Element, ElementAccum, TileShape_MNK_TAIL>());
90+
91+
using SmemLayoutAtomA = decltype(
92+
cutlass::gemm::collective::detail::rs_smem_selector<
93+
GMMA::Major::K, Element, Int<kBlockM / 2>, Int<kBlockK>>());
94+
95+
using SmemLayoutA = decltype(
96+
tile_to_shape(SmemLayoutAtomA{},
97+
make_shape(Int<kBlockM / 2>{}, Int<kBlockK>{}, Int<kStages>{})));
98+
99+
using SmemLayoutAtomB = decltype(
100+
cutlass::gemm::collective::detail::ss_smem_selector<
101+
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK{})),
102+
decltype(cute::get<2>(TileShape_MNK{}))>());
103+
104+
using SmemLayoutB = decltype(
105+
tile_to_shape(SmemLayoutAtomB{},
106+
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
107+
108+
using SmemLayoutAtomB_TAIL = decltype(
109+
cutlass::gemm::collective::detail::rs_smem_selector<
110+
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK_TAIL{})),
111+
decltype(cute::get<2>(TileShape_MNK_TAIL{}))>());
112+
113+
using SmemLayoutB_TAIL = decltype(
114+
tile_to_shape(SmemLayoutAtomB_TAIL{},
115+
make_shape(
116+
shape<1>(TileShape_MNK_TAIL{}),
117+
shape<2>(TileShape_MNK_TAIL{}),
118+
Int<kStages>{})
119+
));
120+
using SmemLayoutAtomC = decltype(
121+
cutlass::gemm::collective::detail::ss_smem_selector<
122+
GMMA::Major::K, ElementOutput,
123+
decltype(cute::get<0>(TileShape_MNK{})),
124+
decltype(cute::get<1>(TileShape_MNK{}))>());
125+
126+
using SmemLayoutC = decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{})));
127+
128+
using SmemLayoutE = Layout<Shape<Int<NumMmaThreads>, Int<kBlockK / 64>, Int<kStages>>>;
129+
130+
using SharedStorage = SharedStorage<
131+
kStages, Element, ElementOutput, SmemLayoutA, SmemLayoutE, SmemLayoutB, SmemLayoutC>;
132+
133+
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
134+
using PipelineState = typename cutlass::PipelineState<kStages>;
135+
136+
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<OutputType>);
137+
static constexpr int kNumThreadsPerRow = kBlockN / kNumVecElem;
138+
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
139+
using TiledCopyCAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, OutputType>;
140+
using TiledCopyCThrLayout = decltype(cute::make_layout(
141+
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
142+
LayoutRight{}));
143+
using TiledCopyCValLayout = decltype(cute::make_layout(
144+
cute::make_shape(_1{}, Int<kNumVecElem>{}),
145+
LayoutRight{}));
146+
using TiledCopyC = decltype(make_tiled_copy(
147+
TiledCopyCAtom{},
148+
TiledCopyCThrLayout{}, // Thr layout
149+
TiledCopyCValLayout{} // Val layout
150+
));
151+
};

0 commit comments

Comments
 (0)