@@ -40,7 +40,7 @@ struct CollectiveMainloopFwd {
40
40
static constexpr int kBlockM = Ktraits::kBlockM ;
41
41
static constexpr int kBlockN = Ktraits::kBlockN ;
42
42
static constexpr int kBlockK = Ktraits::kBlockK ;
43
- static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
43
+ static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
44
44
static constexpr int kTiles = Ktraits::kTiles ;
45
45
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
46
46
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
@@ -71,8 +71,8 @@ struct CollectiveMainloopFwd {
71
71
using TMA_A = decltype (make_tma_copy(
72
72
GmemTiledCopy{},
73
73
make_tensor (
74
- make_gmem_ptr (static_cast <Element const *>(nullptr )),
75
- WShapeT{},
74
+ make_gmem_ptr (static_cast <Element const *>(nullptr )),
75
+ WShapeT{},
76
76
WStrideT{}
77
77
),
78
78
SmemLayoutA{}(_, _, _0{}),
@@ -82,8 +82,8 @@ struct CollectiveMainloopFwd {
82
82
using TMA_B = decltype (make_tma_copy(
83
83
GmemTiledCopy{},
84
84
make_tensor (
85
- make_gmem_ptr (static_cast <Element const *>(nullptr )),
86
- ShapeT{},
85
+ make_gmem_ptr (static_cast <Element const *>(nullptr )),
86
+ ShapeT{},
87
87
StrideT{}
88
88
),
89
89
take<0 , 2 >(SmemLayoutB{}),
@@ -93,8 +93,8 @@ struct CollectiveMainloopFwd {
93
93
using TMA_E = decltype (make_tma_copy(
94
94
GmemTiledCopy{},
95
95
make_tensor (
96
- make_gmem_ptr (static_cast <uint32_t const *>(nullptr )),
97
- EShapeT{},
96
+ make_gmem_ptr (static_cast <uint32_t const *>(nullptr )),
97
+ EShapeT{},
98
98
EStrideT{}
99
99
),
100
100
SmemLayoutE{}(_, _, _0{}),
@@ -108,7 +108,7 @@ struct CollectiveMainloopFwd {
108
108
static constexpr uint32_t TmaTransactionBytesA = static_cast <uint32_t >(size(take<0 , 2 >(SmemLayoutA{})) * cutlass::sizeof_bits_v<Element> / 8 );
109
109
static constexpr uint32_t TmaTransactionBytesB = static_cast <uint32_t >(size(take<0 , 2 >(SmemLayoutB{})) * cutlass::sizeof_bits_v<Element> / 8 );
110
110
static constexpr uint32_t TmaTransactionBytesE = static_cast <uint32_t >(size(take<0 , 2 >(SmemLayoutE{})) * cutlass::sizeof_bits_v<int > / 8 );
111
-
111
+
112
112
struct Arguments {
113
113
Element const * ptr_A;
114
114
WLayoutT layout_A;
@@ -126,8 +126,8 @@ struct CollectiveMainloopFwd {
126
126
WLayoutT layout_A;
127
127
ELayoutT layout_E;
128
128
LayoutT layout_B;
129
- TMA_A tma_load_A;
130
- TMA_E tma_load_E;
129
+ TMA_A tma_load_A;
130
+ TMA_E tma_load_E;
131
131
TMA_B tma_load_B;
132
132
const int *tokens;
133
133
const float *weight_scale;
@@ -160,7 +160,7 @@ struct CollectiveMainloopFwd {
160
160
size<0 >(ClusterShape{}));
161
161
162
162
return {args.layout_A , args.layout_E , args.layout_B ,
163
- tma_load_A, tma_load_E, tma_load_B,
163
+ tma_load_A, tma_load_E, tma_load_B,
164
164
args.tokens , args.weight_scale , args.ptr_C };
165
165
}
166
166
@@ -200,7 +200,7 @@ struct CollectiveMainloopFwd {
200
200
uint16_t *smem_c = reinterpret_cast <uint16_t *>(shared_storage.smem_c .data ());
201
201
202
202
uint32_t * reg_data = reinterpret_cast <uint32_t *>(tOrO_out.data ());
203
-
203
+
204
204
cutlass::arch::NamedBarrier::sync (NumMmaThreads, 0 );
205
205
206
206
constexpr int k_copy_times = CUR_N / 16 ;
@@ -210,13 +210,13 @@ struct CollectiveMainloopFwd {
210
210
uint32_t smem_ptr = cast_smem_ptr_to_uint (reinterpret_cast <uint128_t *>(smem_c + i * 16 * 128 ) + tidx);
211
211
asm volatile (
212
212
" stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n "
213
- :: " r" (smem_ptr), " r" (reg_data[4 * i + 0 ]), " r" (reg_data[4 * i + 2 ]), " r" (reg_data[4 * i + 1 ]), " r" (reg_data[4 * i + 3 ]));
213
+ :: " r" (smem_ptr), " r" (reg_data[4 * i + 0 ]), " r" (reg_data[4 * i + 2 ]), " r" (reg_data[4 * i + 1 ]), " r" (reg_data[4 * i + 3 ]));
214
214
}
215
215
216
216
cutlass::arch::NamedBarrier::sync (NumMmaThreads, 0 );
217
217
const int batch_idx = TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize;
218
218
ElementOutput * store_c = mainloop_params.ptr_C + batch_idx + bidn * (M * kBlockN ) + bidm * kBlockM ;
219
-
219
+
220
220
const int reamin_tokens = tokens - bidn * kBlockN ;
221
221
222
222
const int col = tidx % 2 ;
@@ -241,35 +241,35 @@ struct CollectiveMainloopFwd {
241
241
242
242
template <typename MTensor>
243
243
CUTLASS_DEVICE auto get_local_packed_tensor (
244
- const MTensor &mB ,
244
+ const MTensor &mB ,
245
245
const int tokens,
246
246
const int bidn) const {
247
247
248
248
auto mB_this_batch = make_tensor (
249
- mB .data (),
249
+ mB .data (),
250
250
make_layout (
251
- cute::make_shape (tokens, size<1 >(mB )),
251
+ cute::make_shape (tokens, size<1 >(mB )),
252
252
mB .stride ()
253
253
));
254
254
return local_tile (mB_this_batch , select<1 , 2 >(TileShape_MNK{}), make_coord (bidn, _));
255
255
}
256
256
257
257
template <typename MTensor>
258
258
CUTLASS_DEVICE auto get_local_no_packed_tensor (
259
- const MTensor &mB ,
259
+ const MTensor &mB ,
260
260
const int pre_fix_token,
261
261
const int actual_token,
262
262
const int bidn) const {
263
263
264
264
auto g_offset = local_tile (
265
- mB (_, _, 0 ),
266
- cute::make_shape (1 , size<1 >(mB )),
265
+ mB (_, _, 0 ),
266
+ cute::make_shape (1 , size<1 >(mB )),
267
267
make_coord (pre_fix_token, _0{}));
268
268
269
269
auto g_tensor = make_tensor (
270
- g_offset.data (),
270
+ g_offset.data (),
271
271
make_layout (
272
- cute::make_shape (actual_token, size<1 >(mB )),
272
+ cute::make_shape (actual_token, size<1 >(mB )),
273
273
g_offset.stride ()
274
274
));
275
275
@@ -291,15 +291,15 @@ struct CollectiveMainloopFwd {
291
291
const int bidn,
292
292
const int bidb,
293
293
const int tidx) {
294
-
294
+
295
295
Tensor sA = make_tensor (make_smem_ptr (shared_storage.smem_a .data ()), SmemLayoutA{});
296
296
Tensor sB = make_tensor (make_smem_ptr (shared_storage.smem_b .data ()), SmemLayoutB{});
297
297
Tensor sE = make_tensor (make_smem_ptr (shared_storage.smem_e .data ()), SmemLayoutE{});
298
-
298
+
299
299
Tensor mA = mainloop_params.tma_load_A .get_tma_tensor (mainloop_params.layout_A .shape ());
300
300
Tensor mB = mainloop_params.tma_load_B .get_tma_tensor (mainloop_params.layout_B .shape ());
301
301
Tensor mE = mainloop_params.tma_load_E .get_tma_tensor (mainloop_params.layout_E .shape ());
302
-
302
+
303
303
Tensor gA = local_tile (mA (_, _, _, bidm, bidb), select<0 , 1 >(Shape<Int<kBlockM / 2 >, Int<kBlockK >>{}), make_coord (0 ,0 ,_));
304
304
305
305
Tensor gE = local_tile (mE (_, _, _, bidm, bidb), select<0 , 1 >(Shape<Int<NumMmaThreads>, Int<kBlockK / 64 >>{}), make_coord (0 , 0 ));
@@ -313,7 +313,7 @@ struct CollectiveMainloopFwd {
313
313
314
314
if constexpr (TokenPackSize == 0 ) {
315
315
Tensor gB = get_local_no_packed_tensor (
316
- mB ,
316
+ mB ,
317
317
pre_fix_tokens,
318
318
tokens,
319
319
bidn);
@@ -351,9 +351,9 @@ struct CollectiveMainloopFwd {
351
351
}
352
352
} else {
353
353
auto mB_this_batch = make_tensor (
354
- mB (_, _, bidb).data (),
354
+ mB (_, _, bidb).data (),
355
355
make_layout (
356
- cute::make_shape (tokens, size<1 >(mB )),
356
+ cute::make_shape (tokens, size<1 >(mB )),
357
357
mB .stride ()
358
358
));
359
359
Tensor gB = local_tile (mB_this_batch , select<1 , 2 >(TileShape_MNK{}), make_coord (bidn, _));
@@ -396,11 +396,11 @@ struct CollectiveMainloopFwd {
396
396
CUTLASS_DEVICE void
397
397
mma (Params const & mainloop_params,
398
398
MainloopPipeline pipeline,
399
- PipelineState& smem_pipe_read,
399
+ PipelineState& smem_pipe_read,
400
400
SharedStorage& shared_storage,
401
401
float *acc_s,
402
402
const int tidx) {
403
-
403
+
404
404
using sMemBLayout = std::conditional_t <
405
405
CUR_N == kBlockN ,
406
406
SmemLayoutB,
@@ -462,4 +462,3 @@ struct CollectiveMainloopFwd {
462
462
}
463
463
464
464
};
465
-
0 commit comments