|
16 | 16 | QuantizeMethodBase,
|
17 | 17 | )
|
18 | 18 | from sglang.srt.layers.quantization.utils import is_layer_skipped
|
| 19 | +from sglang.srt.layers.utils import is_sm100_supported |
19 | 20 | from sglang.srt.managers.schedule_batch import global_server_args_dict
|
20 | 21 | from sglang.srt.utils import (
|
21 | 22 | direct_register_custom_op,
|
|
28 | 29 | set_weight_attrs,
|
29 | 30 | )
|
30 | 31 |
|
| 32 | +_is_sm100_supported = is_cuda() and is_sm100_supported() |
31 | 33 | has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
32 | 34 |
|
33 | 35 |
|
@@ -244,13 +246,17 @@ def create_weights(
|
244 | 246 |
|
245 | 247 | # pad the intermediate size to be a multiple of 2 * mxfp4_block
|
246 | 248 | # for to hold non-uniform sharded tensor as well as swizzling
|
247 |
| - if self.use_flashinfer: |
248 |
| - intermediate_size_per_partition_after_pad = round_up(intermediate_size, 256) |
249 |
| - hidden_size = round_up(hidden_size, 256) |
250 |
| - elif is_hip(): |
251 |
| - intermediate_size_per_partition_after_pad = round_up(intermediate_size, 128) |
252 |
| - else: |
253 |
| - intermediate_size_per_partition_after_pad = round_up(intermediate_size, 64) |
| 249 | + intermediate_size_per_partition_after_pad = intermediate_size |
| 250 | + if _is_sm100_supported: |
| 251 | + if self.use_flashinfer: |
| 252 | + intermediate_size_per_partition_after_pad = round_up( |
| 253 | + intermediate_size, 256 |
| 254 | + ) |
| 255 | + hidden_size = round_up(hidden_size, 256) |
| 256 | + else: |
| 257 | + intermediate_size_per_partition_after_pad = round_up( |
| 258 | + intermediate_size, 64 |
| 259 | + ) |
254 | 260 |
|
255 | 261 | self.intermediate_size = intermediate_size_per_partition_after_pad
|
256 | 262 |
|
|
0 commit comments