From 270e373ca54a40360152ac8aac0de7f1363d8623 Mon Sep 17 00:00:00 2001 From: bukejiyu <395822456@qq.com> Date: Mon, 11 Aug 2025 06:19:35 +0000 Subject: [PATCH 1/7] support qwen3 --- fastdeploy/config.py | 2 + fastdeploy/model_executor/layers/linear.py | 138 +++++++++++----- .../layers/moe/fused_moe_backend_base.py | 2 + fastdeploy/model_executor/layers/moe/moe.py | 112 ++++++++----- .../layers/quantization/weight_only.py | 46 +++++- fastdeploy/model_executor/layers/utils.py | 1 - .../model_executor/load_weight_utils.py | 4 + .../model_executor/model_loader/__init__.py | 5 + .../model_loader/default_loader_v1.py | 12 +- .../model_loader/inflight_quant_loader.py | 149 ++++++++++++++++++ fastdeploy/model_executor/models/qwen3.py | 31 +++- fastdeploy/model_executor/models/qwen3moe.py | 28 +++- fastdeploy/model_executor/models/utils.py | 74 +++++---- fastdeploy/model_executor/utils.py | 43 +++++ fastdeploy/worker/worker_process.py | 5 +- 15 files changed, 517 insertions(+), 135 deletions(-) create mode 100644 fastdeploy/model_executor/model_loader/inflight_quant_loader.py create mode 100644 fastdeploy/model_executor/utils.py diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 865a090822..03eb2c86a0 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -666,6 +666,7 @@ class LoadChoices(str, Enum): DEFAULT = "default" # only support qwen3-bf16 now DEFAULT_V1 = "default_v1" + INFLIGHT_QUANT = "inflight_quant" class LoadConfig: @@ -685,6 +686,7 @@ def __init__( args, ): self.load_choices: Union[str, LoadChoices] = LoadChoices.DEFAULT.value + self.is_inflight_quant = False self.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1 self.dynamic_load_weight: bool = False self.load_strategy: Optional[Literal["ipc", "ipc_snapshot"]] = None diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index fe89102119..b721e322ca 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -23,8 +23,10 @@ from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase from fastdeploy.model_executor.models.utils import ( - default_weight_loader, + default_load_weights_into_param, + default_weights_processor, set_weight_attrs, + slice_fn, ) from fastdeploy.platforms import current_platform @@ -37,8 +39,10 @@ class UnquantizedLinearMethod(QuantMethodBase): def create_weights(self, layer: nn.Layer, **extra_weight_attrs): """ extra_weight_attrs is a dictionary that may include parameters like: + - split_axis: axis along which to split the tensor in a distributed environment - output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns) - - weight_loader: a callable or method responsible for loading the weight data + - weights_processor: a callable or method responsible for processing weight data + - load_weights_into_param:Loads the given weight tensor into the specified model parameter. """ layer.weight = layer.create_parameter( shape=layer.weight_shape, @@ -46,12 +50,21 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): is_bias=False, default_initializer=paddle.nn.initializer.Constant(0), ) + split_axis = extra_weight_attrs.get("split_axis") + if hasattr(layer, "nranks") and layer.nranks > 0: + _set_var_distributed(layer.weight, split_axis=split_axis) set_weight_attrs( layer.weight, - {"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config))}, + { + **extra_weight_attrs, + "weights_processor": extra_weight_attrs.get( + "weights_processor", default_weights_processor(layer.fd_config) + ), + "load_weights_into_param": extra_weight_attrs.get( + "load_weights_into_param", default_load_weights_into_param() + ), + }, ) - if hasattr(layer, "nranks") and layer.nranks > 1: - set_weight_attrs(layer.weight, {"output_dim": extra_weight_attrs.get("output_dim")}) def process_loaded_weights(self, layer, weights) -> None: # mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation @@ -158,6 +171,7 @@ def __init__( is_bias=True, ) + self.is_quantized = fd_config.model_config.is_quantized # smooth quant self.linear_shift = None self.linear_smooth = None @@ -270,9 +284,17 @@ def __init__( assert self.quant_method is not None self.quant_method.create_weights( self, - weight_loader=( - self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config) + weights_processor=( + self.weights_processor + if hasattr(self, "weights_processor") + else default_weights_processor(self.fd_config) ), + load_weights_into_param=( + self.load_weights_into_param + if hasattr(self, "load_weights_into_param") + else default_load_weights_into_param() + ), + inflight_quant=fd_config.quant_config and not skip_quant, ) @@ -327,17 +349,23 @@ def __init__( self.quant_method.create_weights( self, output_dim=True, - weight_loader=( - self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config) + weights_processor=( + self.weights_processor + if hasattr(self, "weights_processor") + else default_weights_processor(self.fd_config) + ), + load_weights_into_param=( + self.load_weights_into_param + if hasattr(self, "load_weights_into_param") + else default_load_weights_into_param() ), + inflight_quant=fd_config.quant_config and not skip_quant, ) + if self.nranks > 0: - _set_var_distributed(self.weight, split_axis=1) if self.with_bias: # col parallel _set_var_distributed(self.bias, split_axis=1) - if self.nranks > 1: - set_weight_attrs(self.bias, {"output_dim": True}) class MergedColumnParallelLinear(ColumnParallelLinear): @@ -390,31 +418,33 @@ def __init__( skip_quant=skip_quant, ) - def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): + def load_weights_into_param(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): + assert loaded_shard_id in ["gate", "up"] + output_dim = getattr(param, "output_dim", None) + if loaded_shard_id == "gate": + param = slice_fn(param, output_dim, start=0, end=self.output_size // 2) + elif loaded_shard_id == "up": + param = slice_fn(param, output_dim, start=self.output_size // 2, end=self.output_size) + assert param.shape == loaded_weight.shape, ( + f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" + ) + param.copy_(loaded_weight, False) + + def weights_processor(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): # 1.fused gate_up in disk # 2.split gate up assert loaded_shard_id in ["gate", "up"] output_dim = getattr(param, "output_dim", None) # Tensor parallelism splits the weight along the output_dim - if output_dim is not None: + if output_dim is not None and self.nranks > 1: dim = -1 size = loaded_weight.get_shape()[dim] block_size = size // self.nranks shard_offset = self.local_rank * block_size shard_size = (self.local_rank + 1) * block_size - loaded_weight = loaded_weight[..., shard_offset:shard_size] - + loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size) loaded_weight = get_tensor(loaded_weight) - - if loaded_shard_id == "gate": - param = param[:, : self.output_size // 2] - elif loaded_shard_id == "up": - param = param[:, self.output_size // 2 :] - - assert param.shape == loaded_weight.shape, ( - f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" - ) - param.copy_(loaded_weight, False) + yield loaded_weight def load_state_dict(self, state_dict: dict): """ @@ -484,33 +514,44 @@ def __init__(self, fd_config, prefix, with_bias=False, add_bias=True): add_bias=add_bias, ) - def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): + def weights_processor(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): # 1.fused qkv in disk # 2.split q k v assert loaded_shard_id in ["q", "k", "v"] output_dim = getattr(param, "output_dim", None) # Tensor parallelism splits the weight along the output_dim - if output_dim is not None: + if output_dim is not None and self.nranks > 1: dim = -1 size = loaded_weight.get_shape()[dim] block_size = size // self.nranks shard_offset = self.local_rank * block_size shard_size = (self.local_rank + 1) * block_size - loaded_weight = loaded_weight[..., shard_offset:shard_size] + loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size) loaded_weight = get_tensor(loaded_weight) + yield loaded_weight + def load_weights_into_param(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): + assert loaded_shard_id in ["q", "k", "v"] + output_dim = getattr(param, "output_dim", None) if loaded_shard_id == "q": - param = param[:, : self.num_heads_per_rank * self.head_dim] + param = slice_fn(param, output_dim, 0, self.num_heads_per_rank * self.head_dim) + elif loaded_shard_id == "k": - param = param[ - :, - self.num_heads_per_rank - * self.head_dim : (self.num_heads_per_rank + self.kv_num_heads_per_rank) - * self.head_dim, - ] + param = slice_fn( + param, + output_dim, + self.num_heads_per_rank * self.head_dim, + (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim, + ) + elif loaded_shard_id == "v": - param = param[:, (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim :] + param = slice_fn( + param, + output_dim, + (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim, + (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) * self.head_dim, + ) assert param.shape == loaded_weight.shape, ( f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" @@ -653,9 +694,17 @@ def __init__( self, split_axis=0, output_dim=False, - weight_loader=( - self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config) + weights_processor=( + self.weights_processor + if hasattr(self, "weights_processor") + else default_weights_processor(self.fd_config) + ), + load_weights_into_param=( + self.load_weights_into_param + if hasattr(self, "load_weights_into_param") + else default_load_weights_into_param() ), + inflight_quant=fd_config.quant_config and not skip_quant, ) if self.nranks > 0: _set_var_distributed(self.weight, split_axis=0) @@ -670,6 +719,17 @@ def __init__( }, ) + if self.nranks > 0: + if self.with_bias: + # col parallel + _set_var_distributed(self.bias, split_axis=0) + set_weight_attrs( + self.bias, + { + "output_dim": False, + }, + ) + self.reduce_results = reduce_results def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor: diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py index 6a57b20079..7a8548f2e0 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py @@ -185,9 +185,11 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): if current_platform.is_cuda(): self.up_gate_proj_weight_shape = [layer.num_experts, layer.hidden_size, layer.moe_intermediate_size * 2] self.down_proj_weight_shape = [layer.num_experts, layer.moe_intermediate_size, layer.hidden_size] + extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 1, "down": 0, "up": 1}} else: self.up_gate_proj_weight_shape = [layer.num_experts, layer.moe_intermediate_size * 2, layer.hidden_size] self.down_proj_weight_shape = [layer.num_experts, layer.hidden_size, layer.moe_intermediate_size] + extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}} layer.up_gate_proj_weight = layer.create_parameter( shape=self.up_gate_proj_weight_shape, diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 16b75e9e24..6f2dfb101f 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -22,6 +22,7 @@ from fastdeploy import envs from fastdeploy.model_executor.layers.utils import get_tensor +from fastdeploy.model_executor.models.utils import slice_fn from fastdeploy.platforms import current_platform from fastdeploy.worker.experts_manager import RedundantExpertManger @@ -35,7 +36,6 @@ def get_moe_method(): """ return moe method based on device platform """ - from fastdeploy.platforms import current_platform if current_platform.is_cuda(): from .fused_moe_cutlass_backend import CutlassMoEMethod @@ -152,10 +152,18 @@ def __init__( and is_supported_moe_backend is not None and is_supported_moe_backend(self.quant_method) ): - self.quant_method.create_weights(self, weight_loader=self.weight_loader) + self.quant_method.create_weights( + self, + weights_processor=self.weights_processor, + load_weights_into_param=self.load_weights_into_param, + ) else: # w_fp16 a_fp16 - self.quant_method.create_weights(self, weight_loader=self.weight_loader) + self.quant_method.create_weights( + self, + weights_processor=self.weights_processor, + load_weights_into_param=self.load_weights_into_param, + ) logger.info( f"{moe_tag}MoE config is {num_experts=}[{expert_id_offset}, {expert_id_offset + self.num_local_experts}), \ @@ -164,32 +172,58 @@ def __init__( tp_size={self.tp_size}." ) - def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str] = None): - from fastdeploy.platforms import current_platform + def _load_down_weight_into_param(self, expert_param, shard_dim: int, loaded_weight, shard_id: str): + assert expert_param.shape == loaded_weight.shape, ( + f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})" + ) + expert_param.copy_(loaded_weight, False) + + def _load_gate_up_weight_into_param(self, expert_param, shard_dim: int, loaded_weight, shard_id: str): + tensor_size = expert_param.shape[shard_dim] // 2 + if shard_id == "gate": + expert_param = slice_fn(expert_param, shard_dim, start=0, end=tensor_size) + elif shard_id == "up": + expert_param = slice_fn(expert_param, shard_dim, start=tensor_size, end=tensor_size * 2) + + assert expert_param.shape == loaded_weight.shape, ( + f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})" + ) + expert_param.copy_(loaded_weight, False) + + def _load_expert_weight_into_param(self, expert_param, shard_dim: int, loaded_weight, shard_id: str): + if shard_id == "down": + self._load_down_weight_into_param(expert_param, shard_dim, loaded_weight, shard_id) + elif shard_id in ["gate", "up"]: + self._load_gate_up_weight_into_param(expert_param, shard_dim, loaded_weight, shard_id) + + def load_weights_into_param(self, param, loaded_weight, expert_id: int, shard_id: Optional[str] = None): + assert shard_id in ["gate", "down", "up"] + SHARD_ID_TO_SHARDED_DIM = getattr(param, "SHARD_ID_TO_SHARDED_DIM") + expert_param = param[expert_id] + self._load_expert_weight_into_param( + expert_param=expert_param, + shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id], + loaded_weight=loaded_weight, + shard_id=shard_id, + ) + def weights_processor(self, param, loaded_weight, expert_id: int, shard_id: Optional[str] = None): if shard_id is None: # 1.gate up fused in disk return # 2.gate up splited in disk assert shard_id in ["gate", "down", "up"] + SHARD_ID_TO_SHARDED_DIM = getattr(param, "SHARD_ID_TO_SHARDED_DIM") expert_param = param[expert_id] - if current_platform.is_cuda(): - SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1} - else: - SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0} - self._load_expert_weight( + + yield from self._processed_expert_weight( expert_param=expert_param, shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id], loaded_weight=loaded_weight, shard_id=shard_id, ) - def _load_gate_up_weight(self, expert_param, shard_dim, loaded_weight, shard_id): - tensor_size = expert_param.shape[shard_dim] // 2 - if shard_id == "gate": - expert_param = expert_param[..., :tensor_size] if shard_dim else expert_param[:tensor_size, ...] - elif shard_id == "up": - expert_param = expert_param[..., tensor_size:] if shard_dim else expert_param[tensor_size:, ...] + def _processed_gate_up_weight(self, expert_param, shard_dim: int, loaded_weight, shard_id: str): if self.tp_size > 1: size = loaded_weight.get_shape()[-1] @@ -199,15 +233,13 @@ def _load_gate_up_weight(self, expert_param, shard_dim, loaded_weight, shard_id) loaded_weight = loaded_weight[..., shard_offset:shard_size] loaded_weight = get_tensor(loaded_weight) + # To ensure compatibility across backends, apply an extra transpose for GCU and XPU - if expert_param.shape != loaded_weight.shape: + if not current_platform.is_cuda(): loaded_weight = loaded_weight.transpose([1, 0]) - assert expert_param.shape == loaded_weight.shape, ( - f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})" - ) - expert_param.copy_(loaded_weight, False) + yield loaded_weight - def _load_down_weight(self, expert_param, shard_dim, loaded_weight, shard_id): + def _processed_down_weight(self, expert_param, shard_dim: int, loaded_weight, shard_id: str): if self.tp_size > 1: size = loaded_weight.get_shape()[shard_dim] block_size = size // self.tp_size @@ -215,45 +247,45 @@ def _load_down_weight(self, expert_param, shard_dim, loaded_weight, shard_id): shard_size = (self.tp_rank + 1) * block_size loaded_weight = loaded_weight[shard_offset:shard_size, ...] loaded_weight = get_tensor(loaded_weight) + # To ensure compatibility across backends, apply an extra transpose for GCU and XPU - if expert_param.shape != loaded_weight.shape: + if not current_platform.is_cuda(): loaded_weight = loaded_weight.transpose([1, 0]) - assert expert_param.shape == loaded_weight.shape, ( - f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})" - ) - expert_param.copy_(loaded_weight, False) + yield loaded_weight - def _load_expert_weight( + def _processed_expert_weight( self, expert_param, - shard_dim, + shard_dim: int, loaded_weight, - shard_id, + shard_id: str, ): if shard_id == "down": - self._load_down_weight(expert_param, shard_dim, loaded_weight, shard_id) + yield from self._processed_down_weight(expert_param, shard_dim, loaded_weight, shard_id) + elif shard_id in ["gate", "up"]: - self._load_gate_up_weight(expert_param, shard_dim, loaded_weight, shard_id) + + yield from self._processed_gate_up_weight(expert_param, shard_dim, loaded_weight, shard_id) @classmethod def make_expert_params_mapping( cls, - ckpt_gate_proj_name: str, ckpt_down_proj_name: str, - ckpt_up_proj_name: str, param_gate_up_proj_name: str, param_down_proj_name: str, num_experts: int, ckpt_expert_key_name: str = "experts", + ckpt_gate_proj_name: Optional[str] = None, + ckpt_up_proj_name: Optional[str] = None, ckpt_gate_up_proj_name: Optional[str] = None, ) -> list[tuple[str, str, int, str]]: - param_name_maping = [ - ("gate", ckpt_gate_proj_name), - ("down", ckpt_down_proj_name), - ("up", ckpt_up_proj_name), - ] - if ckpt_gate_up_proj_name: + param_name_maping = [("down", ckpt_down_proj_name)] + if ckpt_gate_up_proj_name is not None: param_name_maping.append((None, ckpt_gate_up_proj_name)) + elif ckpt_gate_proj_name is not None: + param_name_maping.append(("gate", ckpt_gate_proj_name)) + elif ckpt_up_proj_name is not None: + param_name_maping.append(("up", ckpt_up_proj_name)) return [ # (param_name, weight_name, expert_id, shard_id) diff --git a/fastdeploy/model_executor/layers/quantization/weight_only.py b/fastdeploy/model_executor/layers/quantization/weight_only.py index a221dca106..a7ef9156dc 100644 --- a/fastdeploy/model_executor/layers/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/quantization/weight_only.py @@ -21,6 +21,7 @@ import paddle from paddle.nn.quant import weight_only_linear, weight_quantize +from fastdeploy.model_executor.models.utils import set_weight_attrs from fastdeploy.platforms import current_platform from ..moe import FusedMoE @@ -172,25 +173,52 @@ def create_weights(self, layer, **extra_weight_attrs): # The scale shape should be equal to the output dim of weight using Per-Channel Quantization. weight_scale_shape = [layer.weight_shape[1]] - layer.weight_shape.reverse() if self.quant_config.name() == "wint4": layer.weight_shape[0] //= 2 layer.weight_dtype = "int8" - layer.weight = layer.create_parameter( + layer.quant_weight = layer.create_parameter( shape=layer.weight_shape, dtype=layer.weight_dtype, is_bias=False, default_initializer=paddle.nn.initializer.Constant(0), ) + inflight_quant = extra_weight_attrs.get("inflight_quant", None) + output_dim = extra_weight_attrs.get("output_dim") + output_dim = not output_dim + weight_loader = extra_weight_attrs.get("weight_loader") + load_weights_into_param = extra_weight_attrs.get("load_weights_into_param") + weight_attrs = { + "weight_loader": weight_loader, + "load_weights_into_param": load_weights_into_param, + "output_dim": output_dim, + } + if inflight_quant: + weight_attrs = {**weight_attrs, "quant_method": self.apply_weight_quantization} + set_weight_attrs( + layer.quant_weight, + { + **weight_attrs, + }, + ) + layer.weight_scale = layer.create_parameter( shape=weight_scale_shape, dtype=layer._dtype, is_bias=False, ) + set_weight_attrs( + layer.weight_scale, + { + "weight_loader": weight_loader, + "output_dim": output_dim, + "load_weights_into_param": load_weights_into_param, + }, + ) + @abstractmethod def process_loaded_weights(self, layer, weights) -> None: raise NotImplementedError @@ -198,7 +226,7 @@ def process_loaded_weights(self, layer, weights) -> None: def apply(self, layer, x): linear_out = weight_only_linear( x, - weight=layer.weight, + weight=layer.quant_weight, bias=layer.bias if layer.add_bias else None, weight_scale=layer.weight_scale, weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"), @@ -230,9 +258,17 @@ def process_prequanted_weights(self, layer, state_dict) -> None: """ quant_weight = get_tensor(state_dict.pop(layer.weight_key)) weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key)) - layer.weight.set_value(quant_weight) + layer.quant_weight.set_value(quant_weight) layer.weight_scale.set_value(weight_scale.astype(paddle.get_default_dtype())) + def apply_weight_quantization(self, unquantized_weight): + quanted_weight_tensor, weight_scale_tensor = weight_quantize( + unquantized_weight, + algo=self.quant_config.algo, + arch=self.quant_config.weight_only_linear_arch, + ) + return (quanted_weight_tensor, weight_scale_tensor.astype(paddle.get_default_dtype())) + def process_loaded_weights(self, layer, weight) -> None: quanted_weight_tensor, weight_scale_tensor = weight_quantize( @@ -241,5 +277,5 @@ def process_loaded_weights(self, layer, weight) -> None: arch=self.quant_config.weight_only_linear_arch, ) - layer.weight.set_value(quanted_weight_tensor) + layer.quant_weight.set_value(quanted_weight_tensor) layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype())) diff --git a/fastdeploy/model_executor/layers/utils.py b/fastdeploy/model_executor/layers/utils.py index b5e1c2ad0e..4044fb5288 100644 --- a/fastdeploy/model_executor/layers/utils.py +++ b/fastdeploy/model_executor/layers/utils.py @@ -127,7 +127,6 @@ def get_tensor(input: Union[paddle.Tensor, np.ndarray, str], model_path=None) -> """ if "PySafeSlice" in str(type(input)): input = input.get() - if isinstance(input, paddle.Tensor): if input.place.is_cpu_place(): return input.to(paddle.device.get_device()) diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index 6aacb3a59c..5ac37b4c8c 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -34,6 +34,10 @@ ) from fastdeploy.platforms import current_platform +ORI_WEIGHT_NAME = "weight" +QUANT_WEIGHT_NAME = "quant_weight" +QUANT_SCALE_NAME = "weight_scale" + def measure_time(func): def wrapper(*args, **kwargs): diff --git a/fastdeploy/model_executor/model_loader/__init__.py b/fastdeploy/model_executor/model_loader/__init__.py index 4a9c3fec9c..96dd439233 100644 --- a/fastdeploy/model_executor/model_loader/__init__.py +++ b/fastdeploy/model_executor/model_loader/__init__.py @@ -20,6 +20,9 @@ from fastdeploy.model_executor.model_loader.default_loader_v1 import ( DefaultModelLoaderV1, ) +from fastdeploy.model_executor.model_loader.inflight_quant_loader import ( + InflightQuantModelLoader, +) def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: @@ -27,6 +30,8 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: if load_config.load_choices == LoadChoices.DEFAULT_V1: return DefaultModelLoaderV1(load_config) + elif load_config.load_choices == LoadChoices.INFLIGHT_QUANT: + return InflightQuantModelLoader(load_config) return DefaultModelLoader(load_config) diff --git a/fastdeploy/model_executor/model_loader/default_loader_v1.py b/fastdeploy/model_executor/model_loader/default_loader_v1.py index 4d79772e52..e14af077d3 100644 --- a/fastdeploy/model_executor/model_loader/default_loader_v1.py +++ b/fastdeploy/model_executor/model_loader/default_loader_v1.py @@ -28,6 +28,7 @@ ) from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader from fastdeploy.model_executor.models.model_base import ModelRegistry +from fastdeploy.model_executor.models.utils import default_load_weights_into_param from fastdeploy.platforms import current_platform @@ -50,7 +51,16 @@ def clean_memory_fragments(self) -> None: def load_weights(self, model, fd_config: FDConfig) -> None: _, safetensor_files = get_all_safetensors(fd_config.model_config.model) weights_iterator = fast_weights_iterator(safetensor_files) - model.load_weights(weights_iterator) + params_dict = dict(model.named_parameters()) + processed_weights_iterator = model.processed_weights(weights_iterator, params_dict) + for loaded_weight_name, _, model_param, preprocessed_weight, shard_id, expert_id in processed_weights_iterator: + load_weights_into_param = getattr( + model_param, "load_weights_into_param", default_load_weights_into_param() + ) + if expert_id is not None: + load_weights_into_param(model_param, preprocessed_weight, expert_id, shard_id) + else: + load_weights_into_param(model_param, preprocessed_weight, shard_id) self.clean_memory_fragments() def load_model(self, fd_config: FDConfig) -> nn.Layer: diff --git a/fastdeploy/model_executor/model_loader/inflight_quant_loader.py b/fastdeploy/model_executor/model_loader/inflight_quant_loader.py new file mode 100644 index 0000000000..2c1d1e1909 --- /dev/null +++ b/fastdeploy/model_executor/model_loader/inflight_quant_loader.py @@ -0,0 +1,149 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import contextlib + +import paddle +from paddle import nn +from paddleformers.utils.log import logger + +from fastdeploy.config import FDConfig, LoadConfig, ModelConfig +from fastdeploy.model_executor.load_weight_utils import ( + ORI_WEIGHT_NAME, + QUANT_SCALE_NAME, + QUANT_WEIGHT_NAME, + get_all_safetensors, + measure_time, + safetensors_weights_iterator, +) +from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader +from fastdeploy.model_executor.models.model_base import ModelRegistry +from fastdeploy.model_executor.models.utils import default_load_weights_into_param +from fastdeploy.model_executor.utils import switch_config_context +from fastdeploy.platforms import current_platform + + +class InflightQuantModelLoader(BaseModelLoader): + """ModelLoader that can load registered models""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + assert load_config.is_inflight_quant, "InflightQuantModelLoader can only be used for dynamic quantization." + logger.info("Load the model and weights using InflightModelLoader") + + def download_model(self, model_config: ModelConfig) -> None: + pass + + def clean_memory_fragments(self) -> None: + """clean_memory_fragments""" + if current_platform.is_cuda(): + paddle.device.cuda.empty_cache() + paddle.device.synchronize() + + def create_model(self, fd_config, architectures): + with paddle.LazyGuard(): + model_cls = ModelRegistry.get_class(architectures) + model = model_cls(fd_config) + model.eval() + return model + + def _get_quantized_weights_iterator(self, quantized_params_dict, fd_config: FDConfig): + """ + Construct an unquantized model, perform weight preprocessing (e.g., tensor parallel splitting) + on its parameters, and return an iterator of quantized weights. + + Args: + quantized_params_dict (dict): A dictionary containing quantized parameter names and their tensors. + fd_config (FDConfig): Configuration object with settings needed for weight processing. + + Returns: + Iterator: Yields tuples of (weight_name, quantized_weight_tensor) for each quantized weight. + """ + + # 1.Create an unquantized model + with switch_config_context(fd_config, "quant_config", None): + architectures = fd_config.model_config.architectures[0] + with paddle.LazyGuard(): + model_cls = ModelRegistry.get_class(architectures) + unquantized_model = model_cls(fd_config) + unquantized_model.eval() + # 2.Get weight iterator + _, safetensor_files = get_all_safetensors(fd_config.model_config.model) + weights_iterator = safetensors_weights_iterator(safetensor_files) + # 3.Get an iterator over the processed weights (e.g., tensor parallel splitting) . + unquantized_params_dict = dict(unquantized_model.named_parameters()) + processed_weights_iterator = unquantized_model.processed_weights(weights_iterator, unquantized_params_dict) + # 4.Quantize using the parameter that has a quantization method. + for loaded_weight_name, model_param_name, _, preprocessed_weight, _, _ in processed_weights_iterator: + if model_param_name in quantized_params_dict: + yield loaded_weight_name, preprocessed_weight + else: + model_quant_weight_name = model_param_name.replace(ORI_WEIGHT_NAME, QUANT_WEIGHT_NAME) + model_param = quantized_params_dict[model_quant_weight_name] + quant_method = getattr(model_param, "quant_method", None) + assert quant_method is not None, f"{model_quant_weight_name} lacks an implementation of quant_method." + quant_weight_name = loaded_weight_name.replace(ORI_WEIGHT_NAME, QUANT_WEIGHT_NAME) + quant_res = quant_method(preprocessed_weight) + if len(quant_res) == 2: + quant_scale_name = loaded_weight_name.replace(ORI_WEIGHT_NAME, QUANT_SCALE_NAME) + quant_weight = quant_res[0] + weight_scale = quant_res[1] + yield quant_weight_name, quant_weight + yield quant_scale_name, weight_scale + else: + yield quant_weight_name, quant_weight + + @measure_time + def load_weights(self, model, fd_config: FDConfig) -> None: + quantized_params_dict = dict(model.named_parameters()) + quanted_weights_iterator = self._get_quantized_weights_iterator(quantized_params_dict, fd_config) + processed_weights_iterator = model.processed_weights( + quanted_weights_iterator, quantized_params_dict, is_processed=True + ) + for loaded_weight_name, _, model_param, preprocessed_weight, shard_id, expert_id in processed_weights_iterator: + load_weights_into_param = getattr( + model_param, "load_weights_into_param", default_load_weights_into_param() + ) + if expert_id is not None: + load_weights_into_param(model_param, preprocessed_weight, expert_id, shard_id) + else: + load_weights_into_param(model_param, preprocessed_weight, shard_id) + self.clean_memory_fragments() + + def load_model(self, fd_config: FDConfig) -> nn.Layer: + architectures = fd_config.model_config.architectures[0] + logger.info(f"Starting to load model {architectures}") + if fd_config.load_config.dynamic_load_weight: + # register rl model + import fastdeploy.rl # noqa + + architectures = architectures + "RL" + context = paddle.LazyGuard() + + else: + context = contextlib.nullcontext() + + with context: + model_cls = ModelRegistry.get_class(architectures) + model = model_cls(fd_config) + + model.eval() + + # RL model not need set_state_dict + if fd_config.load_config.dynamic_load_weight: + return model + self.load_weights(model, fd_config) + return model diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index 04988740df..862a5c22c2 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -246,7 +246,7 @@ def name(self): return "Qwen3ForCausalLM" @paddle.no_grad() - def load_weights(self, weights_iterator) -> None: + def processed_weights(self, weights_iterator, params_dict, is_processed=False) -> None: """ Load model parameters from a given weights_iterator object. @@ -254,7 +254,7 @@ def load_weights(self, weights_iterator) -> None: weights_iterator (Iterator): An iterator yielding (name, weight) pairs. """ - from fastdeploy.model_executor.models.utils import default_weight_loader + from fastdeploy.model_executor.models.utils import default_weights_processor stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -267,7 +267,6 @@ def load_weights(self, weights_iterator) -> None: ("lm_head.linear", "lm_head", None), ] - params_dict = dict(self.named_parameters()) for loaded_weight_name, loaded_weight in weights_iterator: for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in loaded_weight_name: @@ -276,15 +275,27 @@ def load_weights(self, weights_iterator) -> None: if model_param_name not in params_dict: continue param = params_dict[model_param_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) - weight_loader(param, loaded_weight, shard_id) + if is_processed: + yield loaded_weight_name, model_param_name, param, loaded_weight, shard_id, None + else: + weights_processor = getattr(param, "weights_processor", default_weights_processor(self.fd_config)) + yield from ( + (loaded_weight_name, model_param_name, param, preprocessed_weight, shard_id, None) + for preprocessed_weight in weights_processor(param, loaded_weight, shard_id) + ) break else: if loaded_weight_name not in params_dict: continue param = params_dict[loaded_weight_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) - weight_loader(param, loaded_weight) + if is_processed: + yield loaded_weight_name, loaded_weight_name, param, loaded_weight, None, None + else: + weights_processor = getattr(param, "weights_processor", default_weights_processor(self.fd_config)) + yield from ( + (loaded_weight_name, loaded_weight_name, param, preprocessed_weight, None, None) + for preprocessed_weight in weights_processor(param, loaded_weight, None) + ) if self.tie_word_embeddings: self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0])) @@ -301,7 +312,11 @@ def set_state_dict(self, state_dict): """ self.model.load_state_dict(state_dict) if self.tie_word_embeddings: - self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0])) + if hasattr(self.lm_head.linear, "weight"): + self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0])) + else: + # for ep + self.lm_head.linear.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0])) else: self.lm_head.load_state_dict(state_dict) diff --git a/fastdeploy/model_executor/models/qwen3moe.py b/fastdeploy/model_executor/models/qwen3moe.py index e9d85bbeb8..40e8f2b52e 100644 --- a/fastdeploy/model_executor/models/qwen3moe.py +++ b/fastdeploy/model_executor/models/qwen3moe.py @@ -326,7 +326,7 @@ def get_expert_mapping( ) @paddle.no_grad() - def load_weights(self, weights_iterator) -> None: + def processed_weights(self, weights_iterator, params_dict) -> None: """ Load model parameters from a given weights_iterator object. @@ -334,7 +334,7 @@ def load_weights(self, weights_iterator) -> None: weights_iterator (Iterator): An iterator yielding (name, weight) pairs. """ - from fastdeploy.model_executor.models.utils import default_weight_loader + from fastdeploy.model_executor.models.utils import default_weights_processor stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -358,8 +358,11 @@ def load_weights(self, weights_iterator) -> None: if model_param_name not in params_dict: continue param = params_dict[model_param_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) - weight_loader(param, loaded_weight, shard_id) + weights_processor = getattr(param, "weights_processor", default_weights_processor(self.fd_config)) + yield from ( + (loaded_weight_name, model_param_name, param, preprocessed_weight, shard_id, None) + for preprocessed_weight in weights_processor(param, loaded_weight, shard_id) + ) break else: for mapping in expert_params_mapping: @@ -370,15 +373,24 @@ def load_weights(self, weights_iterator) -> None: if model_param_name not in params_dict: continue param = params_dict[model_param_name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id) + weights_processor = param.weights_processor + yield from ( + (loaded_weight_name, model_param_name, param, preprocessed_weight, shard_id, expert_id) + for preprocessed_weight in weights_processor( + param, loaded_weight, shard_id=shard_id, expert_id=expert_id + ) + ) break else: if loaded_weight_name not in params_dict: continue param = params_dict[loaded_weight_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) - weight_loader(param, loaded_weight) + weights_processor = getattr(param, "weights_processor", default_weights_processor(self.fd_config)) + weights_processor(param, loaded_weight) + yield from ( + (loaded_weight_name, model_param_name, param, preprocessed_weight, None, None) + for preprocessed_weight in weights_processor(param, loaded_weight) + ) @paddle.no_grad() def set_state_dict(self, state_dict): diff --git a/fastdeploy/model_executor/models/utils.py b/fastdeploy/model_executor/models/utils.py index 1d2f21a824..148f78d8a8 100644 --- a/fastdeploy/model_executor/models/utils.py +++ b/fastdeploy/model_executor/models/utils.py @@ -54,41 +54,53 @@ def set_weight_attrs(param, param_attr_map: Optional[dict[str, Any]]): setattr(param, key, value) -def default_weight_loader(fd_config: FDConfig) -> None: +def slice_fn(weight_or_paramter, output_dim, start, end, step=1): + if hasattr(weight_or_paramter, "get_shape"): + shape = weight_or_paramter.get_shape() + else: + shape = weight_or_paramter.shape + if len(shape) == 1: + weight_or_paramter = weight_or_paramter[start:end] + elif output_dim: + weight_or_paramter = weight_or_paramter[..., start:end] + else: + weight_or_paramter = weight_or_paramter[start:end, ...] + return weight_or_paramter + + +def default_weights_processor(fd_config: FDConfig) -> None: """Default weight loader""" def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None): """fn""" - try: - output_dim = getattr(param, "output_dim", None) - # Tensor parallelism splits the weight along the output_dim - if output_dim is not None: - dim = -1 if output_dim else 0 - size = loaded_weight.get_shape()[dim] - block_size = size // fd_config.parallel_config.tensor_parallel_size - shard_offset = fd_config.parallel_config.tensor_parallel_rank * block_size - shard_size = (fd_config.parallel_config.tensor_parallel_rank + 1) * block_size - if output_dim: - loaded_weight = loaded_weight[..., shard_offset:shard_size] - else: - loaded_weight = loaded_weight[shard_offset:shard_size, ...] - - loaded_weight = get_tensor(loaded_weight) - # mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation - if param.dtype != loaded_weight.dtype: - loaded_weight = loaded_weight.cast(param.dtype) - - if param.shape != loaded_weight.shape: - try: - param = param.reshape(loaded_weight.shape) - except ValueError as e: - raise ValueError( - f" Attempted to load weight ({loaded_weight.shape}) into parameter ({param.shape}). {e}" - ) - - param.copy_(loaded_weight, False) - except Exception: - raise + + output_dim = getattr(param, "output_dim", None) + # Tensor parallelism splits the weight along the output_dim + if output_dim is not None and fd_config.parallel_config.tensor_parallel_size > 1: + dim = -1 if output_dim else 0 + size = loaded_weight.get_shape()[dim] + block_size = size // fd_config.parallel_config.tensor_parallel_size + shard_offset = fd_config.parallel_config.tensor_parallel_rank * block_size + shard_size = (fd_config.parallel_config.tensor_parallel_rank + 1) * block_size + loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size) + + loaded_weight = get_tensor(loaded_weight) + # mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation + if param.dtype != loaded_weight.dtype: + loaded_weight = loaded_weight.cast(param.dtype) + yield loaded_weight + + return fn + + +def default_load_weights_into_param(): + def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None): + if param.dtype != loaded_weight.dtype: + loaded_weight = loaded_weight.cast(param.dtype) + assert param.shape == loaded_weight.shape, ( + f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" + ) + param.copy_(loaded_weight, False) return fn diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py new file mode 100644 index 0000000000..9731f886b0 --- /dev/null +++ b/fastdeploy/model_executor/utils.py @@ -0,0 +1,43 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from contextlib import contextmanager + +import paddle + + +@contextmanager +def device_guard(device="cpu", dev_id=0): + origin_device = paddle.device.get_device() + if device == "cpu": + paddle.set_device(device) + elif device in ["gpu", "xpu", "npu"]: + paddle.set_device("{}:{}".format(device, dev_id)) + try: + yield + finally: + paddle.set_device(origin_device) + + +@contextmanager +def switch_config_context(config_obj, config_attr_name, value): + """switch_config_context""" + origin_value = getattr(config_obj, config_attr_name) + setattr(config_obj, config_attr_name, value) + try: + yield + finally: + setattr(config_obj, config_attr_name, origin_value) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 8ddd4bc903..c148d0d512 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -656,10 +656,11 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: quant_config_name = None if quantization_config is not None and quantization_config.get("quantization", None) is None: raise ValueError("quantization_config should have a key named 'quantization' for specify quant config.") - + is_inflight_quant = False if quantization_config is not None: quant_config_name = quantization_config["quantization"] elif args.quantization != "None": + is_inflight_quant = True quantization_config = {} quant_config_name = args.quantization quantization_config["quantization"] = quant_config_name @@ -678,7 +679,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: else: quant_cls = get_quantization_config(quant_config_name) quant_config = quant_cls.from_config(quantization_config) - + load_config.is_inflight_quant = is_inflight_quant # Log quantization info logger.info("===========quantization_config==============") if quant_config is not None: From 19e0d4b7f6a0b98ca8b673e13c8a22e2bff3114f Mon Sep 17 00:00:00 2001 From: bukejiyu <395822456@qq.com> Date: Mon, 11 Aug 2025 06:37:11 +0000 Subject: [PATCH 2/7] update --- fastdeploy/model_executor/layers/linear.py | 1 - fastdeploy/model_executor/layers/moe/moe.py | 4 +++- .../model_loader/inflight_quant_loader.py | 4 ++-- fastdeploy/model_executor/models/qwen3.py | 2 +- fastdeploy/model_executor/models/qwen3moe.py | 2 +- fastdeploy/model_executor/models/utils.py | 15 +-------------- fastdeploy/model_executor/utils.py | 14 ++++++++++++++ 7 files changed, 22 insertions(+), 20 deletions(-) diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index b721e322ca..bccce80df0 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -707,7 +707,6 @@ def __init__( inflight_quant=fd_config.quant_config and not skip_quant, ) if self.nranks > 0: - _set_var_distributed(self.weight, split_axis=0) if self.with_bias: # col parallel _set_var_distributed(self.bias, split_axis=0) diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 6f2dfb101f..bb60629cfd 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -22,7 +22,9 @@ from fastdeploy import envs from fastdeploy.model_executor.layers.utils import get_tensor -from fastdeploy.model_executor.models.utils import slice_fn +from fastdeploy.model_executor.utils import slice_fn + +# from fastdeploy.model_executor.models.utils import slice_fn from fastdeploy.platforms import current_platform from fastdeploy.worker.experts_manager import RedundantExpertManger diff --git a/fastdeploy/model_executor/model_loader/inflight_quant_loader.py b/fastdeploy/model_executor/model_loader/inflight_quant_loader.py index 2c1d1e1909..751d748363 100644 --- a/fastdeploy/model_executor/model_loader/inflight_quant_loader.py +++ b/fastdeploy/model_executor/model_loader/inflight_quant_loader.py @@ -25,9 +25,9 @@ ORI_WEIGHT_NAME, QUANT_SCALE_NAME, QUANT_WEIGHT_NAME, + fast_weights_iterator, get_all_safetensors, measure_time, - safetensors_weights_iterator, ) from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader from fastdeploy.model_executor.models.model_base import ModelRegistry @@ -82,7 +82,7 @@ def _get_quantized_weights_iterator(self, quantized_params_dict, fd_config: FDCo unquantized_model.eval() # 2.Get weight iterator _, safetensor_files = get_all_safetensors(fd_config.model_config.model) - weights_iterator = safetensors_weights_iterator(safetensor_files) + weights_iterator = fast_weights_iterator(safetensor_files) # 3.Get an iterator over the processed weights (e.g., tensor parallel splitting) . unquantized_params_dict = dict(unquantized_model.named_parameters()) processed_weights_iterator = unquantized_model.processed_weights(weights_iterator, unquantized_params_dict) diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index 862a5c22c2..55d04dcaa5 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -297,7 +297,7 @@ def processed_weights(self, weights_iterator, params_dict, is_processed=False) - for preprocessed_weight in weights_processor(param, loaded_weight, None) ) - if self.tie_word_embeddings: + if self.tie_word_embeddings and is_processed: self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0])) @paddle.no_grad() diff --git a/fastdeploy/model_executor/models/qwen3moe.py b/fastdeploy/model_executor/models/qwen3moe.py index 40e8f2b52e..d065ad05a7 100644 --- a/fastdeploy/model_executor/models/qwen3moe.py +++ b/fastdeploy/model_executor/models/qwen3moe.py @@ -36,7 +36,7 @@ RowParallelLinear, ) from fastdeploy.model_executor.layers.lm_head import ParallelLMHead -from fastdeploy.model_executor.layers.moe.moe import FusedMoE +from fastdeploy.model_executor.layers.moe import FusedMoE from fastdeploy.model_executor.layers.normalization import RMSNorm from fastdeploy.model_executor.models.model_base import ModelForCasualLM from fastdeploy.model_executor.models.qwen3 import Qwen3Attention diff --git a/fastdeploy/model_executor/models/utils.py b/fastdeploy/model_executor/models/utils.py index 148f78d8a8..5abb481ecc 100644 --- a/fastdeploy/model_executor/models/utils.py +++ b/fastdeploy/model_executor/models/utils.py @@ -42,6 +42,7 @@ from fastdeploy.config import FDConfig from fastdeploy.model_executor.layers.utils import get_tensor +from fastdeploy.model_executor.utils import slice_fn MAX_BSZ = 512 MAX_DRAFT_TOKENS = 6 @@ -54,20 +55,6 @@ def set_weight_attrs(param, param_attr_map: Optional[dict[str, Any]]): setattr(param, key, value) -def slice_fn(weight_or_paramter, output_dim, start, end, step=1): - if hasattr(weight_or_paramter, "get_shape"): - shape = weight_or_paramter.get_shape() - else: - shape = weight_or_paramter.shape - if len(shape) == 1: - weight_or_paramter = weight_or_paramter[start:end] - elif output_dim: - weight_or_paramter = weight_or_paramter[..., start:end] - else: - weight_or_paramter = weight_or_paramter[start:end, ...] - return weight_or_paramter - - def default_weights_processor(fd_config: FDConfig) -> None: """Default weight loader""" diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index 9731f886b0..2b3c7c79e1 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -19,6 +19,20 @@ import paddle +def slice_fn(weight_or_paramter, output_dim, start, end, step=1): + if hasattr(weight_or_paramter, "get_shape"): + shape = weight_or_paramter.get_shape() + else: + shape = weight_or_paramter.shape + if len(shape) == 1: + weight_or_paramter = weight_or_paramter[start:end] + elif output_dim: + weight_or_paramter = weight_or_paramter[..., start:end] + else: + weight_or_paramter = weight_or_paramter[start:end, ...] + return weight_or_paramter + + @contextmanager def device_guard(device="cpu", dev_id=0): origin_device = paddle.device.get_device() From 318c662861c13555bda672cc8e510e7bda464342 Mon Sep 17 00:00:00 2001 From: bukejiyu <395822456@qq.com> Date: Mon, 11 Aug 2025 06:47:47 +0000 Subject: [PATCH 3/7] update --- fastdeploy/model_executor/model_loader/default_loader_v1.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fastdeploy/model_executor/model_loader/default_loader_v1.py b/fastdeploy/model_executor/model_loader/default_loader_v1.py index e14af077d3..38bcce46e2 100644 --- a/fastdeploy/model_executor/model_loader/default_loader_v1.py +++ b/fastdeploy/model_executor/model_loader/default_loader_v1.py @@ -36,6 +36,10 @@ class DefaultModelLoaderV1(BaseModelLoader): """ModelLoader that can load registered models""" def __init__(self, load_config: LoadConfig): + assert not load_config.is_inflight_quant, ( + "Dynamic quantization requires running with --load_choices 'inflight_quant' " + "or load_choices='inflight_quant'." + ) super().__init__(load_config) def download_model(self, model_config: ModelConfig) -> None: From 6ea05c841144749ea30278bde9d7cb04d5241ed5 Mon Sep 17 00:00:00 2001 From: bukejiyu <395822456@qq.com> Date: Mon, 11 Aug 2025 07:11:08 +0000 Subject: [PATCH 4/7] update --- .../model_loader/default_loader_v1.py | 24 ++++++++++------- .../model_loader/inflight_quant_loader.py | 26 ++++++++++++------- fastdeploy/model_executor/models/qwen3.py | 8 +++--- 3 files changed, 36 insertions(+), 22 deletions(-) diff --git a/fastdeploy/model_executor/model_loader/default_loader_v1.py b/fastdeploy/model_executor/model_loader/default_loader_v1.py index 38bcce46e2..e0dc94e74e 100644 --- a/fastdeploy/model_executor/model_loader/default_loader_v1.py +++ b/fastdeploy/model_executor/model_loader/default_loader_v1.py @@ -45,18 +45,14 @@ def __init__(self, load_config: LoadConfig): def download_model(self, model_config: ModelConfig) -> None: pass - def clean_memory_fragments(self) -> None: + def _clean_memory_fragments(self) -> None: """clean_memory_fragments""" if current_platform.is_cuda(): paddle.device.cuda.empty_cache() paddle.device.synchronize() - @measure_time - def load_weights(self, model, fd_config: FDConfig) -> None: - _, safetensor_files = get_all_safetensors(fd_config.model_config.model) - weights_iterator = fast_weights_iterator(safetensor_files) - params_dict = dict(model.named_parameters()) - processed_weights_iterator = model.processed_weights(weights_iterator, params_dict) + def _load_weights_into_param(self, model, processed_weights_iterator): + """_load_weights_into_param""" for loaded_weight_name, _, model_param, preprocessed_weight, shard_id, expert_id in processed_weights_iterator: load_weights_into_param = getattr( model_param, "load_weights_into_param", default_load_weights_into_param() @@ -65,7 +61,17 @@ def load_weights(self, model, fd_config: FDConfig) -> None: load_weights_into_param(model_param, preprocessed_weight, expert_id, shard_id) else: load_weights_into_param(model_param, preprocessed_weight, shard_id) - self.clean_memory_fragments() + if hasattr(model, "after_load_weights"): + model.after_load_weights() + + @measure_time + def _load_weights(self, model, fd_config: FDConfig) -> None: + _, safetensor_files = get_all_safetensors(fd_config.model_config.model) + weights_iterator = fast_weights_iterator(safetensor_files) + params_dict = dict(model.named_parameters()) + processed_weights_iterator = model.processed_weights(weights_iterator, params_dict) + self._load_weights_into_param(model, processed_weights_iterator) + self._clean_memory_fragments() def load_model(self, fd_config: FDConfig) -> nn.Layer: architectures = fd_config.model_config.architectures[0] @@ -90,5 +96,5 @@ def load_model(self, fd_config: FDConfig) -> nn.Layer: if fd_config.load_config.dynamic_load_weight: return model - self.load_weights(model, fd_config) + self._load_weights(model, fd_config) return model diff --git a/fastdeploy/model_executor/model_loader/inflight_quant_loader.py b/fastdeploy/model_executor/model_loader/inflight_quant_loader.py index 751d748363..9d9a1ad6ba 100644 --- a/fastdeploy/model_executor/model_loader/inflight_quant_loader.py +++ b/fastdeploy/model_executor/model_loader/inflight_quant_loader.py @@ -47,7 +47,7 @@ def __init__(self, load_config: LoadConfig): def download_model(self, model_config: ModelConfig) -> None: pass - def clean_memory_fragments(self) -> None: + def _clean_memory_fragments(self) -> None: """clean_memory_fragments""" if current_platform.is_cuda(): paddle.device.cuda.empty_cache() @@ -106,13 +106,8 @@ def _get_quantized_weights_iterator(self, quantized_params_dict, fd_config: FDCo else: yield quant_weight_name, quant_weight - @measure_time - def load_weights(self, model, fd_config: FDConfig) -> None: - quantized_params_dict = dict(model.named_parameters()) - quanted_weights_iterator = self._get_quantized_weights_iterator(quantized_params_dict, fd_config) - processed_weights_iterator = model.processed_weights( - quanted_weights_iterator, quantized_params_dict, is_processed=True - ) + def _load_weights_into_param(self, model, processed_weights_iterator): + """_load_weights_into_param""" for loaded_weight_name, _, model_param, preprocessed_weight, shard_id, expert_id in processed_weights_iterator: load_weights_into_param = getattr( model_param, "load_weights_into_param", default_load_weights_into_param() @@ -121,7 +116,18 @@ def load_weights(self, model, fd_config: FDConfig) -> None: load_weights_into_param(model_param, preprocessed_weight, expert_id, shard_id) else: load_weights_into_param(model_param, preprocessed_weight, shard_id) - self.clean_memory_fragments() + if hasattr(model, "after_load_weights"): + model.after_load_weights() + + @measure_time + def _load_weights(self, model, fd_config: FDConfig) -> None: + quantized_params_dict = dict(model.named_parameters()) + quanted_weights_iterator = self._get_quantized_weights_iterator(quantized_params_dict, fd_config) + processed_weights_iterator = model.processed_weights( + quanted_weights_iterator, quantized_params_dict, is_processed=True + ) + self._load_weights_into_param(model, processed_weights_iterator) + self._clean_memory_fragments() def load_model(self, fd_config: FDConfig) -> nn.Layer: architectures = fd_config.model_config.architectures[0] @@ -145,5 +151,5 @@ def load_model(self, fd_config: FDConfig) -> nn.Layer: # RL model not need set_state_dict if fd_config.load_config.dynamic_load_weight: return model - self.load_weights(model, fd_config) + self._load_weights(model, fd_config) return model diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index 55d04dcaa5..0e44707365 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -245,6 +245,11 @@ def name(self): """ """ return "Qwen3ForCausalLM" + @paddle.no_grad() + def after_load_weights(self) -> None: + if self.tie_word_embeddings: + self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0])) + @paddle.no_grad() def processed_weights(self, weights_iterator, params_dict, is_processed=False) -> None: """ @@ -297,9 +302,6 @@ def processed_weights(self, weights_iterator, params_dict, is_processed=False) - for preprocessed_weight in weights_processor(param, loaded_weight, None) ) - if self.tie_word_embeddings and is_processed: - self.lm_head.linear.weight.set_value(self.model.embed_tokens.embeddings.weight.transpose([1, 0])) - @paddle.no_grad() def set_state_dict(self, state_dict): """ From 098091b62931c4e084571d4e4e6527fb9ad37a30 Mon Sep 17 00:00:00 2001 From: bukejiyu <395822456@qq.com> Date: Mon, 11 Aug 2025 10:52:53 +0000 Subject: [PATCH 5/7] support qwen3_moe --- .../model_executor/layers/embeddings.py | 2 +- fastdeploy/model_executor/layers/linear.py | 3 +- fastdeploy/model_executor/layers/lm_head.py | 2 +- .../layers/moe/fused_moe_backend_base.py | 2 +- .../layers/moe/fused_moe_cutlass_backend.py | 32 ++++++-- fastdeploy/model_executor/layers/moe/moe.py | 81 ++++++------------- .../layers/quantization/weight_only.py | 10 +-- fastdeploy/model_executor/models/qwen3.py | 13 ++- fastdeploy/model_executor/models/qwen3moe.py | 49 +++++++---- fastdeploy/model_executor/models/utils.py | 9 +-- fastdeploy/model_executor/utils.py | 8 ++ 11 files changed, 114 insertions(+), 97 deletions(-) diff --git a/fastdeploy/model_executor/layers/embeddings.py b/fastdeploy/model_executor/layers/embeddings.py index ba68c9ed00..5c26437ded 100644 --- a/fastdeploy/model_executor/layers/embeddings.py +++ b/fastdeploy/model_executor/layers/embeddings.py @@ -22,7 +22,7 @@ from paddle.distributed import fleet from fastdeploy.config import FDConfig -from fastdeploy.model_executor.models.utils import set_weight_attrs +from fastdeploy.model_executor.utils import set_weight_attrs from .utils import get_tensor diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index bccce80df0..3d99044b81 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -25,9 +25,8 @@ from fastdeploy.model_executor.models.utils import ( default_load_weights_into_param, default_weights_processor, - set_weight_attrs, - slice_fn, ) +from fastdeploy.model_executor.utils import set_weight_attrs, slice_fn from fastdeploy.platforms import current_platform from .utils import _set_var_distributed, divide, get_tensor diff --git a/fastdeploy/model_executor/layers/lm_head.py b/fastdeploy/model_executor/layers/lm_head.py index 6a76a72f75..fc3851f57a 100644 --- a/fastdeploy/model_executor/layers/lm_head.py +++ b/fastdeploy/model_executor/layers/lm_head.py @@ -22,7 +22,7 @@ from paddle.distributed import fleet from fastdeploy.config import FDConfig -from fastdeploy.model_executor.models.utils import set_weight_attrs +from fastdeploy.model_executor.utils import set_weight_attrs from .utils import get_tensor diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py index 7a8548f2e0..5b3b1c6a4c 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py @@ -19,7 +19,7 @@ import paddle from paddle import nn -from fastdeploy.model_executor.layers.utils import set_weight_attrs +from fastdeploy.model_executor.utils import set_weight_attrs from fastdeploy.platforms import current_platform from ..quantization.quant_base import QuantMethodBase diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 2be90f8f99..5b3f83e598 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -21,6 +21,7 @@ import fastdeploy from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce +from fastdeploy.model_executor.utils import set_weight_attrs from fastdeploy.platforms import current_platform from ..utils import get_tensor @@ -93,8 +94,8 @@ def compute_ffn( return fastdeploy.model_executor.ops.iluvatar.moe_expert_ffn( permute_input, token_nums_per_expert, - layer.up_gate_proj_weight, - layer.down_proj_weight, + getattr(layer, self.added_weight_attrs[0]), + getattr(layer, self.added_weight_attrs[1]), None, (layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None), (layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None), @@ -106,8 +107,8 @@ def compute_ffn( return fastdeploy.model_executor.ops.gpu.moe_expert_ffn( permute_input, token_nums_per_expert, - layer.up_gate_proj_weight, - layer.down_proj_weight, + getattr(layer, self.added_weight_attrs[0]), + getattr(layer, self.added_weight_attrs[1]), None, (layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None), (layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None), @@ -627,6 +628,7 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): self.default_dtype = layer._helper.get_default_dtype() self.weight_dtype = "int8" + self.added_weight_attrs = ["up_gate_proj_quant_weight", "down_proj_quant_weight"] up_gate_proj_weight_name = self.added_weight_attrs[0] down_proj_weight_name = self.added_weight_attrs[1] if self.moe_quant_type == "weight_only_int4": @@ -653,6 +655,7 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): layer.hidden_size, layer.moe_intermediate_size, ] + # layer.up_gate_proj_quant_weight setattr( layer, up_gate_proj_weight_name, @@ -662,6 +665,7 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): default_initializer=paddle.nn.initializer.Constant(0), ), ) + # layer.down_proj_quant_weight setattr( layer, down_proj_weight_name, @@ -671,7 +675,7 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): default_initializer=paddle.nn.initializer.Constant(0), ), ) - # weight_scale + # layer.up_gate_proj_weight_scale setattr( layer, self.added_scale_attrs[0], @@ -681,6 +685,7 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): default_initializer=paddle.nn.initializer.Constant(0), ), ) + # layer.down_proj_weight_scale setattr( layer, self.added_scale_attrs[1], @@ -691,6 +696,23 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): ), ) + inflight_quant = extra_weight_attrs.get("inflight_quant", None) + moe_extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}} + if inflight_quant is True: + moe_extra_weight_attrs = {**moe_extra_weight_attrs, "quant_method": self.apply_weight_quantization} + set_weight_attrs(layer.up_gate_proj_quant_weight, moe_extra_weight_attrs) + set_weight_attrs(layer.down_proj_quant_weight, moe_extra_weight_attrs) + scale_extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "up": 0}} + set_weight_attrs(layer.up_gate_proj_weight_scale, scale_extra_weight_attrs) + set_weight_attrs(layer.down_proj_weight_scale, scale_extra_weight_attrs) + + def apply_weight_quantization(self, unquantized_weight): + quanted_weight_tensor, weight_scale_tensor = weight_quantize( + unquantized_weight, + algo=self.moe_quant_type, + ) + return (quanted_weight_tensor, weight_scale_tensor) + def process_loaded_weights(self, layer: nn.Layer, state_dict): """ Paddle cutlass load weight process. diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index bb60629cfd..4236942b6c 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -23,8 +23,6 @@ from fastdeploy import envs from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.utils import slice_fn - -# from fastdeploy.model_executor.models.utils import slice_fn from fastdeploy.platforms import current_platform from fastdeploy.worker.experts_manager import RedundantExpertManger @@ -141,6 +139,11 @@ def __init__( ) self.quant_method.init_ep(self) + inflight_quant = moe_quant_config and ( + not self.fd_config.model_config.is_quantized + or self.fd_config.model_config.is_quantized + and not getattr(self.fd_config.quant_config, "is_permuted", True) + ) if fd_config.load_config.dynamic_load_weight: # It's for RL to build model self.init_moe_weights() @@ -158,6 +161,7 @@ def __init__( self, weights_processor=self.weights_processor, load_weights_into_param=self.load_weights_into_param, + inflight_quant=inflight_quant, ) else: # w_fp16 a_fp16 @@ -165,6 +169,7 @@ def __init__( self, weights_processor=self.weights_processor, load_weights_into_param=self.load_weights_into_param, + inflight_quant=inflight_quant, ) logger.info( @@ -202,73 +207,37 @@ def load_weights_into_param(self, param, loaded_weight, expert_id: int, shard_id assert shard_id in ["gate", "down", "up"] SHARD_ID_TO_SHARDED_DIM = getattr(param, "SHARD_ID_TO_SHARDED_DIM") expert_param = param[expert_id] - self._load_expert_weight_into_param( - expert_param=expert_param, - shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id], - loaded_weight=loaded_weight, - shard_id=shard_id, - ) + if shard_id in SHARD_ID_TO_SHARDED_DIM: + self._load_expert_weight_into_param( + expert_param=expert_param, + shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id], + loaded_weight=loaded_weight, + shard_id=shard_id, + ) + else: + # for down_scale + expert_param.copy_(loaded_weight, False) - def weights_processor(self, param, loaded_weight, expert_id: int, shard_id: Optional[str] = None): + def weights_processor(self, param, loaded_weight, shard_id: Optional[str] = None): if shard_id is None: # 1.gate up fused in disk return # 2.gate up splited in disk assert shard_id in ["gate", "down", "up"] SHARD_ID_TO_SHARDED_DIM = getattr(param, "SHARD_ID_TO_SHARDED_DIM") - expert_param = param[expert_id] - - yield from self._processed_expert_weight( - expert_param=expert_param, - shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id], - loaded_weight=loaded_weight, - shard_id=shard_id, - ) - - def _processed_gate_up_weight(self, expert_param, shard_dim: int, loaded_weight, shard_id: str): - - if self.tp_size > 1: - size = loaded_weight.get_shape()[-1] - block_size = size // self.tp_size - shard_offset = self.tp_rank * block_size - shard_size = (self.tp_rank + 1) * block_size - loaded_weight = loaded_weight[..., shard_offset:shard_size] - - loaded_weight = get_tensor(loaded_weight) - - # To ensure compatibility across backends, apply an extra transpose for GCU and XPU - if not current_platform.is_cuda(): - loaded_weight = loaded_weight.transpose([1, 0]) - yield loaded_weight - - def _processed_down_weight(self, expert_param, shard_dim: int, loaded_weight, shard_id: str): - if self.tp_size > 1: - size = loaded_weight.get_shape()[shard_dim] + if shard_id in SHARD_ID_TO_SHARDED_DIM and self.tp_size > 1: + shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] + dim = -1 if shard_dim else 0 + size = loaded_weight.get_shape()[dim] block_size = size // self.tp_size shard_offset = self.tp_rank * block_size shard_size = (self.tp_rank + 1) * block_size - loaded_weight = loaded_weight[shard_offset:shard_size, ...] + loaded_weight = slice_fn(loaded_weight, shard_dim, shard_offset, shard_size) loaded_weight = get_tensor(loaded_weight) - - # To ensure compatibility across backends, apply an extra transpose for GCU and XPU if not current_platform.is_cuda(): loaded_weight = loaded_weight.transpose([1, 0]) yield loaded_weight - def _processed_expert_weight( - self, - expert_param, - shard_dim: int, - loaded_weight, - shard_id: str, - ): - if shard_id == "down": - yield from self._processed_down_weight(expert_param, shard_dim, loaded_weight, shard_id) - - elif shard_id in ["gate", "up"]: - - yield from self._processed_gate_up_weight(expert_param, shard_dim, loaded_weight, shard_id) - @classmethod def make_expert_params_mapping( cls, @@ -284,9 +253,9 @@ def make_expert_params_mapping( param_name_maping = [("down", ckpt_down_proj_name)] if ckpt_gate_up_proj_name is not None: param_name_maping.append((None, ckpt_gate_up_proj_name)) - elif ckpt_gate_proj_name is not None: + if ckpt_gate_proj_name is not None: param_name_maping.append(("gate", ckpt_gate_proj_name)) - elif ckpt_up_proj_name is not None: + if ckpt_up_proj_name is not None: param_name_maping.append(("up", ckpt_up_proj_name)) return [ diff --git a/fastdeploy/model_executor/layers/quantization/weight_only.py b/fastdeploy/model_executor/layers/quantization/weight_only.py index a7ef9156dc..6f6fd562b4 100644 --- a/fastdeploy/model_executor/layers/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/quantization/weight_only.py @@ -21,7 +21,7 @@ import paddle from paddle.nn.quant import weight_only_linear, weight_quantize -from fastdeploy.model_executor.models.utils import set_weight_attrs +from fastdeploy.model_executor.utils import set_weight_attrs from fastdeploy.platforms import current_platform from ..moe import FusedMoE @@ -188,14 +188,14 @@ def create_weights(self, layer, **extra_weight_attrs): inflight_quant = extra_weight_attrs.get("inflight_quant", None) output_dim = extra_weight_attrs.get("output_dim") output_dim = not output_dim - weight_loader = extra_weight_attrs.get("weight_loader") + weights_processor = extra_weight_attrs.get("weights_processor") load_weights_into_param = extra_weight_attrs.get("load_weights_into_param") weight_attrs = { - "weight_loader": weight_loader, + "weights_processor": weights_processor, "load_weights_into_param": load_weights_into_param, "output_dim": output_dim, } - if inflight_quant: + if inflight_quant is True: weight_attrs = {**weight_attrs, "quant_method": self.apply_weight_quantization} set_weight_attrs( layer.quant_weight, @@ -213,7 +213,7 @@ def create_weights(self, layer, **extra_weight_attrs): set_weight_attrs( layer.weight_scale, { - "weight_loader": weight_loader, + "weights_processor": weights_processor, "output_dim": output_dim, "load_weights_into_param": load_weights_into_param, }, diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index 0e44707365..7ecba76190 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -253,10 +253,19 @@ def after_load_weights(self) -> None: @paddle.no_grad() def processed_weights(self, weights_iterator, params_dict, is_processed=False) -> None: """ - Load model parameters from a given weights_iterator object. + process weights from a given weights_iterator object. Args: - weights_iterator (Iterator): An iterator yielding (name, weight) pairs. + weights_iterator: iterator yielding weight tuples + params_dict: dict of model parameters by name + is_processed: whether weights are already preprocessed (default False) + weights_iterator (Iterator): + Yield the following items: + loaded_weight_name: name of the weight as loaded from storage + model_param_name: parameter name used in the model + param: model parameter object + preprocessed_weight: weight after preprocessing (e.g., slicing or normalization) + shard_id: ID for tensor parallel shard or partition """ from fastdeploy.model_executor.models.utils import default_weights_processor diff --git a/fastdeploy/model_executor/models/qwen3moe.py b/fastdeploy/model_executor/models/qwen3moe.py index d065ad05a7..eb8ee55572 100644 --- a/fastdeploy/model_executor/models/qwen3moe.py +++ b/fastdeploy/model_executor/models/qwen3moe.py @@ -326,12 +326,22 @@ def get_expert_mapping( ) @paddle.no_grad() - def processed_weights(self, weights_iterator, params_dict) -> None: + def processed_weights(self, weights_iterator, params_dict, is_processed=False): """ - Load model parameters from a given weights_iterator object. + process weights from a given weights_iterator object. Args: - weights_iterator (Iterator): An iterator yielding (name, weight) pairs. + weights_iterator: iterator yielding weight tuples + params_dict: dict of model parameters by name + is_processed: whether weights are already preprocessed (default False) + weights_iterator (Iterator): + Yield the following items: + loaded_weight_name: name of the weight as loaded from storage + model_param_name: parameter name used in the model + param: model parameter object + preprocessed_weight: weight after preprocessing (e.g., slicing or normalization) + shard_id: ID for tensor parallel shard or partition + expert_id: expert index in MoE models;may be None if not applicable. """ from fastdeploy.model_executor.models.utils import default_weights_processor @@ -359,10 +369,13 @@ def processed_weights(self, weights_iterator, params_dict) -> None: continue param = params_dict[model_param_name] weights_processor = getattr(param, "weights_processor", default_weights_processor(self.fd_config)) - yield from ( - (loaded_weight_name, model_param_name, param, preprocessed_weight, shard_id, None) - for preprocessed_weight in weights_processor(param, loaded_weight, shard_id) - ) + if is_processed: + yield loaded_weight_name, model_param_name, param, loaded_weight, shard_id, None + else: + yield from ( + (loaded_weight_name, model_param_name, param, preprocessed_weight, shard_id, None) + for preprocessed_weight in weights_processor(param, loaded_weight, shard_id) + ) break else: for mapping in expert_params_mapping: @@ -374,12 +387,13 @@ def processed_weights(self, weights_iterator, params_dict) -> None: continue param = params_dict[model_param_name] weights_processor = param.weights_processor - yield from ( - (loaded_weight_name, model_param_name, param, preprocessed_weight, shard_id, expert_id) - for preprocessed_weight in weights_processor( - param, loaded_weight, shard_id=shard_id, expert_id=expert_id + if is_processed: + yield loaded_weight_name, model_param_name, param, loaded_weight, shard_id, expert_id + else: + yield from ( + (loaded_weight_name, model_param_name, param, preprocessed_weight, shard_id, expert_id) + for preprocessed_weight in weights_processor(param, loaded_weight, shard_id=shard_id) ) - ) break else: if loaded_weight_name not in params_dict: @@ -387,10 +401,13 @@ def processed_weights(self, weights_iterator, params_dict) -> None: param = params_dict[loaded_weight_name] weights_processor = getattr(param, "weights_processor", default_weights_processor(self.fd_config)) weights_processor(param, loaded_weight) - yield from ( - (loaded_weight_name, model_param_name, param, preprocessed_weight, None, None) - for preprocessed_weight in weights_processor(param, loaded_weight) - ) + if is_processed: + yield loaded_weight_name, loaded_weight_name, param, loaded_weight, None, None + else: + yield from ( + (loaded_weight_name, loaded_weight_name, param, preprocessed_weight, None, None) + for preprocessed_weight in weights_processor(param, loaded_weight) + ) @paddle.no_grad() def set_state_dict(self, state_dict): diff --git a/fastdeploy/model_executor/models/utils.py b/fastdeploy/model_executor/models/utils.py index 5abb481ecc..a8a1594934 100644 --- a/fastdeploy/model_executor/models/utils.py +++ b/fastdeploy/model_executor/models/utils.py @@ -24,7 +24,7 @@ import re import struct from functools import partial -from typing import Any, NamedTuple, Optional, Union +from typing import NamedTuple, Optional, Union import numpy as np import paddle @@ -48,13 +48,6 @@ MAX_DRAFT_TOKENS = 6 -def set_weight_attrs(param, param_attr_map: Optional[dict[str, Any]]): - if param_attr_map is None: - return - for key, value in param_attr_map.items(): - setattr(param, key, value) - - def default_weights_processor(fd_config: FDConfig) -> None: """Default weight loader""" diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index 2b3c7c79e1..db2687bc40 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -15,10 +15,18 @@ """ from contextlib import contextmanager +from typing import Any, Optional import paddle +def set_weight_attrs(param, param_attr_map: Optional[dict[str, Any]]): + if param_attr_map is None: + return + for key, value in param_attr_map.items(): + setattr(param, key, value) + + def slice_fn(weight_or_paramter, output_dim, start, end, step=1): if hasattr(weight_or_paramter, "get_shape"): shape = weight_or_paramter.get_shape() From 69a64e1647ea9bed26cb3f8f239ef1c3465f81b9 Mon Sep 17 00:00:00 2001 From: bukejiyu <395822456@qq.com> Date: Mon, 11 Aug 2025 11:13:11 +0000 Subject: [PATCH 6/7] update --- .../layers/backends/gcu/quantization/weight_only.py | 8 ++++---- .../layers/backends/xpu/quantization/weight_only.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/fastdeploy/model_executor/layers/backends/gcu/quantization/weight_only.py b/fastdeploy/model_executor/layers/backends/gcu/quantization/weight_only.py index 9aebf64ce0..fc4289eb3b 100644 --- a/fastdeploy/model_executor/layers/backends/gcu/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/backends/gcu/quantization/weight_only.py @@ -46,7 +46,7 @@ def create_weights(self, layer, **extra_weight_attrs): layer.weight_shape[0] //= 2 layer.weight_dtype = "int8" - layer.weight = layer.create_parameter( + layer.quant_weight = layer.create_parameter( shape=layer.weight_shape, dtype=layer.weight_dtype, is_bias=False, @@ -69,7 +69,7 @@ def process_prequanted_weights(self, layer, state_dict) -> None: """ quant_weight = get_tensor(state_dict.pop(layer.weight_key)) weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key)) - layer.weight.set_value(quant_weight) + layer.quant_weight.set_value(quant_weight) layer.weight_scale.set_value(weight_scale.astype(paddle.get_default_dtype())) def process_loaded_weights(self, layer, weight) -> None: @@ -79,14 +79,14 @@ def process_loaded_weights(self, layer, weight) -> None: self.group_size, # group_size ) - layer.weight.set_value(quanted_weight_tensor) + layer.quant_weight.set_value(quanted_weight_tensor) layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype())) @paddle.no_grad() def apply(self, layer, x): linear_out = linear_quant( lhs=x, - rhs=layer.weight, + rhs=layer.quant_weight, scale=layer.weight_scale, bias=None, group_size=self.group_size, diff --git a/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py b/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py index b010f958f0..434c3e162c 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py @@ -45,7 +45,7 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs) -> None: if self.quant_config.name() == "weight_only_int4": layer.weight_shape[0] //= 2 layer.weight_dtype = "int8" - layer.weight = layer.create_parameter( + layer.quant_weight = layer.create_parameter( shape=layer.weight_shape, dtype=layer.weight_dtype, is_bias=False, @@ -62,5 +62,5 @@ def process_loaded_weights(self, layer: nn.Layer, weight: paddle.Tensor) -> None loaded_weights using xpu special quantization """ quanted_weight_tensor, weight_scale_tensor = weight_quantize_xpu(weight, self.quant_config.algo, -1, -1) - layer.weight.set_value(paddle.transpose(quanted_weight_tensor, [1, 0])) + layer.quant_weight.set_value(paddle.transpose(quanted_weight_tensor, [1, 0])) layer.weight_scale.set_value(weight_scale_tensor) From bdcf6f3bb7925bae28e3bc5a1610b9b79d8e62bc Mon Sep 17 00:00:00 2001 From: bukejiyu <395822456@qq.com> Date: Mon, 11 Aug 2025 12:13:31 +0000 Subject: [PATCH 7/7] fix wint4 --- fastdeploy/model_executor/layers/linear.py | 16 +++++++++------- .../layers/quantization/weight_only.py | 1 - 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 3d99044b81..8ad53fdb00 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -420,10 +420,11 @@ def __init__( def load_weights_into_param(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): assert loaded_shard_id in ["gate", "up"] output_dim = getattr(param, "output_dim", None) + tensor_size = param.shape[output_dim] // 2 if loaded_shard_id == "gate": - param = slice_fn(param, output_dim, start=0, end=self.output_size // 2) + param = slice_fn(param, output_dim, start=0, end=tensor_size) elif loaded_shard_id == "up": - param = slice_fn(param, output_dim, start=self.output_size // 2, end=self.output_size) + param = slice_fn(param, output_dim, start=tensor_size, end=tensor_size * 2) assert param.shape == loaded_weight.shape, ( f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" ) @@ -533,23 +534,24 @@ def weights_processor(self, param, loaded_weight, loaded_shard_id: Optional[str] def load_weights_into_param(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): assert loaded_shard_id in ["q", "k", "v"] output_dim = getattr(param, "output_dim", None) + head_dim = param.shape[output_dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) if loaded_shard_id == "q": - param = slice_fn(param, output_dim, 0, self.num_heads_per_rank * self.head_dim) + param = slice_fn(param, output_dim, 0, self.num_heads_per_rank * head_dim) elif loaded_shard_id == "k": param = slice_fn( param, output_dim, - self.num_heads_per_rank * self.head_dim, - (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim, + self.num_heads_per_rank * head_dim, + (self.num_heads_per_rank + self.kv_num_heads_per_rank) * head_dim, ) elif loaded_shard_id == "v": param = slice_fn( param, output_dim, - (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim, - (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) * self.head_dim, + (self.num_heads_per_rank + self.kv_num_heads_per_rank) * head_dim, + (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) * head_dim, ) assert param.shape == loaded_weight.shape, ( diff --git a/fastdeploy/model_executor/layers/quantization/weight_only.py b/fastdeploy/model_executor/layers/quantization/weight_only.py index 6f6fd562b4..79abd71c1e 100644 --- a/fastdeploy/model_executor/layers/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/quantization/weight_only.py @@ -177,7 +177,6 @@ def create_weights(self, layer, **extra_weight_attrs): if self.quant_config.name() == "wint4": layer.weight_shape[0] //= 2 layer.weight_dtype = "int8" - layer.quant_weight = layer.create_parameter( shape=layer.weight_shape, dtype=layer.weight_dtype,