@@ -130,10 +130,12 @@ __global__ void __launch_bounds__(256) dequantize_weights(
130
130
int * __restrict__ qzeros,
131
131
OutputT* __restrict__ output,
132
132
int group_size,
133
- int qweight_cols) {
133
+ int qweight_cols,
134
+ int qweight_rows) {
134
135
#if CUDA_VERSION >= 12000
135
136
int col = blockIdx .x * blockDim .x + threadIdx .x ;
136
137
int row = blockIdx .y * blockDim .y + threadIdx .y ;
138
+ if (col >= qweight_cols || row >= qweight_rows) return ;
137
139
138
140
int group_idx = row / group_size;
139
141
int scale_offset = 8 * col + group_idx * qweight_cols * 8 ;
@@ -188,8 +190,8 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch:
188
190
189
191
int x_num_threads = 16 ;
190
192
int y_num_threads = 16 ;
191
- int x_blocks = qweight_cols / x_num_threads;
192
- int y_blocks = qweight_rows / y_num_threads;
193
+ int x_blocks = ( qweight_cols + x_num_threads - 1 ) / x_num_threads;
194
+ int y_blocks = ( qweight_rows + y_num_threads - 1 ) / y_num_threads;
193
195
194
196
const at::cuda::OptionalCUDAGuard device_guard (device_of (qweight));
195
197
@@ -206,13 +208,13 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch:
206
208
if (scales.scalar_type () == at::ScalarType::Half) {
207
209
auto _scales = reinterpret_cast <half*>(scales.data_ptr <at::Half>());
208
210
auto _output = reinterpret_cast <half*>(output.data_ptr <at::Half>());
209
- dequantize_weights<half>
210
- <<<num_blocks, threads_per_block, 0 , stream>>> ( _qweight, _scales, _zeros, _output, group_size, qweight_cols);
211
+ dequantize_weights<half><<<num_blocks, threads_per_block, 0 , stream>>> (
212
+ _qweight, _scales, _zeros, _output, group_size, qweight_cols, qweight_rows );
211
213
} else {
212
214
auto _scales = reinterpret_cast <__nv_bfloat16*>(scales.data_ptr <at::BFloat16>());
213
215
auto _output = reinterpret_cast <__nv_bfloat16*>(output.data_ptr <at::BFloat16>());
214
- dequantize_weights<__nv_bfloat16>
215
- <<<num_blocks, threads_per_block, 0 , stream>>> ( _qweight, _scales, _zeros, _output, group_size, qweight_cols);
216
+ dequantize_weights<__nv_bfloat16><<<num_blocks, threads_per_block, 0 , stream>>> (
217
+ _qweight, _scales, _zeros, _output, group_size, qweight_cols, qweight_rows );
216
218
}
217
219
218
220
return output;
0 commit comments