Skip to content

Commit 1c1f8a1

Browse files
kkHuang-amdwunhuang
andauthored
Combine fp4.py and mxfp4.py into one file and support dynamic mxfp4 quantization in mxfp4.py (sgl-project#9049)
Co-authored-by: wunhuang <wunhuang@amd.com>
1 parent 384f8ab commit 1c1f8a1

File tree

7 files changed

+760
-557
lines changed

7 files changed

+760
-557
lines changed

python/sglang/srt/layers/moe/fused_moe_triton/layer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@ def weight_loader(
474474
not expert_id
475475
and self.quant_config is not None
476476
and self.quant_config.get_name() == "mxfp4"
477+
and self.quant_config.is_static_cfg()
477478
):
478479
if "bias" in weight_name:
479480
dim1 = loaded_weight.shape[1]
@@ -724,7 +725,11 @@ def weight_loader_fused(
724725
) -> None:
725726
tp_rank = self.moe_tp_rank
726727

727-
if self.quant_config is not None and self.quant_config.get_name() == "mxfp4":
728+
if (
729+
self.quant_config is not None
730+
and self.quant_config.get_name() == "mxfp4"
731+
and self.quant_config.is_static_cfg()
732+
):
728733
if "bias" in weight_name:
729734
dim1 = loaded_weight.shape[1]
730735
param.data[:, :dim1].copy_(loaded_weight)

python/sglang/srt/layers/quantization/__init__.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,6 @@ def override_quantization_method(self, *args, **kwargs):
4848
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
4949
CompressedTensorsConfig,
5050
)
51-
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
52-
53-
is_mxfp_supported = mxfp_supported()
54-
if is_mxfp_supported:
55-
from sglang.srt.layers.quantization.fp4 import MxFp4Config
56-
5751
from sglang.srt.layers.quantization.fp8 import Fp8Config
5852
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
5953
from sglang.srt.layers.quantization.modelopt_quant import (
@@ -67,6 +61,9 @@ def override_quantization_method(self, *args, **kwargs):
6761
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
6862
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
6963
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
64+
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
65+
66+
_is_mxfp_supported = mxfp_supported()
7067

7168
if TYPE_CHECKING:
7269
from sglang.srt.layers.moe.topk import TopKOutput
@@ -98,11 +95,13 @@ def override_quantization_method(self, *args, **kwargs):
9895
"mxfp4": Mxfp4Config,
9996
}
10097
)
101-
elif is_mxfp_supported and is_hip():
98+
elif _is_mxfp_supported and is_hip():
99+
from sglang.srt.layers.quantization.quark.quark import QuarkConfig
100+
102101
BASE_QUANTIZATION_METHODS.update(
103102
{
104-
"quark": MxFp4Config,
105-
"mxfp4": MxFp4Config,
103+
"quark": QuarkConfig,
104+
"mxfp4": Mxfp4Config,
106105
}
107106
)
108107
# VLLM-dependent quantization methods

0 commit comments

Comments
 (0)