Skip to content

Commit 3eb4a80

Browse files
authored
Fix AWQ Dequant and Weight Loading of deepseek v2 (sgl-project#6842)
1 parent e726131 commit 3eb4a80

File tree

3 files changed

+18
-11
lines changed

3 files changed

+18
-11
lines changed

python/sglang/srt/models/deepseek_v2.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2137,8 +2137,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
21372137
):
21382138
q_a_proj_weight = cached_a_proj[q_a_proj_name]
21392139
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
2140+
cat_dim = 0
2141+
if (
2142+
self.quant_config.get_name() == "awq"
2143+
or self.quant_config.get_name() == "moe_wna16"
2144+
):
2145+
cat_dim = 1
21402146
fused_weight = torch.cat(
2141-
[q_a_proj_weight, kv_a_proj_weight], dim=0
2147+
[q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
21422148
)
21432149
param_name = (
21442150
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")

sgl-kernel/csrc/gemm/awq_kernel.cu

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,12 @@ __global__ void __launch_bounds__(256) dequantize_weights(
130130
int* __restrict__ qzeros,
131131
OutputT* __restrict__ output,
132132
int group_size,
133-
int qweight_cols) {
133+
int qweight_cols,
134+
int qweight_rows) {
134135
#if CUDA_VERSION >= 12000
135136
int col = blockIdx.x * blockDim.x + threadIdx.x;
136137
int row = blockIdx.y * blockDim.y + threadIdx.y;
138+
if (col >= qweight_cols || row >= qweight_rows) return;
137139

138140
int group_idx = row / group_size;
139141
int scale_offset = 8 * col + group_idx * qweight_cols * 8;
@@ -188,8 +190,8 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch:
188190

189191
int x_num_threads = 16;
190192
int y_num_threads = 16;
191-
int x_blocks = qweight_cols / x_num_threads;
192-
int y_blocks = qweight_rows / y_num_threads;
193+
int x_blocks = (qweight_cols + x_num_threads - 1) / x_num_threads;
194+
int y_blocks = (qweight_rows + y_num_threads - 1) / y_num_threads;
193195

194196
const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight));
195197

@@ -206,13 +208,13 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch:
206208
if (scales.scalar_type() == at::ScalarType::Half) {
207209
auto _scales = reinterpret_cast<half*>(scales.data_ptr<at::Half>());
208210
auto _output = reinterpret_cast<half*>(output.data_ptr<at::Half>());
209-
dequantize_weights<half>
210-
<<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols);
211+
dequantize_weights<half><<<num_blocks, threads_per_block, 0, stream>>>(
212+
_qweight, _scales, _zeros, _output, group_size, qweight_cols, qweight_rows);
211213
} else {
212214
auto _scales = reinterpret_cast<__nv_bfloat16*>(scales.data_ptr<at::BFloat16>());
213215
auto _output = reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>());
214-
dequantize_weights<__nv_bfloat16>
215-
<<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols);
216+
dequantize_weights<__nv_bfloat16><<<num_blocks, threads_per_block, 0, stream>>>(
217+
_qweight, _scales, _zeros, _output, group_size, qweight_cols, qweight_rows);
216218
}
217219

218220
return output;

sgl-kernel/tests/test_awq_dequant.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def sglang_awq_dequantize(
6767
"qweight_row,qweight_col,is_bf16_act",
6868
list(
6969
itertools.product(
70-
[3584, 18944, 128, 256, 512, 1024],
71-
[448, 576, 4736, 16, 32, 64, 128],
70+
[3584, 18944, 128, 256, 512, 1024, 1536],
71+
[448, 576, 4736, 16, 32, 64, 128, 72],
7272
[True, False],
7373
)
7474
),
@@ -77,7 +77,6 @@ def test_awq_dequant_compare_implementations(
7777
qweight_row: int, qweight_col: int, is_bf16_act: bool
7878
):
7979
device = torch.device("cuda")
80-
8180
qweight = torch.randint(
8281
0,
8382
torch.iinfo(torch.int32).max,

0 commit comments

Comments
 (0)