Skip to content

[feat]support inflight quant #3277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ class LoadChoices(str, Enum):
DEFAULT = "default"
# only support qwen3-bf16 now
DEFAULT_V1 = "default_v1"
INFLIGHT_QUANT = "inflight_quant"


class LoadConfig:
Expand All @@ -685,6 +686,7 @@ def __init__(
args,
):
self.load_choices: Union[str, LoadChoices] = LoadChoices.DEFAULT.value
self.is_inflight_quant = False
self.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1
self.dynamic_load_weight: bool = False
self.load_strategy: Optional[Literal["ipc", "ipc_snapshot"]] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def create_weights(self, layer, **extra_weight_attrs):
layer.weight_shape[0] //= 2
layer.weight_dtype = "int8"

layer.weight = layer.create_parameter(
layer.quant_weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
Expand All @@ -69,7 +69,7 @@ def process_prequanted_weights(self, layer, state_dict) -> None:
"""
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
layer.weight.set_value(quant_weight)
layer.quant_weight.set_value(quant_weight)
layer.weight_scale.set_value(weight_scale.astype(paddle.get_default_dtype()))

def process_loaded_weights(self, layer, weight) -> None:
Expand All @@ -79,14 +79,14 @@ def process_loaded_weights(self, layer, weight) -> None:
self.group_size, # group_size
)

layer.weight.set_value(quanted_weight_tensor)
layer.quant_weight.set_value(quanted_weight_tensor)
layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype()))

@paddle.no_grad()
def apply(self, layer, x):
linear_out = linear_quant(
lhs=x,
rhs=layer.weight,
rhs=layer.quant_weight,
scale=layer.weight_scale,
bias=None,
group_size=self.group_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs) -> None:
if self.quant_config.name() == "weight_only_int4":
layer.weight_shape[0] //= 2
layer.weight_dtype = "int8"
layer.weight = layer.create_parameter(
layer.quant_weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
Expand All @@ -62,5 +62,5 @@ def process_loaded_weights(self, layer: nn.Layer, weight: paddle.Tensor) -> None
loaded_weights using xpu special quantization
"""
quanted_weight_tensor, weight_scale_tensor = weight_quantize_xpu(weight, self.quant_config.algo, -1, -1)
layer.weight.set_value(paddle.transpose(quanted_weight_tensor, [1, 0]))
layer.quant_weight.set_value(paddle.transpose(quanted_weight_tensor, [1, 0]))
layer.weight_scale.set_value(weight_scale_tensor)
2 changes: 1 addition & 1 deletion fastdeploy/model_executor/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from paddle.distributed import fleet

from fastdeploy.config import FDConfig
from fastdeploy.model_executor.models.utils import set_weight_attrs
from fastdeploy.model_executor.utils import set_weight_attrs

from .utils import get_tensor

Expand Down
142 changes: 101 additions & 41 deletions fastdeploy/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
from fastdeploy.model_executor.models.utils import (
default_weight_loader,
set_weight_attrs,
default_load_weights_into_param,
default_weights_processor,
)
from fastdeploy.model_executor.utils import set_weight_attrs, slice_fn
from fastdeploy.platforms import current_platform

from .utils import _set_var_distributed, divide, get_tensor
Expand All @@ -37,21 +38,32 @@ class UnquantizedLinearMethod(QuantMethodBase):
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
extra_weight_attrs is a dictionary that may include parameters like:
- split_axis: axis along which to split the tensor in a distributed environment
- output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns)
- weight_loader: a callable or method responsible for loading the weight data
- weights_processor: a callable or method responsible for processing weight data
- load_weights_into_param:Loads the given weight tensor into the specified model parameter.
"""
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
split_axis = extra_weight_attrs.get("split_axis")
if hasattr(layer, "nranks") and layer.nranks > 0:
_set_var_distributed(layer.weight, split_axis=split_axis)
set_weight_attrs(
layer.weight,
{"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config))},
{
**extra_weight_attrs,
"weights_processor": extra_weight_attrs.get(
"weights_processor", default_weights_processor(layer.fd_config)
),
"load_weights_into_param": extra_weight_attrs.get(
"load_weights_into_param", default_load_weights_into_param()
),
},
)
if hasattr(layer, "nranks") and layer.nranks > 1:
set_weight_attrs(layer.weight, {"output_dim": extra_weight_attrs.get("output_dim")})

def process_loaded_weights(self, layer, weights) -> None:
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
Expand Down Expand Up @@ -158,6 +170,7 @@ def __init__(
is_bias=True,
)

self.is_quantized = fd_config.model_config.is_quantized
# smooth quant
self.linear_shift = None
self.linear_smooth = None
Expand Down Expand Up @@ -270,9 +283,17 @@ def __init__(
assert self.quant_method is not None
self.quant_method.create_weights(
self,
weight_loader=(
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
weights_processor=(
self.weights_processor
if hasattr(self, "weights_processor")
else default_weights_processor(self.fd_config)
),
load_weights_into_param=(
self.load_weights_into_param
if hasattr(self, "load_weights_into_param")
else default_load_weights_into_param()
),
inflight_quant=fd_config.quant_config and not skip_quant,
)


Expand Down Expand Up @@ -327,17 +348,23 @@ def __init__(
self.quant_method.create_weights(
self,
output_dim=True,
weight_loader=(
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
weights_processor=(
self.weights_processor
if hasattr(self, "weights_processor")
else default_weights_processor(self.fd_config)
),
load_weights_into_param=(
self.load_weights_into_param
if hasattr(self, "load_weights_into_param")
else default_load_weights_into_param()
),
inflight_quant=fd_config.quant_config and not skip_quant,
)

if self.nranks > 0:
_set_var_distributed(self.weight, split_axis=1)
if self.with_bias:
# col parallel
_set_var_distributed(self.bias, split_axis=1)
if self.nranks > 1:
set_weight_attrs(self.bias, {"output_dim": True})


class MergedColumnParallelLinear(ColumnParallelLinear):
Expand Down Expand Up @@ -390,31 +417,34 @@ def __init__(
skip_quant=skip_quant,
)

def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
def load_weights_into_param(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
assert loaded_shard_id in ["gate", "up"]
output_dim = getattr(param, "output_dim", None)
tensor_size = param.shape[output_dim] // 2
if loaded_shard_id == "gate":
param = slice_fn(param, output_dim, start=0, end=tensor_size)
elif loaded_shard_id == "up":
param = slice_fn(param, output_dim, start=tensor_size, end=tensor_size * 2)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(loaded_weight, False)

def weights_processor(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
# 1.fused gate_up in disk
# 2.split gate up
assert loaded_shard_id in ["gate", "up"]
output_dim = getattr(param, "output_dim", None)
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None:
if output_dim is not None and self.nranks > 1:
dim = -1
size = loaded_weight.get_shape()[dim]
block_size = size // self.nranks
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
loaded_weight = loaded_weight[..., shard_offset:shard_size]

loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size)
loaded_weight = get_tensor(loaded_weight)

if loaded_shard_id == "gate":
param = param[:, : self.output_size // 2]
elif loaded_shard_id == "up":
param = param[:, self.output_size // 2 :]

assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(loaded_weight, False)
yield loaded_weight

def load_state_dict(self, state_dict: dict):
"""
Expand Down Expand Up @@ -484,33 +514,45 @@ def __init__(self, fd_config, prefix, with_bias=False, add_bias=True):
add_bias=add_bias,
)

def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
def weights_processor(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
# 1.fused qkv in disk
# 2.split q k v
assert loaded_shard_id in ["q", "k", "v"]
output_dim = getattr(param, "output_dim", None)
# Tensor parallelism splits the weight along the output_dim
if output_dim is not None:
if output_dim is not None and self.nranks > 1:
dim = -1
size = loaded_weight.get_shape()[dim]
block_size = size // self.nranks
shard_offset = self.local_rank * block_size
shard_size = (self.local_rank + 1) * block_size
loaded_weight = loaded_weight[..., shard_offset:shard_size]
loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size)

loaded_weight = get_tensor(loaded_weight)
yield loaded_weight

def load_weights_into_param(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
assert loaded_shard_id in ["q", "k", "v"]
output_dim = getattr(param, "output_dim", None)
head_dim = param.shape[output_dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
if loaded_shard_id == "q":
param = param[:, : self.num_heads_per_rank * self.head_dim]
param = slice_fn(param, output_dim, 0, self.num_heads_per_rank * head_dim)

elif loaded_shard_id == "k":
param = param[
:,
self.num_heads_per_rank
* self.head_dim : (self.num_heads_per_rank + self.kv_num_heads_per_rank)
* self.head_dim,
]
param = slice_fn(
param,
output_dim,
self.num_heads_per_rank * head_dim,
(self.num_heads_per_rank + self.kv_num_heads_per_rank) * head_dim,
)

elif loaded_shard_id == "v":
param = param[:, (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim :]
param = slice_fn(
param,
output_dim,
(self.num_heads_per_rank + self.kv_num_heads_per_rank) * head_dim,
(self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) * head_dim,
)

assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
Expand Down Expand Up @@ -653,12 +695,19 @@ def __init__(
self,
split_axis=0,
output_dim=False,
weight_loader=(
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
weights_processor=(
self.weights_processor
if hasattr(self, "weights_processor")
else default_weights_processor(self.fd_config)
),
load_weights_into_param=(
self.load_weights_into_param
if hasattr(self, "load_weights_into_param")
else default_load_weights_into_param()
),
inflight_quant=fd_config.quant_config and not skip_quant,
)
if self.nranks > 0:
_set_var_distributed(self.weight, split_axis=0)
if self.with_bias:
# col parallel
_set_var_distributed(self.bias, split_axis=0)
Expand All @@ -670,6 +719,17 @@ def __init__(
},
)

if self.nranks > 0:
if self.with_bias:
# col parallel
_set_var_distributed(self.bias, split_axis=0)
set_weight_attrs(
self.bias,
{
"output_dim": False,
},
)

self.reduce_results = reduce_results

def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/model_executor/layers/lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from paddle.distributed import fleet

from fastdeploy.config import FDConfig
from fastdeploy.model_executor.models.utils import set_weight_attrs
from fastdeploy.model_executor.utils import set_weight_attrs

from .utils import get_tensor

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import paddle
from paddle import nn

from fastdeploy.model_executor.layers.utils import set_weight_attrs
from fastdeploy.model_executor.utils import set_weight_attrs
from fastdeploy.platforms import current_platform

from ..quantization.quant_base import QuantMethodBase
Expand Down Expand Up @@ -185,9 +185,11 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
if current_platform.is_cuda():
self.up_gate_proj_weight_shape = [layer.num_experts, layer.hidden_size, layer.moe_intermediate_size * 2]
self.down_proj_weight_shape = [layer.num_experts, layer.moe_intermediate_size, layer.hidden_size]
extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 1, "down": 0, "up": 1}}
else:
self.up_gate_proj_weight_shape = [layer.num_experts, layer.moe_intermediate_size * 2, layer.hidden_size]
self.down_proj_weight_shape = [layer.num_experts, layer.hidden_size, layer.moe_intermediate_size]
extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}

layer.up_gate_proj_weight = layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
Expand Down
Loading
Loading