23
23
from fastdeploy .config import FDConfig
24
24
from fastdeploy .distributed .communication import tensor_model_parallel_all_reduce
25
25
from fastdeploy .model_executor .layers .quantization .quant_base import QuantMethodBase
26
- from fastdeploy .model_executor .models . utils import (
26
+ from fastdeploy .model_executor .utils import (
27
27
default_weight_loader ,
28
28
set_weight_attrs ,
29
+ slice_fn ,
29
30
)
30
31
from fastdeploy .platforms import current_platform
31
32
@@ -38,6 +39,7 @@ class UnquantizedLinearMethod(QuantMethodBase):
38
39
def create_weights (self , layer : nn .Layer , ** extra_weight_attrs ):
39
40
"""
40
41
extra_weight_attrs is a dictionary that may include parameters like:
42
+ - split_axis: axis along which to split the tensor in a distributed environment
41
43
- output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns)
42
44
- weight_loader: a callable or method responsible for loading the weight data
43
45
"""
@@ -47,12 +49,16 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
47
49
is_bias = False ,
48
50
default_initializer = paddle .nn .initializer .Constant (0 ),
49
51
)
52
+ split_axis = extra_weight_attrs .get ("split_axis" )
53
+ if hasattr (layer , "nranks" ) and layer .nranks > 0 :
54
+ _set_var_distributed (layer .weight , split_axis = split_axis )
50
55
set_weight_attrs (
51
56
layer .weight ,
52
- {"weight_loader" : extra_weight_attrs .get ("weight_loader" , default_weight_loader (layer .fd_config ))},
57
+ {
58
+ ** extra_weight_attrs ,
59
+ "weight_loader" : extra_weight_attrs .get ("weight_loader" , default_weight_loader (layer .fd_config )),
60
+ },
53
61
)
54
- if hasattr (layer , "nranks" ) and layer .nranks > 1 :
55
- set_weight_attrs (layer .weight , {"output_dim" : extra_weight_attrs .get ("output_dim" )})
56
62
57
63
def process_loaded_weights (self , layer , weights ) -> None :
58
64
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
@@ -334,7 +340,6 @@ def __init__(
334
340
),
335
341
)
336
342
if self .nranks > 0 :
337
- _set_var_distributed (self .weight , split_axis = 1 )
338
343
if self .with_bias :
339
344
# col parallel
340
345
_set_var_distributed (self .bias , split_axis = 1 )
@@ -393,44 +398,47 @@ def __init__(
393
398
)
394
399
395
400
def weight_loader (self , param , loaded_weight , loaded_shard_id : Optional [str ] = None ):
401
+ output_dim = getattr (param , "output_dim" , None )
402
+ shard_dim = - 1 if output_dim else 0
403
+ output_size = param .shape [shard_dim ] // 2
396
404
if loaded_shard_id is None :
397
405
# Loaded weight is already fused on disk.
398
- if self .nranks != 1 :
399
- shard_offsets = [
400
- # (shard_id, shard_offset, shard_size)
401
- ("gate" , 0 , self .output_size * self .nranks // 2 ),
402
- ("up" , self .output_size * self .nranks // 2 , self .output_size * self .nranks // 2 ),
403
- ]
404
- for shard_id , shard_offset , shard_size in shard_offsets :
405
- loaded_weight_shard = loaded_weight [..., shard_offset : shard_offset + shard_size ]
406
- self .weight_loader (param , loaded_weight_shard , shard_id )
407
- else :
408
- loaded_weight = get_tensor (loaded_weight )
409
- param .copy_ (loaded_weight , False )
406
+ shard_offsets = [
407
+ # (shard_id, shard_offset, shard_size)
408
+ ("gate" , 0 , output_size * self .nranks // 2 ),
409
+ ("up" , output_size * self .nranks // 2 , output_size * self .nranks // 2 ),
410
+ ]
411
+ for shard_id , shard_offset , shard_size in shard_offsets :
412
+ loaded_weight_shard = slice_fn (
413
+ loaded_weight , output_dim , start = shard_offset , end = shard_offset + shard_size
414
+ )
415
+ self .weight_loader (param , loaded_weight_shard , shard_id )
410
416
else :
411
- # 1.fused gate_up in disk
412
- # 2.split gate up
417
+ # split gate up
413
418
assert loaded_shard_id in ["gate" , "up" ]
414
- output_dim = getattr (param , "output_dim" , None )
415
419
# Tensor parallelism splits the weight along the output_dim
416
- if output_dim is not None :
417
- dim = - 1
420
+ if self . nranks != 1 :
421
+ dim = - 1 if output_dim else 0
418
422
if isinstance (loaded_weight , np .ndarray ):
419
423
size = loaded_weight .shape [dim ]
420
424
else :
421
425
size = loaded_weight .get_shape ()[dim ]
422
426
block_size = size // self .nranks
423
427
shard_offset = self .local_rank * block_size
424
428
shard_size = (self .local_rank + 1 ) * block_size
425
- loaded_weight = loaded_weight [..., shard_offset : shard_size ]
429
+ loaded_weight = slice_fn ( loaded_weight , output_dim , start = shard_offset , end = shard_size )
426
430
427
431
loaded_weight = get_tensor (loaded_weight )
428
-
432
+ if not param ._is_initialized ():
433
+ param .initialize ()
429
434
if loaded_shard_id == "gate" :
430
- param = param [:, : self .output_size // 2 ]
431
- elif loaded_shard_id == "up" :
432
- param = param [:, self .output_size // 2 :]
433
-
435
+ param_shard_offset = 0
436
+ else :
437
+ # loaded_shard_id == "up"
438
+ param_shard_offset = output_size
439
+ if hasattr (param , "tensor_track" ):
440
+ param .tensor_track .mark (start = param_shard_offset , end = param_shard_offset + output_size )
441
+ param = slice_fn (param , output_dim , start = param_shard_offset , end = param_shard_offset + output_size )
434
442
assert param .shape == loaded_weight .shape , (
435
443
f" Attempted to load weight ({ loaded_weight .shape } ) " f"into parameter ({ param .shape } )"
436
444
)
@@ -505,53 +513,54 @@ def __init__(self, fd_config, prefix, with_bias=False, add_bias=True):
505
513
)
506
514
507
515
def weight_loader (self , param , loaded_weight , loaded_shard_id : Optional [str ] = None ):
516
+ output_dim = getattr (param , "output_dim" , None )
517
+ head_dim = param .shape [output_dim ] // (self .num_heads_per_rank + 2 * self .kv_num_heads_per_rank )
508
518
if loaded_shard_id is None :
509
519
# Loaded weight is already fused on disk
510
- if self .nranks != 1 :
511
- shard_offsets = [
512
- # (shard_id, shard_offset, shard_size)
513
- ("q" , 0 , self .num_heads * self .head_dim ),
514
- ("k" , self .num_heads * self .head_dim , self .kv_num_heads * self .head_dim ),
515
- ("v" , (self .num_heads + self .kv_num_heads ) * self .head_dim , self .kv_num_heads * self .head_dim ),
516
- ]
517
- for shard_id , shard_offset , shard_size in shard_offsets :
518
- loaded_weight_shard = loaded_weight [..., shard_offset : shard_offset + shard_size ]
519
- self .weight_loader (param , loaded_weight_shard , shard_id )
520
- else :
521
- loaded_weight = get_tensor (loaded_weight )
522
- split_loaded_weight = loaded_weight
523
- param .copy_ (split_loaded_weight , False )
520
+ shard_offsets = [
521
+ # (shard_id, shard_offset, shard_size)
522
+ ("q" , 0 , self .num_heads * head_dim ),
523
+ ("k" , self .num_heads * head_dim , self .kv_num_heads * head_dim ),
524
+ ("v" , (self .num_heads + self .kv_num_heads ) * head_dim , self .kv_num_heads * head_dim ),
525
+ ]
526
+ for shard_id , shard_offset , shard_size in shard_offsets :
527
+ loaded_weight_shard = slice_fn (
528
+ loaded_weight , output_dim , start = shard_offset , end = shard_offset + shard_size
529
+ )
530
+ self .weight_loader (param , loaded_weight_shard , shard_id )
524
531
else :
525
- # 1.fused qkv in disk
526
- # 2.split q k v
532
+ # split q k v
527
533
assert loaded_shard_id in ["q" , "k" , "v" ]
528
- output_dim = getattr (param , "output_dim" , None )
529
534
# Tensor parallelism splits the weight along the output_dim
530
- if output_dim is not None :
531
- dim = - 1
535
+ if self . nranks != 1 :
536
+ dim = - 1 if output_dim else 0
532
537
if isinstance (loaded_weight , np .ndarray ):
533
538
size = loaded_weight .shape [dim ]
534
539
else :
535
540
size = loaded_weight .get_shape ()[dim ]
536
541
block_size = size // self .nranks
537
542
shard_offset = self .local_rank * block_size
538
543
shard_size = (self .local_rank + 1 ) * block_size
539
- loaded_weight = loaded_weight [..., shard_offset : shard_size ]
544
+ loaded_weight = slice_fn ( loaded_weight , output_dim , start = shard_offset , end = shard_size )
540
545
541
546
loaded_weight = get_tensor (loaded_weight )
547
+ if not param ._is_initialized ():
548
+ param .initialize ()
542
549
543
550
if loaded_shard_id == "q" :
544
- param = param [:, : self .num_heads_per_rank * self .head_dim ]
545
- elif loaded_shard_id == "k" :
546
- param = param [
547
- :,
548
- self .num_heads_per_rank
549
- * self .head_dim : (self .num_heads_per_rank + self .kv_num_heads_per_rank )
550
- * self .head_dim ,
551
- ]
552
- elif loaded_shard_id == "v" :
553
- param = param [:, (self .num_heads_per_rank + self .kv_num_heads_per_rank ) * self .head_dim :]
554
551
552
+ param_shard_offset = 0
553
+ param_shard_size = self .num_heads_per_rank * head_dim
554
+ elif loaded_shard_id == "k" :
555
+ param_shard_offset = self .num_heads_per_rank * head_dim
556
+ param_shard_size = self .kv_num_heads_per_rank * head_dim
557
+ else :
558
+ # loaded_shard_id == "v"
559
+ param_shard_offset = (self .num_heads_per_rank + self .kv_num_heads_per_rank ) * head_dim
560
+ param_shard_size = self .kv_num_heads_per_rank * head_dim
561
+ if hasattr (param , "tensor_track" ):
562
+ param .tensor_track .mark (start = param_shard_offset , end = param_shard_offset + param_shard_size )
563
+ param = slice_fn (param , output_dim , start = param_shard_offset , end = param_shard_offset + param_shard_size )
555
564
assert param .shape == loaded_weight .shape , (
556
565
f" Attempted to load weight ({ loaded_weight .shape } ) " f"into parameter ({ param .shape } )"
557
566
)
@@ -698,7 +707,6 @@ def __init__(
698
707
),
699
708
)
700
709
if self .nranks > 0 :
701
- _set_var_distributed (self .weight , split_axis = 0 )
702
710
if self .with_bias :
703
711
# col parallel
704
712
_set_var_distributed (self .bias , split_axis = 0 )
0 commit comments