Skip to content

Commit 7f8f948

Browse files
authored
Merge branch 'PaddlePaddle:develop' into pr0819
2 parents 516fedb + fef447e commit 7f8f948

File tree

4 files changed

+137
-18
lines changed

4 files changed

+137
-18
lines changed

fastdeploy/model_executor/layers/moe/check_backend_supported.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,26 @@
1818
from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import (
1919
CutlassMoEMethod,
2020
)
21+
from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import (
22+
DeepGemmFusedMoeMethod,
23+
)
24+
from fastdeploy.model_executor.layers.moe.fused_moe_marlin_backend import (
25+
MarlinWeightOnlyMoEMethod,
26+
)
2127
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import (
2228
BlockWiseFP8MoEMethod,
2329
TensorWiseFP8MoEMethod,
2430
TritonWeightOnlyMoEMethod,
2531
)
2632

27-
pre_create_weights_list = (CutlassMoEMethod, TensorWiseFP8MoEMethod, BlockWiseFP8MoEMethod, TritonWeightOnlyMoEMethod)
33+
pre_create_weights_list = (
34+
CutlassMoEMethod,
35+
TensorWiseFP8MoEMethod,
36+
BlockWiseFP8MoEMethod,
37+
TritonWeightOnlyMoEMethod,
38+
DeepGemmFusedMoeMethod,
39+
MarlinWeightOnlyMoEMethod,
40+
)
2841

2942

3043
def is_supported_moe_backend(quant_method: MoEMethodBase):

fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from fastdeploy.model_executor.layers.utils import get_tensor
2424
from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func, deep_gemm
2525

26-
from ..utils import create_and_set_parameter
2726
from .fused_moe_backend_base import MoEMethodBase
2827

2928

@@ -32,11 +31,73 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
3231
DeepGemmFusedMoeMethod is a class that implements the MoEMethodBase interface for DeepGemm backend.
3332
"""
3433

35-
def create_weights(self, layer: nn.Layer, state_dict):
34+
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
3635
"""
3736
deepgemm create weight process.
3837
"""
38+
self.weight_dtype = paddle.float8_e4m3fn
39+
up_gate_proj_weight_name = self.added_weight_attrs[0]
40+
down_proj_weight_name = self.added_weight_attrs[1]
41+
self.ffn1_weight_shape = [
42+
layer.num_local_experts,
43+
layer.moe_intermediate_size * 2,
44+
layer.hidden_size,
45+
]
46+
self.ffn2_weight_shape = [
47+
layer.num_local_experts,
48+
layer.hidden_size,
49+
layer.moe_intermediate_size,
50+
]
51+
setattr(
52+
layer,
53+
up_gate_proj_weight_name,
54+
layer.create_parameter(
55+
shape=self.ffn1_weight_shape,
56+
dtype=self.weight_dtype,
57+
default_initializer=paddle.nn.initializer.Constant(0),
58+
),
59+
)
60+
setattr(
61+
layer,
62+
down_proj_weight_name,
63+
layer.create_parameter(
64+
shape=self.ffn2_weight_shape,
65+
dtype=self.weight_dtype,
66+
default_initializer=paddle.nn.initializer.Constant(0),
67+
),
68+
)
69+
# weight_scale
70+
setattr(
71+
layer,
72+
self.added_scale_attrs[0],
73+
layer.create_parameter(
74+
shape=[
75+
layer.num_local_experts,
76+
layer.moe_intermediate_size * 2 // self.quant_config.weight_block_size[0],
77+
layer.hidden_size // self.quant_config.weight_block_size[1],
78+
],
79+
dtype="float32",
80+
default_initializer=paddle.nn.initializer.Constant(0),
81+
),
82+
)
83+
setattr(
84+
layer,
85+
self.added_scale_attrs[1],
86+
layer.create_parameter(
87+
shape=[
88+
layer.num_local_experts,
89+
layer.hidden_size // self.quant_config.weight_block_size[0],
90+
layer.moe_intermediate_size // self.quant_config.weight_block_size[1],
91+
],
92+
dtype="float32",
93+
default_initializer=paddle.nn.initializer.Constant(0),
94+
),
95+
)
3996

97+
def process_loaded_weights(self, layer: nn.Layer, state_dict):
98+
"""
99+
deepgemm create weight process.
100+
"""
40101
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
41102

42103
self.check(layer, up_gate_proj_weights, down_proj_weights)
@@ -56,11 +117,11 @@ def create_weights(self, layer: nn.Layer, state_dict):
56117
weight_scale_list.append(scale)
57118
quanted_weight = paddle.stack(weight_list, axis=0)
58119
quanted_weight = quanted_weight.transpose([0, 2, 1]).contiguous()
59-
create_and_set_parameter(layer, weight_name, quanted_weight)
120+
getattr(layer, weight_name).copy_(quanted_weight, False)
60121

61122
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
62123
quanted_weight_scale = quanted_weight_scale.transpose([0, 2, 1]).contiguous()
63-
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
124+
getattr(layer, scale_name).set_value(quanted_weight_scale)
64125

65126
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
66127
"""
@@ -120,7 +181,7 @@ def process_prequanted_weights(self, layer: nn.Layer, state_dict):
120181
"down_proj_weight_scale": down_proj_weight_scale,
121182
}
122183
for name, tensor in name_tensor_map.items():
123-
create_and_set_parameter(layer, name, tensor)
184+
getattr(layer, name).set_value(tensor)
124185

125186
def apply_ep_prefill(
126187
self,

fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,63 @@ def __init__(self, quant_method=None):
139139
]
140140
self.added_zeros_attrs = ["zeros0", "zeros1"]
141141

142-
def create_weights(self, layer: nn.Layer, state_dict):
142+
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
143+
self.default_dtype = layer._helper.get_default_dtype()
144+
self.weight_dtype = "int32"
145+
146+
up_gate_proj_weight_name = self.added_weight_attrs[0]
147+
down_proj_weight_name = self.added_weight_attrs[1]
148+
self.ffn1_weight_shape = [
149+
layer.num_local_experts,
150+
layer.hidden_size // 16,
151+
layer.moe_intermediate_size * 4,
152+
]
153+
self.ffn2_weight_shape = [
154+
layer.num_local_experts,
155+
layer.moe_intermediate_size // 16,
156+
layer.hidden_size * 2,
157+
]
158+
setattr(
159+
layer,
160+
up_gate_proj_weight_name,
161+
layer.create_parameter(
162+
shape=self.ffn1_weight_shape,
163+
dtype=self.weight_dtype,
164+
default_initializer=paddle.nn.initializer.Constant(0),
165+
),
166+
)
167+
setattr(
168+
layer,
169+
down_proj_weight_name,
170+
layer.create_parameter(
171+
shape=self.ffn2_weight_shape,
172+
dtype=self.weight_dtype,
173+
default_initializer=paddle.nn.initializer.Constant(0),
174+
),
175+
)
176+
# weight_scale
177+
setattr(
178+
layer,
179+
self.added_scale_attrs[0],
180+
layer.create_parameter(
181+
shape=[layer.num_local_experts, 1, layer.moe_intermediate_size * 2],
182+
dtype=self.default_dtype,
183+
default_initializer=paddle.nn.initializer.Constant(0),
184+
),
185+
)
186+
setattr(
187+
layer,
188+
self.added_scale_attrs[1],
189+
layer.create_parameter(
190+
shape=[layer.num_local_experts, 1, layer.hidden_size],
191+
dtype=self.default_dtype,
192+
default_initializer=paddle.nn.initializer.Constant(0),
193+
),
194+
)
195+
196+
def process_loaded_weights(self, layer: nn.Layer, state_dict):
143197
"""
144-
Marlin MoE create weight process.
198+
Marlin MoE load weight process.
145199
"""
146200
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
147201
assert len(up_gate_proj_weights) == layer.num_local_experts
@@ -204,15 +258,6 @@ def create_weights(self, layer: nn.Layer, state_dict):
204258
(weight_name, quanted_weight),
205259
(scale_name, weight_scale),
206260
]:
207-
setattr(
208-
layer,
209-
name,
210-
layer.create_parameter(
211-
shape=tensor.shape,
212-
dtype=tensor.dtype,
213-
default_initializer=paddle.nn.initializer.Constant(0),
214-
),
215-
)
216261
getattr(layer, name).set_value(tensor)
217262

218263
def apply(

fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
630630
layer,
631631
down_proj_weight_name,
632632
layer.create_parameter(
633-
shape=self.ffn1_weight_shape,
633+
shape=self.ffn2_weight_shape,
634634
dtype=self.weight_dtype,
635635
default_initializer=paddle.nn.initializer.Constant(0),
636636
),

0 commit comments

Comments
 (0)