Skip to content

Commit 837451d

Browse files
code style
1 parent 2141963 commit 837451d

File tree

11 files changed

+124
-136
lines changed

11 files changed

+124
-136
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,4 @@ custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_*.cu
172172
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_template.h
173173

174174
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8Afp8_sparse_gemm_*.cu
175-
/custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8Afp8_sparse_gemm_template.h
175+
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8Afp8_sparse_gemm_template.h

custom_ops/gpu_ops/wfp8afp8_sparse_gemm/kernel_traits.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ struct SharedStorage {
3535
};
3636
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutC>> smem_c;
3737
};
38-
39-
struct {
38+
39+
struct {
4040
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline;
4141
};
4242
};
@@ -46,7 +46,7 @@ template<int kBlockM_, int kBlockN_, int kBlockK_,
4646
int kTiles_, int M_,
4747
int TokenPackSize_,
4848
int TAIL_N_ = 0,
49-
int kClusterM_ = 1,
49+
int kClusterM_ = 1,
5050
typename elem_type=cutlass::float_e4m3_t,
5151
typename OutputType = cutlass::bfloat16_t>
5252
struct Kernel_traits {
@@ -78,7 +78,7 @@ struct Kernel_traits {
7878
static constexpr int kStages = kStages_;
7979
static_assert(kStages > 1);
8080

81-
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
81+
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
8282

8383
using TiledMma = decltype(cute::make_tiled_mma(
8484
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
@@ -98,7 +98,7 @@ struct Kernel_traits {
9898

9999
using SmemLayoutAtomB = decltype(
100100
cutlass::gemm::collective::detail::ss_smem_selector<
101-
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK{})),
101+
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK{})),
102102
decltype(cute::get<2>(TileShape_MNK{}))>());
103103

104104
using SmemLayoutB = decltype(
@@ -107,20 +107,20 @@ struct Kernel_traits {
107107

108108
using SmemLayoutAtomB_TAIL = decltype(
109109
cutlass::gemm::collective::detail::rs_smem_selector<
110-
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK_TAIL{})),
110+
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK_TAIL{})),
111111
decltype(cute::get<2>(TileShape_MNK_TAIL{}))>());
112-
112+
113113
using SmemLayoutB_TAIL = decltype(
114114
tile_to_shape(SmemLayoutAtomB_TAIL{},
115115
make_shape(
116-
shape<1>(TileShape_MNK_TAIL{}),
117-
shape<2>(TileShape_MNK_TAIL{}),
116+
shape<1>(TileShape_MNK_TAIL{}),
117+
shape<2>(TileShape_MNK_TAIL{}),
118118
Int<kStages>{})
119119
));
120120
using SmemLayoutAtomC = decltype(
121121
cutlass::gemm::collective::detail::ss_smem_selector<
122122
GMMA::Major::K, ElementOutput,
123-
decltype(cute::get<0>(TileShape_MNK{})),
123+
decltype(cute::get<0>(TileShape_MNK{})),
124124
decltype(cute::get<1>(TileShape_MNK{}))>());
125125

126126
using SmemLayoutC = decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{})));
@@ -132,7 +132,7 @@ struct Kernel_traits {
132132

133133
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
134134
using PipelineState = typename cutlass::PipelineState<kStages>;
135-
135+
136136
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<OutputType>);
137137
static constexpr int kNumThreadsPerRow = kBlockN / kNumVecElem;
138138
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
@@ -148,4 +148,4 @@ struct Kernel_traits {
148148
TiledCopyCThrLayout{}, // Thr layout
149149
TiledCopyCValLayout{} // Val layout
150150
));
151-
};
151+
};

custom_ops/gpu_ops/wfp8afp8_sparse_gemm/mainloop_fwd.h

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ struct CollectiveMainloopFwd {
4040
static constexpr int kBlockM = Ktraits::kBlockM;
4141
static constexpr int kBlockN = Ktraits::kBlockN;
4242
static constexpr int kBlockK = Ktraits::kBlockK;
43-
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
43+
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
4444
static constexpr int kTiles = Ktraits::kTiles;
4545
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
4646
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
@@ -71,8 +71,8 @@ struct CollectiveMainloopFwd {
7171
using TMA_A = decltype(make_tma_copy(
7272
GmemTiledCopy{},
7373
make_tensor(
74-
make_gmem_ptr(static_cast<Element const*>(nullptr)),
75-
WShapeT{},
74+
make_gmem_ptr(static_cast<Element const*>(nullptr)),
75+
WShapeT{},
7676
WStrideT{}
7777
),
7878
SmemLayoutA{}(_, _, _0{}),
@@ -82,8 +82,8 @@ struct CollectiveMainloopFwd {
8282
using TMA_B = decltype(make_tma_copy(
8383
GmemTiledCopy{},
8484
make_tensor(
85-
make_gmem_ptr(static_cast<Element const*>(nullptr)),
86-
ShapeT{},
85+
make_gmem_ptr(static_cast<Element const*>(nullptr)),
86+
ShapeT{},
8787
StrideT{}
8888
),
8989
take<0, 2>(SmemLayoutB{}),
@@ -93,8 +93,8 @@ struct CollectiveMainloopFwd {
9393
using TMA_E = decltype(make_tma_copy(
9494
GmemTiledCopy{},
9595
make_tensor(
96-
make_gmem_ptr(static_cast<uint32_t const*>(nullptr)),
97-
EShapeT{},
96+
make_gmem_ptr(static_cast<uint32_t const*>(nullptr)),
97+
EShapeT{},
9898
EStrideT{}
9999
),
100100
SmemLayoutE{}(_, _, _0{}),
@@ -108,7 +108,7 @@ struct CollectiveMainloopFwd {
108108
static constexpr uint32_t TmaTransactionBytesA = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutA{})) * cutlass::sizeof_bits_v<Element> / 8);
109109
static constexpr uint32_t TmaTransactionBytesB = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutB{})) * cutlass::sizeof_bits_v<Element> / 8);
110110
static constexpr uint32_t TmaTransactionBytesE = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutE{})) * cutlass::sizeof_bits_v<int> / 8);
111-
111+
112112
struct Arguments {
113113
Element const* ptr_A;
114114
WLayoutT layout_A;
@@ -126,8 +126,8 @@ struct CollectiveMainloopFwd {
126126
WLayoutT layout_A;
127127
ELayoutT layout_E;
128128
LayoutT layout_B;
129-
TMA_A tma_load_A;
130-
TMA_E tma_load_E;
129+
TMA_A tma_load_A;
130+
TMA_E tma_load_E;
131131
TMA_B tma_load_B;
132132
const int *tokens;
133133
const float *weight_scale;
@@ -160,7 +160,7 @@ struct CollectiveMainloopFwd {
160160
size<0>(ClusterShape{}));
161161

162162
return {args.layout_A, args.layout_E, args.layout_B,
163-
tma_load_A, tma_load_E, tma_load_B,
163+
tma_load_A, tma_load_E, tma_load_B,
164164
args.tokens, args.weight_scale, args.ptr_C};
165165
}
166166

@@ -200,7 +200,7 @@ struct CollectiveMainloopFwd {
200200
uint16_t *smem_c = reinterpret_cast<uint16_t *>(shared_storage.smem_c.data());
201201

202202
uint32_t * reg_data = reinterpret_cast<uint32_t*>(tOrO_out.data());
203-
203+
204204
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
205205

206206
constexpr int k_copy_times = CUR_N / 16;
@@ -210,13 +210,13 @@ struct CollectiveMainloopFwd {
210210
uint32_t smem_ptr = cast_smem_ptr_to_uint(reinterpret_cast<uint128_t*>(smem_c + i * 16 * 128) + tidx);
211211
asm volatile (
212212
"stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
213-
:: "r"(smem_ptr), "r"(reg_data[4 * i + 0]), "r"(reg_data[4 * i + 2]), "r"(reg_data[4 * i + 1]), "r"(reg_data[4 * i + 3]));
213+
:: "r"(smem_ptr), "r"(reg_data[4 * i + 0]), "r"(reg_data[4 * i + 2]), "r"(reg_data[4 * i + 1]), "r"(reg_data[4 * i + 3]));
214214
}
215215

216216
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
217217
const int batch_idx = TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize;
218218
ElementOutput * store_c = mainloop_params.ptr_C + batch_idx + bidn * (M * kBlockN) + bidm * kBlockM;
219-
219+
220220
const int reamin_tokens = tokens - bidn * kBlockN;
221221

222222
const int col = tidx % 2;
@@ -241,35 +241,35 @@ struct CollectiveMainloopFwd {
241241

242242
template <typename MTensor>
243243
CUTLASS_DEVICE auto get_local_packed_tensor(
244-
const MTensor &mB,
244+
const MTensor &mB,
245245
const int tokens,
246246
const int bidn) const {
247247

248248
auto mB_this_batch = make_tensor(
249-
mB.data(),
249+
mB.data(),
250250
make_layout(
251-
cute::make_shape(tokens, size<1>(mB)),
251+
cute::make_shape(tokens, size<1>(mB)),
252252
mB.stride()
253253
));
254254
return local_tile(mB_this_batch, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
255255
}
256256

257257
template <typename MTensor>
258258
CUTLASS_DEVICE auto get_local_no_packed_tensor(
259-
const MTensor &mB,
259+
const MTensor &mB,
260260
const int pre_fix_token,
261261
const int actual_token,
262262
const int bidn) const {
263263

264264
auto g_offset = local_tile(
265-
mB(_, _, 0),
266-
cute::make_shape(1, size<1>(mB)),
265+
mB(_, _, 0),
266+
cute::make_shape(1, size<1>(mB)),
267267
make_coord(pre_fix_token, _0{}));
268268

269269
auto g_tensor = make_tensor(
270-
g_offset.data(),
270+
g_offset.data(),
271271
make_layout(
272-
cute::make_shape(actual_token, size<1>(mB)),
272+
cute::make_shape(actual_token, size<1>(mB)),
273273
g_offset.stride()
274274
));
275275

@@ -291,15 +291,15 @@ struct CollectiveMainloopFwd {
291291
const int bidn,
292292
const int bidb,
293293
const int tidx) {
294-
294+
295295
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
296296
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{});
297297
Tensor sE = make_tensor(make_smem_ptr(shared_storage.smem_e.data()), SmemLayoutE{});
298-
298+
299299
Tensor mA = mainloop_params.tma_load_A.get_tma_tensor(mainloop_params.layout_A.shape());
300300
Tensor mB = mainloop_params.tma_load_B.get_tma_tensor(mainloop_params.layout_B.shape());
301301
Tensor mE = mainloop_params.tma_load_E.get_tma_tensor(mainloop_params.layout_E.shape());
302-
302+
303303
Tensor gA = local_tile(mA(_, _, _, bidm, bidb), select<0, 1>(Shape<Int<kBlockM / 2>, Int<kBlockK>>{}), make_coord(0,0,_));
304304

305305
Tensor gE = local_tile(mE(_, _, _, bidm, bidb), select<0, 1>(Shape<Int<NumMmaThreads>, Int<kBlockK / 64>>{}), make_coord(0, 0));
@@ -313,7 +313,7 @@ struct CollectiveMainloopFwd {
313313

314314
if constexpr (TokenPackSize == 0) {
315315
Tensor gB = get_local_no_packed_tensor(
316-
mB,
316+
mB,
317317
pre_fix_tokens,
318318
tokens,
319319
bidn);
@@ -351,9 +351,9 @@ struct CollectiveMainloopFwd {
351351
}
352352
} else {
353353
auto mB_this_batch = make_tensor(
354-
mB(_, _, bidb).data(),
354+
mB(_, _, bidb).data(),
355355
make_layout(
356-
cute::make_shape(tokens, size<1>(mB)),
356+
cute::make_shape(tokens, size<1>(mB)),
357357
mB.stride()
358358
));
359359
Tensor gB = local_tile(mB_this_batch, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
@@ -396,11 +396,11 @@ struct CollectiveMainloopFwd {
396396
CUTLASS_DEVICE void
397397
mma(Params const& mainloop_params,
398398
MainloopPipeline pipeline,
399-
PipelineState& smem_pipe_read,
399+
PipelineState& smem_pipe_read,
400400
SharedStorage& shared_storage,
401401
float *acc_s,
402402
const int tidx) {
403-
403+
404404
using sMemBLayout = std::conditional_t<
405405
CUR_N == kBlockN,
406406
SmemLayoutB,
@@ -462,4 +462,3 @@ struct CollectiveMainloopFwd {
462462
}
463463

464464
};
465-

custom_ops/gpu_ops/wfp8afp8_sparse_gemm/utils.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ struct PackedHalf<cutlass::bfloat16_t> {
5151

5252
template <class PointerType>
5353
__device__ GmmaDescriptor make_smem_desc(
54-
PointerType smem_ptr,
54+
PointerType smem_ptr,
5555
int layout_type,
5656
int leading_byte_offset = 0,
5757
int stride_byte_offset = 1024) {
@@ -73,7 +73,7 @@ __forceinline__ __device__ static void gemm(uint64_t const& desc_a, uint64_t con
7373

7474
template <typename Mma, int kBlockK, int NumMmaThreads, typename T>
7575
__forceinline__ __device__ void gemm(
76-
const T * sA,
76+
const T * sA,
7777
const T * sB,
7878
float * acc_c,
7979
const uint32_t *E) {
@@ -97,4 +97,4 @@ __forceinline__ __device__ void gemm(
9797

9898
warpgroup_commit_batch();
9999
warpgroup_wait<0>();
100-
}
100+
}

0 commit comments

Comments
 (0)