|
1 | 1 | #include <ATen/cuda/CUDAContext.h>
|
2 | 2 |
|
3 | 3 | #include <cmath>
|
4 |
| -#include <cub/block/block_reduce.cuh> |
5 | 4 | #include <flashinfer/vec_dtypes.cuh>
|
6 | 5 |
|
7 | 6 | #include "utils.h"
|
8 | 7 |
|
9 |
| -template <typename T> |
| 8 | +static constexpr int kWarpSize = 32; |
| 9 | + |
| 10 | +// --------------------------------------------------------------------------- |
| 11 | +// 1. Warp‑local, no shared memory |
| 12 | +// • One warp handles one token. |
| 13 | +// • Eight tokens per 256‑thread CTA. |
| 14 | +// --------------------------------------------------------------------------- |
| 15 | +template <typename T, int kTokensPerCTA = 8, int kVecSize = 16> |
10 | 16 | __global__ void per_token_quant_fp8_kernel(
|
11 | 17 | const T* __restrict__ input,
|
12 | 18 | FP8_TYPE* __restrict__ output_q,
|
13 | 19 | float* __restrict__ output_s,
|
14 | 20 | const int64_t hidden_dim,
|
15 | 21 | const int64_t num_tokens) {
|
| 22 | + const int warp_id = threadIdx.x / kWarpSize; // 0‑7 (8 warps) |
| 23 | + const int lane_id = threadIdx.x & (kWarpSize - 1); // 0‑31 |
| 24 | + const int token_id = blockIdx.x * kTokensPerCTA + warp_id; |
| 25 | + if (token_id >= num_tokens) return; |
| 26 | + |
| 27 | + // Global tensors for this token |
| 28 | + const T* token_input = input + token_id * hidden_dim; |
| 29 | + FP8_TYPE* token_output = output_q + token_id * hidden_dim; |
| 30 | + float* token_scale = output_s + token_id; |
| 31 | + |
| 32 | + // |
| 33 | + // Pass-1: Perform a warp reduce to find the max_value of a token's hidden_dim |
| 34 | + // |
| 35 | + float max_value = 0.f; |
| 36 | + using vec_t = flashinfer::vec_t<T, kVecSize>; |
| 37 | + const int32_t num_vec_elems = hidden_dim / kVecSize; |
| 38 | + |
| 39 | + for (int32_t i = lane_id; i < num_vec_elems; i += kWarpSize) { |
| 40 | + vec_t input_vec; |
| 41 | + input_vec.cast_load(token_input + i * kVecSize); |
| 42 | + |
| 43 | +#pragma unroll |
| 44 | + for (uint32_t j = 0; j < kVecSize; ++j) { |
| 45 | + max_value = fmaxf(max_value, fabsf(static_cast<float>(input_vec[j]))); |
| 46 | + } |
| 47 | + } |
| 48 | + |
| 49 | + float warp_max = warpReduceMax(max_value); |
| 50 | + |
| 51 | + __shared__ float scale; |
| 52 | + scale = warp_max / FP8_E4M3_MAX; |
| 53 | + // Broadcast scale |
| 54 | + if (lane_id == 0) { |
| 55 | + token_scale[0] = scale; |
| 56 | + } |
| 57 | + float scale_inv = (scale == 0.f) ? 0.f : 1.0f / scale; |
| 58 | + |
| 59 | + // |
| 60 | + // Pass-2: quantize and write back |
| 61 | + // |
| 62 | + for (int i = lane_id; i < num_vec_elems; i += kWarpSize) { |
| 63 | + vec_t input_vec; |
| 64 | + input_vec.cast_load(token_input + i * kVecSize); |
| 65 | + FP8_TYPE output_arr[kVecSize]; |
| 66 | +#pragma unroll |
| 67 | + for (uint32_t j = 0; j < kVecSize; ++j) { |
| 68 | + float val = static_cast<float>(input_vec[j]) * scale_inv; |
| 69 | + val = fmaxf(fminf(val, FP8_E4M3_MAX), -FP8_E4M3_MAX); |
| 70 | + |
| 71 | +#ifndef USE_ROCM |
| 72 | + output_arr[j] = static_cast<FP8_TYPE>(val); |
| 73 | +#else |
| 74 | + output_arr[j] = c10::Float8_e4m3fnuz( |
| 75 | + __hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), |
| 76 | + c10::Float8_e4m3fnuz::from_bits()); |
| 77 | +#endif |
| 78 | + } |
| 79 | + *(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr; |
| 80 | + } |
| 81 | +} |
| 82 | + |
| 83 | +// --------------------------------------------------------------------------- |
| 84 | +// 2. Baseline kernel (1 token / CTA, CUB block reduce) |
| 85 | +// --------------------------------------------------------------------------- |
| 86 | +template <typename T> |
| 87 | +__global__ void per_token_quant_fp8_small_batch_kernel( |
| 88 | + const T* __restrict__ input, |
| 89 | + FP8_TYPE* __restrict__ output_q, |
| 90 | + float* __restrict__ output_s, |
| 91 | + const int64_t hidden_dim, |
| 92 | + const int64_t num_tokens) { |
16 | 93 | const int token_idx = blockIdx.x;
|
17 | 94 | if (token_idx >= num_tokens) return;
|
18 | 95 |
|
@@ -79,28 +156,41 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
|
79 | 156 | CHECK_INPUT(input);
|
80 | 157 | CHECK_INPUT(output_q);
|
81 | 158 | CHECK_INPUT(output_s);
|
82 |
| - |
83 | 159 | const auto input_sizes = input.sizes();
|
84 | 160 | const int64_t num_tokens = input_sizes[0];
|
85 | 161 | const int64_t hidden_dim = input_sizes[1];
|
86 |
| - |
87 | 162 | TORCH_CHECK(hidden_dim % 16 == 0, "Hidden dimension must be divisible by 16, but got ", hidden_dim);
|
88 | 163 |
|
89 |
| - const int block_size = 256; |
90 |
| - const int num_blocks = num_tokens; |
91 |
| - |
92 |
| - dim3 grid(num_blocks); |
93 |
| - dim3 block(block_size); |
94 |
| - |
95 | 164 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 165 | + // Hard-code sm_count |
| 166 | + int sm_count = 132; |
| 167 | + constexpr int TOKENS_PER_CTA = 8; |
| 168 | + const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA); |
96 | 169 |
|
97 | 170 | DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
|
98 |
| - per_token_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>( |
99 |
| - static_cast<scalar_t*>(input.data_ptr()), |
100 |
| - static_cast<FP8_TYPE*>(output_q.data_ptr()), |
101 |
| - static_cast<float*>(output_s.data_ptr()), |
102 |
| - hidden_dim, |
103 |
| - num_tokens); |
| 171 | + if (use_warp_kernel) { |
| 172 | + // -------- warp‑local --------------------------------------------------- |
| 173 | + constexpr int THREADS = TOKENS_PER_CTA * kWarpSize; // 256 |
| 174 | + dim3 grid((num_tokens + TOKENS_PER_CTA - 1) / TOKENS_PER_CTA); |
| 175 | + dim3 block(THREADS); |
| 176 | + per_token_quant_fp8_kernel<scalar_t, TOKENS_PER_CTA, 16><<<grid, block, 0, stream>>>( |
| 177 | + static_cast<const scalar_t*>(input.data_ptr()), |
| 178 | + static_cast<FP8_TYPE*>(output_q.data_ptr()), |
| 179 | + static_cast<float*>(output_s.data_ptr()), |
| 180 | + hidden_dim, |
| 181 | + num_tokens); |
| 182 | + } else { |
| 183 | + // -------- baseline ----------------------------------------------------- |
| 184 | + constexpr int THREADS = 256; |
| 185 | + dim3 grid(num_tokens); |
| 186 | + dim3 block(THREADS); |
| 187 | + per_token_quant_fp8_small_batch_kernel<scalar_t><<<grid, block, 0, stream>>>( |
| 188 | + static_cast<const scalar_t*>(input.data_ptr()), |
| 189 | + static_cast<FP8_TYPE*>(output_q.data_ptr()), |
| 190 | + static_cast<float*>(output_s.data_ptr()), |
| 191 | + hidden_dim, |
| 192 | + num_tokens); |
| 193 | + } |
104 | 194 | return true;
|
105 | 195 | });
|
106 | 196 | }
|
0 commit comments