Skip to content

Commit 2508956

Browse files
committed
support inflight quant
1 parent afff4d3 commit 2508956

File tree

15 files changed

+501
-133
lines changed

15 files changed

+501
-133
lines changed

fastdeploy/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,7 @@ class LoadChoices(str, Enum):
665665
DEFAULT = "default"
666666
# only support qwen3-bf16 now
667667
DEFAULT_V1 = "default_v1"
668+
INFLIGHT_QUANT = "inflight_quant"
668669

669670

670671
class LoadConfig:
@@ -684,6 +685,7 @@ def __init__(
684685
args,
685686
):
686687
self.load_choices: Union[str, LoadChoices] = LoadChoices.DEFAULT.value
688+
self.is_inflight_quant = False
687689
self.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1
688690
self.dynamic_load_weight: bool = False
689691
self.load_strategy: Optional[Literal["ipc", "ipc_snapshot"]] = None

fastdeploy/model_executor/layers/linear.py

Lines changed: 98 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
2424
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
2525
from fastdeploy.model_executor.models.utils import (
26-
default_weight_loader,
26+
default_load_weights_into_param,
27+
default_weights_processor,
2728
set_weight_attrs,
29+
slice_fn,
2830
)
2931
from fastdeploy.platforms import current_platform
3032

@@ -37,24 +39,29 @@ class UnquantizedLinearMethod(QuantMethodBase):
3739
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
3840
"""
3941
extra_weight_attrs is a dictionary that may include parameters like:
40-
- split_axis: specifies which axis to split the weight tensor on (for distributed weight partitioning)
41-
- output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns)
42-
- weight_loader: a callable or method responsible for loading the weight data
42+
- weights_processor: a callable or method responsible for loading the weight data
4343
"""
4444
layer.weight = layer.create_parameter(
4545
shape=layer.weight_shape,
4646
dtype=layer.weight_dtype,
4747
is_bias=False,
4848
default_initializer=paddle.nn.initializer.Constant(0),
4949
)
50+
split_axis = extra_weight_attrs.get("split_axis")
51+
if hasattr(layer, "nranks") and layer.nranks > 0:
52+
_set_var_distributed(layer.weight, split_axis=split_axis)
5053
set_weight_attrs(
5154
layer.weight,
52-
{"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config))},
55+
{
56+
**extra_weight_attrs,
57+
"weights_processor": extra_weight_attrs.get(
58+
"weights_processor", default_weights_processor(layer.fd_config)
59+
),
60+
"load_weights_into_param": extra_weight_attrs.get(
61+
"load_weights_into_param", default_load_weights_into_param()
62+
),
63+
},
5364
)
54-
if hasattr(layer, "nranks") and layer.nranks > 0:
55-
split_axis = extra_weight_attrs.get("split_axis")
56-
_set_var_distributed(layer.weight, split_axis=split_axis)
57-
set_weight_attrs(layer.weight, {"output_dim": extra_weight_attrs.get("output_dim")})
5865

5966
def process_loaded_weights(self, layer, weights) -> None:
6067
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
@@ -157,6 +164,7 @@ def __init__(
157164
is_bias=True,
158165
)
159166

167+
self.is_quantized = fd_config.model_config.is_quantized
160168
# smooth quant
161169
self.linear_shift = None
162170
self.linear_smooth = None
@@ -274,9 +282,17 @@ def __init__(
274282
assert self.quant_method is not None
275283
self.quant_method.create_weights(
276284
self,
277-
weight_loader=(
278-
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
285+
weights_processor=(
286+
self.weights_processor
287+
if hasattr(self, "weights_processor")
288+
else default_weights_processor(self.fd_config)
289+
),
290+
load_weights_into_param=(
291+
self.load_weights_into_param
292+
if hasattr(self, "load_weights_into_param")
293+
else default_load_weights_into_param()
279294
),
295+
inflight_quant=fd_config.quant_config and not skip_quant,
280296
)
281297

282298

@@ -335,16 +351,23 @@ def __init__(
335351
self,
336352
split_axis=1,
337353
output_dim=True,
338-
weight_loader=(
339-
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
354+
weights_processor=(
355+
self.weights_processor
356+
if hasattr(self, "weights_processor")
357+
else default_weights_processor(self.fd_config)
340358
),
359+
load_weights_into_param=(
360+
self.load_weights_into_param
361+
if hasattr(self, "load_weights_into_param")
362+
else default_load_weights_into_param()
363+
),
364+
inflight_quant=fd_config.quant_config and not skip_quant,
341365
)
342366

343-
if self.with_bias:
344-
if self.nranks > 0:
367+
if self.nranks > 0:
368+
if self.with_bias:
345369
# col parallel
346370
_set_var_distributed(self.bias, split_axis=1)
347-
set_weight_attrs(self.bias, {"output_dim": True})
348371

349372

350373
class MergedColumnParallelLinear(ColumnParallelLinear):
@@ -397,31 +420,33 @@ def __init__(
397420
skip_quant=skip_quant,
398421
)
399422

400-
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
423+
def load_weights_into_param(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
424+
assert loaded_shard_id in ["gate", "up"]
425+
output_dim = getattr(param, "output_dim", None)
426+
if loaded_shard_id == "gate":
427+
param = slice_fn(param, output_dim, start=0, end=self.output_size // 2)
428+
elif loaded_shard_id == "up":
429+
param = slice_fn(param, output_dim, start=self.output_size // 2, end=self.output_size)
430+
assert param.shape == loaded_weight.shape, (
431+
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
432+
)
433+
param.copy_(loaded_weight, False)
434+
435+
def weights_processor(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
401436
# 1.fused gate_up in disk
402437
# 2.split gate up
403438
assert loaded_shard_id in ["gate", "up"]
404439
output_dim = getattr(param, "output_dim", None)
405440
# Tensor parallelism splits the weight along the output_dim
406-
if output_dim is not None:
441+
if output_dim is not None and self.nranks > 1:
407442
dim = -1
408443
size = loaded_weight.get_shape()[dim]
409444
block_size = size // self.nranks
410445
shard_offset = self.local_rank * block_size
411446
shard_size = (self.local_rank + 1) * block_size
412-
loaded_weight = loaded_weight[..., shard_offset:shard_size]
413-
447+
loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size)
414448
loaded_weight = get_tensor(loaded_weight)
415-
416-
if loaded_shard_id == "gate":
417-
param = param[:, : self.output_size // 2]
418-
elif loaded_shard_id == "up":
419-
param = param[:, self.output_size // 2 :]
420-
421-
assert param.shape == loaded_weight.shape, (
422-
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
423-
)
424-
param.copy_(loaded_weight, False)
449+
yield loaded_weight
425450

426451
def load_state_dict(self, state_dict: dict):
427452
"""
@@ -491,33 +516,44 @@ def __init__(self, fd_config, prefix, with_bias=False, add_bias=True):
491516
add_bias=add_bias,
492517
)
493518

494-
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
519+
def weights_processor(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
495520
# 1.fused qkv in disk
496521
# 2.split q k v
497522
assert loaded_shard_id in ["q", "k", "v"]
498523
output_dim = getattr(param, "output_dim", None)
499524
# Tensor parallelism splits the weight along the output_dim
500-
if output_dim is not None:
525+
if output_dim is not None and self.nranks > 1:
501526
dim = -1
502527
size = loaded_weight.get_shape()[dim]
503528
block_size = size // self.nranks
504529
shard_offset = self.local_rank * block_size
505530
shard_size = (self.local_rank + 1) * block_size
506-
loaded_weight = loaded_weight[..., shard_offset:shard_size]
531+
loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size)
507532

508533
loaded_weight = get_tensor(loaded_weight)
534+
yield loaded_weight
509535

536+
def load_weights_into_param(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
537+
assert loaded_shard_id in ["q", "k", "v"]
538+
output_dim = getattr(param, "output_dim", None)
510539
if loaded_shard_id == "q":
511-
param = param[:, : self.num_heads_per_rank * self.head_dim]
540+
param = slice_fn(param, output_dim, 0, self.num_heads_per_rank * self.head_dim)
541+
512542
elif loaded_shard_id == "k":
513-
param = param[
514-
:,
515-
self.num_heads_per_rank
516-
* self.head_dim : (self.num_heads_per_rank + self.kv_num_heads_per_rank)
517-
* self.head_dim,
518-
]
543+
param = slice_fn(
544+
param,
545+
output_dim,
546+
self.num_heads_per_rank * self.head_dim,
547+
(self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim,
548+
)
549+
519550
elif loaded_shard_id == "v":
520-
param = param[:, (self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim :]
551+
param = slice_fn(
552+
param,
553+
output_dim,
554+
(self.num_heads_per_rank + self.kv_num_heads_per_rank) * self.head_dim,
555+
(self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) * self.head_dim,
556+
)
521557

522558
assert param.shape == loaded_weight.shape, (
523559
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
@@ -665,19 +701,30 @@ def __init__(
665701
self,
666702
split_axis=0,
667703
output_dim=False,
668-
weight_loader=(
669-
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
704+
weights_processor=(
705+
self.weights_processor
706+
if hasattr(self, "weights_processor")
707+
else default_weights_processor(self.fd_config)
670708
),
709+
load_weights_into_param=(
710+
self.load_weights_into_param
711+
if hasattr(self, "load_weights_into_param")
712+
else default_load_weights_into_param()
713+
),
714+
inflight_quant=fd_config.quant_config and not skip_quant,
671715
)
672716

673-
if self.with_bias:
674-
_set_var_distributed(self.bias, split_axis=0)
675-
set_weight_attrs(
676-
self.bias,
677-
{
678-
"output_dim": False,
679-
},
680-
)
717+
if self.nranks > 0:
718+
if self.with_bias:
719+
# col parallel
720+
_set_var_distributed(self.bias, split_axis=0)
721+
set_weight_attrs(
722+
self.bias,
723+
{
724+
"output_dim": False,
725+
},
726+
)
727+
681728
self.reduce_results = reduce_results
682729

683730
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:

fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,11 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
185185
if current_platform.is_cuda():
186186
self.up_gate_proj_weight_shape = [layer.num_experts, layer.hidden_size, layer.moe_intermediate_size * 2]
187187
self.down_proj_weight_shape = [layer.num_experts, layer.moe_intermediate_size, layer.hidden_size]
188+
extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 1, "down": 0, "up": 1}}
188189
else:
189190
self.up_gate_proj_weight_shape = [layer.num_experts, layer.moe_intermediate_size * 2, layer.hidden_size]
190191
self.down_proj_weight_shape = [layer.num_experts, layer.hidden_size, layer.moe_intermediate_size]
192+
extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
191193

192194
layer.up_gate_proj_weight = layer.create_parameter(
193195
shape=self.up_gate_proj_weight_shape,

0 commit comments

Comments
 (0)