23
23
from fastdeploy .distributed .communication import tensor_model_parallel_all_reduce
24
24
from fastdeploy .model_executor .layers .quantization .quant_base import QuantMethodBase
25
25
from fastdeploy .model_executor .models .utils import (
26
- default_weight_loader ,
26
+ default_load_weights_into_param ,
27
+ default_weights_processor ,
27
28
set_weight_attrs ,
29
+ slice_fn ,
28
30
)
29
31
from fastdeploy .platforms import current_platform
30
32
@@ -37,24 +39,29 @@ class UnquantizedLinearMethod(QuantMethodBase):
37
39
def create_weights (self , layer : nn .Layer , ** extra_weight_attrs ):
38
40
"""
39
41
extra_weight_attrs is a dictionary that may include parameters like:
40
- - split_axis: specifies which axis to split the weight tensor on (for distributed weight partitioning)
41
- - output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns)
42
- - weight_loader: a callable or method responsible for loading the weight data
42
+ - weights_processor: a callable or method responsible for loading the weight data
43
43
"""
44
44
layer .weight = layer .create_parameter (
45
45
shape = layer .weight_shape ,
46
46
dtype = layer .weight_dtype ,
47
47
is_bias = False ,
48
48
default_initializer = paddle .nn .initializer .Constant (0 ),
49
49
)
50
+ split_axis = extra_weight_attrs .get ("split_axis" )
51
+ if hasattr (layer , "nranks" ) and layer .nranks > 0 :
52
+ _set_var_distributed (layer .weight , split_axis = split_axis )
50
53
set_weight_attrs (
51
54
layer .weight ,
52
- {"weight_loader" : extra_weight_attrs .get ("weight_loader" , default_weight_loader (layer .fd_config ))},
55
+ {
56
+ ** extra_weight_attrs ,
57
+ "weights_processor" : extra_weight_attrs .get (
58
+ "weights_processor" , default_weights_processor (layer .fd_config )
59
+ ),
60
+ "load_weights_into_param" : extra_weight_attrs .get (
61
+ "load_weights_into_param" , default_load_weights_into_param ()
62
+ ),
63
+ },
53
64
)
54
- if hasattr (layer , "nranks" ) and layer .nranks > 0 :
55
- split_axis = extra_weight_attrs .get ("split_axis" )
56
- _set_var_distributed (layer .weight , split_axis = split_axis )
57
- set_weight_attrs (layer .weight , {"output_dim" : extra_weight_attrs .get ("output_dim" )})
58
65
59
66
def process_loaded_weights (self , layer , weights ) -> None :
60
67
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
@@ -157,6 +164,7 @@ def __init__(
157
164
is_bias = True ,
158
165
)
159
166
167
+ self .is_quantized = fd_config .model_config .is_quantized
160
168
# smooth quant
161
169
self .linear_shift = None
162
170
self .linear_smooth = None
@@ -274,9 +282,17 @@ def __init__(
274
282
assert self .quant_method is not None
275
283
self .quant_method .create_weights (
276
284
self ,
277
- weight_loader = (
278
- self .weight_loader if hasattr (self , "weight_loader" ) else default_weight_loader (self .fd_config )
285
+ weights_processor = (
286
+ self .weights_processor
287
+ if hasattr (self , "weights_processor" )
288
+ else default_weights_processor (self .fd_config )
289
+ ),
290
+ load_weights_into_param = (
291
+ self .load_weights_into_param
292
+ if hasattr (self , "load_weights_into_param" )
293
+ else default_load_weights_into_param ()
279
294
),
295
+ inflight_quant = fd_config .quant_config and not skip_quant ,
280
296
)
281
297
282
298
@@ -335,16 +351,23 @@ def __init__(
335
351
self ,
336
352
split_axis = 1 ,
337
353
output_dim = True ,
338
- weight_loader = (
339
- self .weight_loader if hasattr (self , "weight_loader" ) else default_weight_loader (self .fd_config )
354
+ weights_processor = (
355
+ self .weights_processor
356
+ if hasattr (self , "weights_processor" )
357
+ else default_weights_processor (self .fd_config )
340
358
),
359
+ load_weights_into_param = (
360
+ self .load_weights_into_param
361
+ if hasattr (self , "load_weights_into_param" )
362
+ else default_load_weights_into_param ()
363
+ ),
364
+ inflight_quant = fd_config .quant_config and not skip_quant ,
341
365
)
342
366
343
- if self .with_bias :
344
- if self .nranks > 0 :
367
+ if self .nranks > 0 :
368
+ if self .with_bias :
345
369
# col parallel
346
370
_set_var_distributed (self .bias , split_axis = 1 )
347
- set_weight_attrs (self .bias , {"output_dim" : True })
348
371
349
372
350
373
class MergedColumnParallelLinear (ColumnParallelLinear ):
@@ -397,31 +420,33 @@ def __init__(
397
420
skip_quant = skip_quant ,
398
421
)
399
422
400
- def weight_loader (self , param , loaded_weight , loaded_shard_id : Optional [str ] = None ):
423
+ def load_weights_into_param (self , param , loaded_weight , loaded_shard_id : Optional [str ] = None ):
424
+ assert loaded_shard_id in ["gate" , "up" ]
425
+ output_dim = getattr (param , "output_dim" , None )
426
+ if loaded_shard_id == "gate" :
427
+ param = slice_fn (param , output_dim , start = 0 , end = self .output_size // 2 )
428
+ elif loaded_shard_id == "up" :
429
+ param = slice_fn (param , output_dim , start = self .output_size // 2 , end = self .output_size )
430
+ assert param .shape == loaded_weight .shape , (
431
+ f" Attempted to load weight ({ loaded_weight .shape } ) " f"into parameter ({ param .shape } )"
432
+ )
433
+ param .copy_ (loaded_weight , False )
434
+
435
+ def weights_processor (self , param , loaded_weight , loaded_shard_id : Optional [str ] = None ):
401
436
# 1.fused gate_up in disk
402
437
# 2.split gate up
403
438
assert loaded_shard_id in ["gate" , "up" ]
404
439
output_dim = getattr (param , "output_dim" , None )
405
440
# Tensor parallelism splits the weight along the output_dim
406
- if output_dim is not None :
441
+ if output_dim is not None and self . nranks > 1 :
407
442
dim = - 1
408
443
size = loaded_weight .get_shape ()[dim ]
409
444
block_size = size // self .nranks
410
445
shard_offset = self .local_rank * block_size
411
446
shard_size = (self .local_rank + 1 ) * block_size
412
- loaded_weight = loaded_weight [..., shard_offset :shard_size ]
413
-
447
+ loaded_weight = slice_fn (loaded_weight , output_dim , shard_offset , shard_size )
414
448
loaded_weight = get_tensor (loaded_weight )
415
-
416
- if loaded_shard_id == "gate" :
417
- param = param [:, : self .output_size // 2 ]
418
- elif loaded_shard_id == "up" :
419
- param = param [:, self .output_size // 2 :]
420
-
421
- assert param .shape == loaded_weight .shape , (
422
- f" Attempted to load weight ({ loaded_weight .shape } ) " f"into parameter ({ param .shape } )"
423
- )
424
- param .copy_ (loaded_weight , False )
449
+ yield loaded_weight
425
450
426
451
def load_state_dict (self , state_dict : dict ):
427
452
"""
@@ -491,33 +516,44 @@ def __init__(self, fd_config, prefix, with_bias=False, add_bias=True):
491
516
add_bias = add_bias ,
492
517
)
493
518
494
- def weight_loader (self , param , loaded_weight , loaded_shard_id : Optional [str ] = None ):
519
+ def weights_processor (self , param , loaded_weight , loaded_shard_id : Optional [str ] = None ):
495
520
# 1.fused qkv in disk
496
521
# 2.split q k v
497
522
assert loaded_shard_id in ["q" , "k" , "v" ]
498
523
output_dim = getattr (param , "output_dim" , None )
499
524
# Tensor parallelism splits the weight along the output_dim
500
- if output_dim is not None :
525
+ if output_dim is not None and self . nranks > 1 :
501
526
dim = - 1
502
527
size = loaded_weight .get_shape ()[dim ]
503
528
block_size = size // self .nranks
504
529
shard_offset = self .local_rank * block_size
505
530
shard_size = (self .local_rank + 1 ) * block_size
506
- loaded_weight = loaded_weight [..., shard_offset : shard_size ]
531
+ loaded_weight = slice_fn ( loaded_weight , output_dim , shard_offset , shard_size )
507
532
508
533
loaded_weight = get_tensor (loaded_weight )
534
+ yield loaded_weight
509
535
536
+ def load_weights_into_param (self , param , loaded_weight , loaded_shard_id : Optional [str ] = None ):
537
+ assert loaded_shard_id in ["q" , "k" , "v" ]
538
+ output_dim = getattr (param , "output_dim" , None )
510
539
if loaded_shard_id == "q" :
511
- param = param [:, : self .num_heads_per_rank * self .head_dim ]
540
+ param = slice_fn (param , output_dim , 0 , self .num_heads_per_rank * self .head_dim )
541
+
512
542
elif loaded_shard_id == "k" :
513
- param = param [
514
- :,
515
- self .num_heads_per_rank
516
- * self .head_dim : (self .num_heads_per_rank + self .kv_num_heads_per_rank )
517
- * self .head_dim ,
518
- ]
543
+ param = slice_fn (
544
+ param ,
545
+ output_dim ,
546
+ self .num_heads_per_rank * self .head_dim ,
547
+ (self .num_heads_per_rank + self .kv_num_heads_per_rank ) * self .head_dim ,
548
+ )
549
+
519
550
elif loaded_shard_id == "v" :
520
- param = param [:, (self .num_heads_per_rank + self .kv_num_heads_per_rank ) * self .head_dim :]
551
+ param = slice_fn (
552
+ param ,
553
+ output_dim ,
554
+ (self .num_heads_per_rank + self .kv_num_heads_per_rank ) * self .head_dim ,
555
+ (self .num_heads_per_rank + 2 * self .kv_num_heads_per_rank ) * self .head_dim ,
556
+ )
521
557
522
558
assert param .shape == loaded_weight .shape , (
523
559
f" Attempted to load weight ({ loaded_weight .shape } ) " f"into parameter ({ param .shape } )"
@@ -665,19 +701,30 @@ def __init__(
665
701
self ,
666
702
split_axis = 0 ,
667
703
output_dim = False ,
668
- weight_loader = (
669
- self .weight_loader if hasattr (self , "weight_loader" ) else default_weight_loader (self .fd_config )
704
+ weights_processor = (
705
+ self .weights_processor
706
+ if hasattr (self , "weights_processor" )
707
+ else default_weights_processor (self .fd_config )
670
708
),
709
+ load_weights_into_param = (
710
+ self .load_weights_into_param
711
+ if hasattr (self , "load_weights_into_param" )
712
+ else default_load_weights_into_param ()
713
+ ),
714
+ inflight_quant = fd_config .quant_config and not skip_quant ,
671
715
)
672
716
673
- if self .with_bias :
674
- _set_var_distributed (self .bias , split_axis = 0 )
675
- set_weight_attrs (
676
- self .bias ,
677
- {
678
- "output_dim" : False ,
679
- },
680
- )
717
+ if self .nranks > 0 :
718
+ if self .with_bias :
719
+ # col parallel
720
+ _set_var_distributed (self .bias , split_axis = 0 )
721
+ set_weight_attrs (
722
+ self .bias ,
723
+ {
724
+ "output_dim" : False ,
725
+ },
726
+ )
727
+
681
728
self .reduce_results = reduce_results
682
729
683
730
def forward_cuda (self , x : paddle .Tensor ) -> paddle .Tensor :
0 commit comments