Skip to content

Commit 3ae33fc

Browse files
authored
Fix hopper launch gpt-oss model illegal memory (sgl-project#8908)
1 parent 500b15c commit 3ae33fc

File tree

1 file changed

+13
-7
lines changed
  • python/sglang/srt/layers/quantization

1 file changed

+13
-7
lines changed

python/sglang/srt/layers/quantization/mxfp4.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
QuantizeMethodBase,
1717
)
1818
from sglang.srt.layers.quantization.utils import is_layer_skipped
19+
from sglang.srt.layers.utils import is_sm100_supported
1920
from sglang.srt.managers.schedule_batch import global_server_args_dict
2021
from sglang.srt.utils import (
2122
direct_register_custom_op,
@@ -28,6 +29,7 @@
2829
set_weight_attrs,
2930
)
3031

32+
_is_sm100_supported = is_cuda() and is_sm100_supported()
3133
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
3234

3335

@@ -244,13 +246,17 @@ def create_weights(
244246

245247
# pad the intermediate size to be a multiple of 2 * mxfp4_block
246248
# for to hold non-uniform sharded tensor as well as swizzling
247-
if self.use_flashinfer:
248-
intermediate_size_per_partition_after_pad = round_up(intermediate_size, 256)
249-
hidden_size = round_up(hidden_size, 256)
250-
elif is_hip():
251-
intermediate_size_per_partition_after_pad = round_up(intermediate_size, 128)
252-
else:
253-
intermediate_size_per_partition_after_pad = round_up(intermediate_size, 64)
249+
intermediate_size_per_partition_after_pad = intermediate_size
250+
if _is_sm100_supported:
251+
if self.use_flashinfer:
252+
intermediate_size_per_partition_after_pad = round_up(
253+
intermediate_size, 256
254+
)
255+
hidden_size = round_up(hidden_size, 256)
256+
else:
257+
intermediate_size_per_partition_after_pad = round_up(
258+
intermediate_size, 64
259+
)
254260

255261
self.intermediate_size = intermediate_size_per_partition_after_pad
256262

0 commit comments

Comments
 (0)