@@ -47,6 +47,9 @@ def __init__(
47
47
self .weight = nn .Parameter (
48
48
torch .empty ((out_features , in_features ), ** factory_kwargs )
49
49
)
50
+ self .dqweight = nn .Parameter (
51
+ torch .empty ((out_features , in_features ), ** factory_kwargs )
52
+ )
50
53
if bias :
51
54
self .bias = nn .Parameter (torch .empty (out_features , ** factory_kwargs ))
52
55
else :
@@ -133,17 +136,18 @@ def pre_forward(self):
133
136
if self .config .store_master_weights :
134
137
self .qweight = None
135
138
self .scales = None
139
+ self .dqweight = None
136
140
elif self .config .pseudoquantization :
137
- self .qweight = None
138
- self .scales = None
139
141
weight_dq , _ = forward_pseudoquantize (
140
142
self .weight .data ,
141
143
self .forward_hadamard_matrix ,
142
144
self .config .forward_dtype ,
143
145
self .config .forward_method ,
144
146
)
145
- self .weight .data = weight_dq
146
- self .weight .requires_grad = False
147
+ self .dqweight = nn .Parameter (weight_dq , requires_grad = False )
148
+ self .weight = None
149
+ self .qweight = None
150
+ self .scales = None
147
151
else :
148
152
weight_q , scales , _ = forward_quantize (
149
153
self .weight ,
@@ -156,6 +160,7 @@ def pre_forward(self):
156
160
scales .view (dtype = torch .uint8 ), requires_grad = False
157
161
)
158
162
self .weight = None
163
+ self .dqweight = None
159
164
160
165
def forward (self , x ) -> torch .Tensor :
161
166
match (
@@ -186,7 +191,7 @@ def forward(self, x) -> torch.Tensor:
186
191
case (FPQuantDtype .MXFP4 , FPQuantDtype .BF16 , True , True ):
187
192
return PseudoQuant4x16MasterFn .apply (
188
193
x ,
189
- self .weight ,
194
+ self .dqweight ,
190
195
self .bias ,
191
196
self .forward_hadamard_matrix ,
192
197
self .config .forward_dtype ,
@@ -195,7 +200,7 @@ def forward(self, x) -> torch.Tensor:
195
200
case (FPQuantDtype .MXFP4 , FPQuantDtype .BF16 , False , True ):
196
201
return PseudoQuant4x16NoMasterFn .apply (
197
202
x ,
198
- self .weight ,
203
+ self .dqweight ,
199
204
self .bias ,
200
205
self .forward_hadamard_matrix ,
201
206
self .config .forward_dtype ,
0 commit comments