Skip to content

Commit b4e9d6b

Browse files
committed
dqweight
1 parent ac03ad3 commit b4e9d6b

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

inference_lib/src/fp_quant/module/linear.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def __init__(
4747
self.weight = nn.Parameter(
4848
torch.empty((out_features, in_features), **factory_kwargs)
4949
)
50+
self.dqweight = nn.Parameter(
51+
torch.empty((out_features, in_features), **factory_kwargs)
52+
)
5053
if bias:
5154
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
5255
else:
@@ -133,17 +136,18 @@ def pre_forward(self):
133136
if self.config.store_master_weights:
134137
self.qweight = None
135138
self.scales = None
139+
self.dqweight = None
136140
elif self.config.pseudoquantization:
137-
self.qweight = None
138-
self.scales = None
139141
weight_dq, _ = forward_pseudoquantize(
140142
self.weight.data,
141143
self.forward_hadamard_matrix,
142144
self.config.forward_dtype,
143145
self.config.forward_method,
144146
)
145-
self.weight.data = weight_dq
146-
self.weight.requires_grad = False
147+
self.dqweight = nn.Parameter(weight_dq, requires_grad=False)
148+
self.weight = None
149+
self.qweight = None
150+
self.scales = None
147151
else:
148152
weight_q, scales, _ = forward_quantize(
149153
self.weight,
@@ -156,6 +160,7 @@ def pre_forward(self):
156160
scales.view(dtype=torch.uint8), requires_grad=False
157161
)
158162
self.weight = None
163+
self.dqweight = None
159164

160165
def forward(self, x) -> torch.Tensor:
161166
match (
@@ -186,7 +191,7 @@ def forward(self, x) -> torch.Tensor:
186191
case (FPQuantDtype.MXFP4, FPQuantDtype.BF16, True, True):
187192
return PseudoQuant4x16MasterFn.apply(
188193
x,
189-
self.weight,
194+
self.dqweight,
190195
self.bias,
191196
self.forward_hadamard_matrix,
192197
self.config.forward_dtype,
@@ -195,7 +200,7 @@ def forward(self, x) -> torch.Tensor:
195200
case (FPQuantDtype.MXFP4, FPQuantDtype.BF16, False, True):
196201
return PseudoQuant4x16NoMasterFn.apply(
197202
x,
198-
self.weight,
203+
self.dqweight,
199204
self.bias,
200205
self.forward_hadamard_matrix,
201206
self.config.forward_dtype,

0 commit comments

Comments
 (0)