23
23
from fastdeploy .distributed .communication import tensor_model_parallel_all_reduce
24
24
from fastdeploy .model_executor .layers .quantization .quant_base import QuantMethodBase
25
25
from fastdeploy .model_executor .models .utils import (
26
- default_weight_loader ,
26
+ default_load_weights_into_param ,
27
+ default_weights_processor ,
27
28
set_weight_attrs ,
29
+ slice_fn ,
28
30
)
29
31
from fastdeploy .platforms import current_platform
30
32
@@ -37,21 +39,32 @@ class UnquantizedLinearMethod(QuantMethodBase):
37
39
def create_weights (self , layer : nn .Layer , ** extra_weight_attrs ):
38
40
"""
39
41
extra_weight_attrs is a dictionary that may include parameters like:
42
+ - split_axis: axis along which to split the tensor in a distributed environment
40
43
- output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns)
41
- - weight_loader: a callable or method responsible for loading the weight data
44
+ - weights_processor: a callable or method responsible for processing weight data
45
+ - load_weights_into_param:Loads the given weight tensor into the specified model parameter.
42
46
"""
43
47
layer .weight = layer .create_parameter (
44
48
shape = layer .weight_shape ,
45
49
dtype = layer .weight_dtype ,
46
50
is_bias = False ,
47
51
default_initializer = paddle .nn .initializer .Constant (0 ),
48
52
)
53
+ split_axis = extra_weight_attrs .get ("split_axis" )
54
+ if hasattr (layer , "nranks" ) and layer .nranks > 0 :
55
+ _set_var_distributed (layer .weight , split_axis = split_axis )
49
56
set_weight_attrs (
50
57
layer .weight ,
51
- {"weight_loader" : extra_weight_attrs .get ("weight_loader" , default_weight_loader (layer .fd_config ))},
58
+ {
59
+ ** extra_weight_attrs ,
60
+ "weights_processor" : extra_weight_attrs .get (
61
+ "weights_processor" , default_weights_processor (layer .fd_config )
62
+ ),
63
+ "load_weights_into_param" : extra_weight_attrs .get (
64
+ "load_weights_into_param" , default_load_weights_into_param ()
65
+ ),
66
+ },
52
67
)
53
- if hasattr (layer , "nranks" ) and layer .nranks > 1 :
54
- set_weight_attrs (layer .weight , {"output_dim" : extra_weight_attrs .get ("output_dim" )})
55
68
56
69
def process_loaded_weights (self , layer , weights ) -> None :
57
70
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
@@ -158,6 +171,7 @@ def __init__(
158
171
is_bias = True ,
159
172
)
160
173
174
+ self .is_quantized = fd_config .model_config .is_quantized
161
175
# smooth quant
162
176
self .linear_shift = None
163
177
self .linear_smooth = None
@@ -270,9 +284,17 @@ def __init__(
270
284
assert self .quant_method is not None
271
285
self .quant_method .create_weights (
272
286
self ,
273
- weight_loader = (
274
- self .weight_loader if hasattr (self , "weight_loader" ) else default_weight_loader (self .fd_config )
287
+ weights_processor = (
288
+ self .weights_processor
289
+ if hasattr (self , "weights_processor" )
290
+ else default_weights_processor (self .fd_config )
275
291
),
292
+ load_weights_into_param = (
293
+ self .load_weights_into_param
294
+ if hasattr (self , "load_weights_into_param" )
295
+ else default_load_weights_into_param ()
296
+ ),
297
+ inflight_quant = fd_config .quant_config and not skip_quant ,
276
298
)
277
299
278
300
@@ -327,17 +349,23 @@ def __init__(
327
349
self .quant_method .create_weights (
328
350
self ,
329
351
output_dim = True ,
330
- weight_loader = (
331
- self .weight_loader if hasattr (self , "weight_loader" ) else default_weight_loader (self .fd_config )
352
+ weights_processor = (
353
+ self .weights_processor
354
+ if hasattr (self , "weights_processor" )
355
+ else default_weights_processor (self .fd_config )
356
+ ),
357
+ load_weights_into_param = (
358
+ self .load_weights_into_param
359
+ if hasattr (self , "load_weights_into_param" )
360
+ else default_load_weights_into_param ()
332
361
),
362
+ inflight_quant = fd_config .quant_config and not skip_quant ,
333
363
)
364
+
334
365
if self .nranks > 0 :
335
- _set_var_distributed (self .weight , split_axis = 1 )
336
366
if self .with_bias :
337
367
# col parallel
338
368
_set_var_distributed (self .bias , split_axis = 1 )
339
- if self .nranks > 1 :
340
- set_weight_attrs (self .bias , {"output_dim" : True })
341
369
342
370
343
371
class MergedColumnParallelLinear (ColumnParallelLinear ):
@@ -390,31 +418,33 @@ def __init__(
390
418
skip_quant = skip_quant ,
391
419
)
392
420
393
- def weight_loader (self , param , loaded_weight , loaded_shard_id : Optional [str ] = None ):
421
+ def load_weights_into_param (self , param , loaded_weight , loaded_shard_id : Optional [str ] = None ):
422
+ assert loaded_shard_id in ["gate" , "up" ]
423
+ output_dim = getattr (param , "output_dim" , None )
424
+ if loaded_shard_id == "gate" :
425
+ param = slice_fn (param , output_dim , start = 0 , end = self .output_size // 2 )
426
+ elif loaded_shard_id == "up" :
427
+ param = slice_fn (param , output_dim , start = self .output_size // 2 , end = self .output_size )
428
+ assert param .shape == loaded_weight .shape , (
429
+ f" Attempted to load weight ({ loaded_weight .shape } ) " f"into parameter ({ param .shape } )"
430
+ )
431
+ param .copy_ (loaded_weight , False )
432
+
433
+ def weights_processor (self , param , loaded_weight , loaded_shard_id : Optional [str ] = None ):
394
434
# 1.fused gate_up in disk
395
435
# 2.split gate up
396
436
assert loaded_shard_id in ["gate" , "up" ]
397
437
output_dim = getattr (param , "output_dim" , None )
398
438
# Tensor parallelism splits the weight along the output_dim
399
- if output_dim is not None :
439
+ if output_dim is not None and self . nranks > 1 :
400
440
dim = - 1
401
441
size = loaded_weight .get_shape ()[dim ]
402
442
block_size = size // self .nranks
403
443
shard_offset = self .local_rank * block_size
404
444
shard_size = (self .local_rank + 1 ) * block_size
405
- loaded_weight = loaded_weight [..., shard_offset :shard_size ]
406
-
445
+ loaded_weight = slice_fn (loaded_weight , output_dim , shard_offset , shard_size )
407
446
loaded_weight = get_tensor (loaded_weight )
408
-
409
- if loaded_shard_id == "gate" :
410
- param = param [:, : self .output_size // 2 ]
411
- elif loaded_shard_id == "up" :
412
- param = param [:, self .output_size // 2 :]
413
-
414
- assert param .shape == loaded_weight .shape , (
415
- f" Attempted to load weight ({ loaded_weight .shape } ) " f"into parameter ({ param .shape } )"
416
- )
417
- param .copy_ (loaded_weight , False )
447
+ yield loaded_weight
418
448
419
449
def load_state_dict (self , state_dict : dict ):
420
450
"""
@@ -484,33 +514,44 @@ def __init__(self, fd_config, prefix, with_bias=False, add_bias=True):
484
514
add_bias = add_bias ,
485
515
)
486
516
487
- def weight_loader (self , param , loaded_weight , loaded_shard_id : Optional [str ] = None ):
517
+ def weights_processor (self , param , loaded_weight , loaded_shard_id : Optional [str ] = None ):
488
518
# 1.fused qkv in disk
489
519
# 2.split q k v
490
520
assert loaded_shard_id in ["q" , "k" , "v" ]
491
521
output_dim = getattr (param , "output_dim" , None )
492
522
# Tensor parallelism splits the weight along the output_dim
493
- if output_dim is not None :
523
+ if output_dim is not None and self . nranks > 1 :
494
524
dim = - 1
495
525
size = loaded_weight .get_shape ()[dim ]
496
526
block_size = size // self .nranks
497
527
shard_offset = self .local_rank * block_size
498
528
shard_size = (self .local_rank + 1 ) * block_size
499
- loaded_weight = loaded_weight [..., shard_offset : shard_size ]
529
+ loaded_weight = slice_fn ( loaded_weight , output_dim , shard_offset , shard_size )
500
530
501
531
loaded_weight = get_tensor (loaded_weight )
532
+ yield loaded_weight
502
533
534
+ def load_weights_into_param (self , param , loaded_weight , loaded_shard_id : Optional [str ] = None ):
535
+ assert loaded_shard_id in ["q" , "k" , "v" ]
536
+ output_dim = getattr (param , "output_dim" , None )
503
537
if loaded_shard_id == "q" :
504
- param = param [:, : self .num_heads_per_rank * self .head_dim ]
538
+ param = slice_fn (param , output_dim , 0 , self .num_heads_per_rank * self .head_dim )
539
+
505
540
elif loaded_shard_id == "k" :
506
- param = param [
507
- :,
508
- self .num_heads_per_rank
509
- * self .head_dim : (self .num_heads_per_rank + self .kv_num_heads_per_rank )
510
- * self .head_dim ,
511
- ]
541
+ param = slice_fn (
542
+ param ,
543
+ output_dim ,
544
+ self .num_heads_per_rank * self .head_dim ,
545
+ (self .num_heads_per_rank + self .kv_num_heads_per_rank ) * self .head_dim ,
546
+ )
547
+
512
548
elif loaded_shard_id == "v" :
513
- param = param [:, (self .num_heads_per_rank + self .kv_num_heads_per_rank ) * self .head_dim :]
549
+ param = slice_fn (
550
+ param ,
551
+ output_dim ,
552
+ (self .num_heads_per_rank + self .kv_num_heads_per_rank ) * self .head_dim ,
553
+ (self .num_heads_per_rank + 2 * self .kv_num_heads_per_rank ) * self .head_dim ,
554
+ )
514
555
515
556
assert param .shape == loaded_weight .shape , (
516
557
f" Attempted to load weight ({ loaded_weight .shape } ) " f"into parameter ({ param .shape } )"
@@ -653,9 +694,17 @@ def __init__(
653
694
self ,
654
695
split_axis = 0 ,
655
696
output_dim = False ,
656
- weight_loader = (
657
- self .weight_loader if hasattr (self , "weight_loader" ) else default_weight_loader (self .fd_config )
697
+ weights_processor = (
698
+ self .weights_processor
699
+ if hasattr (self , "weights_processor" )
700
+ else default_weights_processor (self .fd_config )
701
+ ),
702
+ load_weights_into_param = (
703
+ self .load_weights_into_param
704
+ if hasattr (self , "load_weights_into_param" )
705
+ else default_load_weights_into_param ()
658
706
),
707
+ inflight_quant = fd_config .quant_config and not skip_quant ,
659
708
)
660
709
if self .nranks > 0 :
661
710
_set_var_distributed (self .weight , split_axis = 0 )
@@ -670,6 +719,17 @@ def __init__(
670
719
},
671
720
)
672
721
722
+ if self .nranks > 0 :
723
+ if self .with_bias :
724
+ # col parallel
725
+ _set_var_distributed (self .bias , split_axis = 0 )
726
+ set_weight_attrs (
727
+ self .bias ,
728
+ {
729
+ "output_dim" : False ,
730
+ },
731
+ )
732
+
673
733
self .reduce_results = reduce_results
674
734
675
735
def forward_cuda (self , x : paddle .Tensor ) -> paddle .Tensor :
0 commit comments