@@ -390,64 +390,26 @@ class MmaTensorOpWin2xDequantizer<
390
390
static_cast <float >(scale_frag_[4 ]), static_cast <float >(scale_frag_[5 ]),
391
391
static_cast <float >(scale_frag_[6 ]), static_cast <float >(scale_frag_[7 ]));
392
392
393
- unsigned long long unpack_local_scale = clock64 ();
394
-
395
393
int offset = warp_k_compute_offset * ArchMmaOperator::FragmentB::kElements ;
396
394
const int kOutputColumns = FragmentOutput::kElements / kWarpIterationsAlongN ;
397
395
398
396
// After applying LOP3 optimizations for performance, the B operand requires data rearrangement.
399
397
// reorder: [0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9, 13, 10, 14, 11, 15]
400
398
int mapped_offset = (warp_k_compute_offset % 2 ) == 0 ? 0 : (-kOutputColumns + 1 );
401
-
402
- /* if constexpr (platform::is_same<ElementOperand, bfloat16_t>::value) {
403
- __nv_bfloat162* output_ptr = reinterpret_cast<__nv_bfloat162 *>(&output_frag);
404
- __nv_bfloat16 const* unpacked_ptr = reinterpret_cast<__nv_bfloat16 const*>(&unpacked_frag_);
405
- __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag_) ;
399
+
400
+ CUTLASS_TRACE_DEVICE ( " tb_offset_k = %d, scale_frag_[0] = %f " , tb_offset_k, float (scale_frag_[ 0 ]));
401
+ CUTLASS_PRAGMA_UNROLL
402
+ for ( int mma_n_iter = 0 ; mma_n_iter < kWarpIterationsAlongN ; ++mma_n_iter) {
403
+ int mapped_idx_base = mma_n_iter * kExpansionFactor * kOutputColumns + offset + mapped_offset ;
406
404
407
405
CUTLASS_PRAGMA_UNROLL
408
- for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) {
409
- int mapped_idx_base = mma_n_iter * kExpansionFactor * kOutputColumns + offset + mapped_offset;
410
-
411
- __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]);
406
+ for (int j = 0 ; j < kOutputColumns ; ++j) {
407
+ ElementOperand scaled_value =
408
+ static_cast <ElementOperand>(unpacked_frag_[mapped_idx_base + 2 * j]) * scale_frag_[mma_n_iter];
412
409
413
- CUTLASS_PRAGMA_UNROLL
414
- for (int j = 0; j < kOutputColumns / 2; ++j) {
415
- __nv_bfloat162 unpacked_valuex2 = make_bfloat162(
416
- unpacked_ptr[mapped_idx_base + 4 * j], unpacked_ptr[mapped_idx_base + 4 * j + 2]);
417
- output_ptr[mma_n_iter * kOutputColumns / 2 + j] = __hmul2(unpacked_valuex2, scalex2);
418
- }
410
+ output_frag[mma_n_iter * kOutputColumns + j] = static_cast <ElementOperand>(scaled_value);
419
411
}
420
- } else {*/
421
- // CUTLASS_TRACE_DEVICE(" kWarpIterationsAlongN = %d, kOutputColumns = %d", kWarpIterationsAlongN, kOutputColumns);
422
-
423
- CUTLASS_TRACE_DEVICE (" tb_offset_k = %d, scale_frag_[0] = %f" , tb_offset_k, float (scale_frag_[0 ]));
424
- CUTLASS_PRAGMA_UNROLL
425
- for (int mma_n_iter = 0 ; mma_n_iter < kWarpIterationsAlongN ; ++mma_n_iter) {
426
- int mapped_idx_base = mma_n_iter * kExpansionFactor * kOutputColumns + offset + mapped_offset;
427
-
428
- CUTLASS_PRAGMA_UNROLL
429
- for (int j = 0 ; j < kOutputColumns ; ++j) {
430
- ElementOperand scaled_value =
431
- static_cast <ElementOperand>(unpacked_frag_[mapped_idx_base + 2 * j]) * scale_frag_[mma_n_iter];
432
-
433
- // ElementOperand scaled_value =
434
- // static_cast<ElementOperand>(unpacked_frag_[mapped_idx_base + 2 * j]) * scale_frag_[mma_n_iter];
435
- // CUTLASS_TRACE_DEVICE(" unpacked_frag_[%d] = %f, scale_frag_[%d] = %f",
436
- // mapped_idx_base + 2 * j, static_cast<float>(unpacked_frag_[mapped_idx_base + 2 * j]),
437
- // mma_n_iter, static_cast<float>(scale_frag_[mma_n_iter]));
438
- output_frag[mma_n_iter * kOutputColumns + j] = static_cast <ElementOperand>(scaled_value);
439
- // CUTLASS_TRACE_DEVICE("scale_frag_[0] = %f, scaled_value[%d] = %f", float(scale_frag_[0]), mma_n_iter * kOutputColumns + j, static_cast<float>(output_frag[mma_n_iter * kOutputColumns + j]));
440
-
441
-
442
- // CUTLASS_TRACE_DEVICE(" ref_idx = %d, output_frag[%d] = %f", mapped_idx_base + 2 * j, mma_n_iter * kOutputColumns + j, static_cast<float>(output_frag[mma_n_iter * kOutputColumns + j]));
443
- // output_frag[mma_n_iter * kOutputColumns + j] = static_cast<ElementOperand>(1.0);
444
- }
445
- }
446
- // }
447
-
448
- unsigned long long end = clock64 ();
449
- // CUTLASS_TRACE_DEVICE(" unpack_B: %llu, dequant_local_scale: %llu, unscale: %llu, dequantize: %llu",
450
- // unpack_b - start, unpack_local_scale - unpack_b, end - unpack_local_scale, end - start);
412
+ }
451
413
}
452
414
453
415
// / Add an offset to pointer in units of elements.
0 commit comments