Skip to content

Commit 0c8dab9

Browse files
yuan-luoluoyuan.luo
andauthored
[sgl-kernel] Opt per_token_quant_fp8 with warp reduce (sgl-project#8130)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
1 parent f39037f commit 0c8dab9

File tree

1 file changed

+106
-16
lines changed

1 file changed

+106
-16
lines changed

sgl-kernel/csrc/gemm/per_token_quant_fp8.cu

Lines changed: 106 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,95 @@
11
#include <ATen/cuda/CUDAContext.h>
22

33
#include <cmath>
4-
#include <cub/block/block_reduce.cuh>
54
#include <flashinfer/vec_dtypes.cuh>
65

76
#include "utils.h"
87

9-
template <typename T>
8+
static constexpr int kWarpSize = 32;
9+
10+
// ---------------------------------------------------------------------------
11+
// 1. Warp‑local, no shared memory
12+
// • One warp handles one token.
13+
// • Eight tokens per 256‑thread CTA.
14+
// ---------------------------------------------------------------------------
15+
template <typename T, int kTokensPerCTA = 8, int kVecSize = 16>
1016
__global__ void per_token_quant_fp8_kernel(
1117
const T* __restrict__ input,
1218
FP8_TYPE* __restrict__ output_q,
1319
float* __restrict__ output_s,
1420
const int64_t hidden_dim,
1521
const int64_t num_tokens) {
22+
const int warp_id = threadIdx.x / kWarpSize; // 0‑7 (8 warps)
23+
const int lane_id = threadIdx.x & (kWarpSize - 1); // 0‑31
24+
const int token_id = blockIdx.x * kTokensPerCTA + warp_id;
25+
if (token_id >= num_tokens) return;
26+
27+
// Global tensors for this token
28+
const T* token_input = input + token_id * hidden_dim;
29+
FP8_TYPE* token_output = output_q + token_id * hidden_dim;
30+
float* token_scale = output_s + token_id;
31+
32+
//
33+
// Pass-1: Perform a warp reduce to find the max_value of a token's hidden_dim
34+
//
35+
float max_value = 0.f;
36+
using vec_t = flashinfer::vec_t<T, kVecSize>;
37+
const int32_t num_vec_elems = hidden_dim / kVecSize;
38+
39+
for (int32_t i = lane_id; i < num_vec_elems; i += kWarpSize) {
40+
vec_t input_vec;
41+
input_vec.cast_load(token_input + i * kVecSize);
42+
43+
#pragma unroll
44+
for (uint32_t j = 0; j < kVecSize; ++j) {
45+
max_value = fmaxf(max_value, fabsf(static_cast<float>(input_vec[j])));
46+
}
47+
}
48+
49+
float warp_max = warpReduceMax(max_value);
50+
51+
__shared__ float scale;
52+
scale = warp_max / FP8_E4M3_MAX;
53+
// Broadcast scale
54+
if (lane_id == 0) {
55+
token_scale[0] = scale;
56+
}
57+
float scale_inv = (scale == 0.f) ? 0.f : 1.0f / scale;
58+
59+
//
60+
// Pass-2: quantize and write back
61+
//
62+
for (int i = lane_id; i < num_vec_elems; i += kWarpSize) {
63+
vec_t input_vec;
64+
input_vec.cast_load(token_input + i * kVecSize);
65+
FP8_TYPE output_arr[kVecSize];
66+
#pragma unroll
67+
for (uint32_t j = 0; j < kVecSize; ++j) {
68+
float val = static_cast<float>(input_vec[j]) * scale_inv;
69+
val = fmaxf(fminf(val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
70+
71+
#ifndef USE_ROCM
72+
output_arr[j] = static_cast<FP8_TYPE>(val);
73+
#else
74+
output_arr[j] = c10::Float8_e4m3fnuz(
75+
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
76+
c10::Float8_e4m3fnuz::from_bits());
77+
#endif
78+
}
79+
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;
80+
}
81+
}
82+
83+
// ---------------------------------------------------------------------------
84+
// 2. Baseline kernel (1 token / CTA, CUB block reduce)
85+
// ---------------------------------------------------------------------------
86+
template <typename T>
87+
__global__ void per_token_quant_fp8_small_batch_kernel(
88+
const T* __restrict__ input,
89+
FP8_TYPE* __restrict__ output_q,
90+
float* __restrict__ output_s,
91+
const int64_t hidden_dim,
92+
const int64_t num_tokens) {
1693
const int token_idx = blockIdx.x;
1794
if (token_idx >= num_tokens) return;
1895

@@ -79,28 +156,41 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
79156
CHECK_INPUT(input);
80157
CHECK_INPUT(output_q);
81158
CHECK_INPUT(output_s);
82-
83159
const auto input_sizes = input.sizes();
84160
const int64_t num_tokens = input_sizes[0];
85161
const int64_t hidden_dim = input_sizes[1];
86-
87162
TORCH_CHECK(hidden_dim % 16 == 0, "Hidden dimension must be divisible by 16, but got ", hidden_dim);
88163

89-
const int block_size = 256;
90-
const int num_blocks = num_tokens;
91-
92-
dim3 grid(num_blocks);
93-
dim3 block(block_size);
94-
95164
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
165+
// Hard-code sm_count
166+
int sm_count = 132;
167+
constexpr int TOKENS_PER_CTA = 8;
168+
const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA);
96169

97170
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
98-
per_token_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
99-
static_cast<scalar_t*>(input.data_ptr()),
100-
static_cast<FP8_TYPE*>(output_q.data_ptr()),
101-
static_cast<float*>(output_s.data_ptr()),
102-
hidden_dim,
103-
num_tokens);
171+
if (use_warp_kernel) {
172+
// -------- warp‑local ---------------------------------------------------
173+
constexpr int THREADS = TOKENS_PER_CTA * kWarpSize; // 256
174+
dim3 grid((num_tokens + TOKENS_PER_CTA - 1) / TOKENS_PER_CTA);
175+
dim3 block(THREADS);
176+
per_token_quant_fp8_kernel<scalar_t, TOKENS_PER_CTA, 16><<<grid, block, 0, stream>>>(
177+
static_cast<const scalar_t*>(input.data_ptr()),
178+
static_cast<FP8_TYPE*>(output_q.data_ptr()),
179+
static_cast<float*>(output_s.data_ptr()),
180+
hidden_dim,
181+
num_tokens);
182+
} else {
183+
// -------- baseline -----------------------------------------------------
184+
constexpr int THREADS = 256;
185+
dim3 grid(num_tokens);
186+
dim3 block(THREADS);
187+
per_token_quant_fp8_small_batch_kernel<scalar_t><<<grid, block, 0, stream>>>(
188+
static_cast<const scalar_t*>(input.data_ptr()),
189+
static_cast<FP8_TYPE*>(output_q.data_ptr()),
190+
static_cast<float*>(output_s.data_ptr()),
191+
hidden_dim,
192+
num_tokens);
193+
}
104194
return true;
105195
});
106196
}

0 commit comments

Comments
 (0)