Skip to content

Commit 856872e

Browse files
committed
support dequantizer
1 parent 1c1065e commit 856872e

File tree

1 file changed

+10
-48
lines changed

1 file changed

+10
-48
lines changed

custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h

Lines changed: 10 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -390,64 +390,26 @@ class MmaTensorOpWin2xDequantizer<
390390
static_cast<float>(scale_frag_[4]), static_cast<float>(scale_frag_[5]),
391391
static_cast<float>(scale_frag_[6]), static_cast<float>(scale_frag_[7]));
392392

393-
unsigned long long unpack_local_scale = clock64();
394-
395393
int offset = warp_k_compute_offset * ArchMmaOperator::FragmentB::kElements;
396394
const int kOutputColumns = FragmentOutput::kElements / kWarpIterationsAlongN;
397395

398396
// After applying LOP3 optimizations for performance, the B operand requires data rearrangement.
399397
// reorder: [0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9, 13, 10, 14, 11, 15]
400398
int mapped_offset = (warp_k_compute_offset % 2) == 0 ? 0 : (-kOutputColumns + 1);
401-
402-
/*if constexpr (platform::is_same<ElementOperand, bfloat16_t>::value) {
403-
__nv_bfloat162* output_ptr = reinterpret_cast<__nv_bfloat162 *>(&output_frag);
404-
__nv_bfloat16 const* unpacked_ptr = reinterpret_cast<__nv_bfloat16 const*>(&unpacked_frag_);
405-
__nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag_);
399+
400+
CUTLASS_TRACE_DEVICE("tb_offset_k = %d, scale_frag_[0] = %f", tb_offset_k, float(scale_frag_[0]));
401+
CUTLASS_PRAGMA_UNROLL
402+
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) {
403+
int mapped_idx_base = mma_n_iter * kExpansionFactor * kOutputColumns + offset + mapped_offset;
406404

407405
CUTLASS_PRAGMA_UNROLL
408-
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) {
409-
int mapped_idx_base = mma_n_iter * kExpansionFactor * kOutputColumns + offset + mapped_offset;
410-
411-
__nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]);
406+
for (int j = 0; j < kOutputColumns; ++j) {
407+
ElementOperand scaled_value =
408+
static_cast<ElementOperand>(unpacked_frag_[mapped_idx_base + 2 * j]) * scale_frag_[mma_n_iter];
412409

413-
CUTLASS_PRAGMA_UNROLL
414-
for (int j = 0; j < kOutputColumns / 2; ++j) {
415-
__nv_bfloat162 unpacked_valuex2 = make_bfloat162(
416-
unpacked_ptr[mapped_idx_base + 4 * j], unpacked_ptr[mapped_idx_base + 4 * j + 2]);
417-
output_ptr[mma_n_iter * kOutputColumns / 2 + j] = __hmul2(unpacked_valuex2, scalex2);
418-
}
410+
output_frag[mma_n_iter * kOutputColumns + j] = static_cast<ElementOperand>(scaled_value);
419411
}
420-
} else {*/
421-
// CUTLASS_TRACE_DEVICE(" kWarpIterationsAlongN = %d, kOutputColumns = %d", kWarpIterationsAlongN, kOutputColumns);
422-
423-
CUTLASS_TRACE_DEVICE("tb_offset_k = %d, scale_frag_[0] = %f", tb_offset_k, float(scale_frag_[0]));
424-
CUTLASS_PRAGMA_UNROLL
425-
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) {
426-
int mapped_idx_base = mma_n_iter * kExpansionFactor * kOutputColumns + offset + mapped_offset;
427-
428-
CUTLASS_PRAGMA_UNROLL
429-
for (int j = 0; j < kOutputColumns; ++j) {
430-
ElementOperand scaled_value =
431-
static_cast<ElementOperand>(unpacked_frag_[mapped_idx_base + 2 * j]) * scale_frag_[mma_n_iter];
432-
433-
// ElementOperand scaled_value =
434-
// static_cast<ElementOperand>(unpacked_frag_[mapped_idx_base + 2 * j]) * scale_frag_[mma_n_iter];
435-
// CUTLASS_TRACE_DEVICE(" unpacked_frag_[%d] = %f, scale_frag_[%d] = %f",
436-
// mapped_idx_base + 2 * j, static_cast<float>(unpacked_frag_[mapped_idx_base + 2 * j]),
437-
// mma_n_iter, static_cast<float>(scale_frag_[mma_n_iter]));
438-
output_frag[mma_n_iter * kOutputColumns + j] = static_cast<ElementOperand>(scaled_value);
439-
// CUTLASS_TRACE_DEVICE("scale_frag_[0] = %f, scaled_value[%d] = %f", float(scale_frag_[0]), mma_n_iter * kOutputColumns + j, static_cast<float>(output_frag[mma_n_iter * kOutputColumns + j]));
440-
441-
442-
// CUTLASS_TRACE_DEVICE(" ref_idx = %d, output_frag[%d] = %f", mapped_idx_base + 2 * j, mma_n_iter * kOutputColumns + j, static_cast<float>(output_frag[mma_n_iter * kOutputColumns + j]));
443-
// output_frag[mma_n_iter * kOutputColumns + j] = static_cast<ElementOperand>(1.0);
444-
}
445-
}
446-
//}
447-
448-
unsigned long long end = clock64();
449-
// CUTLASS_TRACE_DEVICE(" unpack_B: %llu, dequant_local_scale: %llu, unscale: %llu, dequantize: %llu",
450-
// unpack_b - start, unpack_local_scale - unpack_b, end - unpack_local_scale, end - start);
412+
}
451413
}
452414

453415
/// Add an offset to pointer in units of elements.

0 commit comments

Comments
 (0)