File tree Expand file tree Collapse file tree 2 files changed +5
-8
lines changed Expand file tree Collapse file tree 2 files changed +5
-8
lines changed Original file line number Diff line number Diff line change 16
16
seed = 1234
17
17
18
18
rotation_trick = True # rotation trick instead ot straight-through
19
- use_mlp = True # use a one layer mlp with relu instead of linear
19
+ use_mlp = True # use a one layer mlp with relu instead of linear
20
20
21
21
device = "cuda" if torch .cuda .is_available () else "cpu"
22
22
Original file line number Diff line number Diff line change @@ -118,18 +118,15 @@ def forward(
118
118
119
119
# commit loss and straight through, as was done in the paper
120
120
121
- commit_loss = F .mse_loss (x .detach (), quantized )
121
+ commit_loss = (
122
+ F .mse_loss (x .detach (), quantized ) +
123
+ F .mse_loss (x , quantized .detach ()) * self .input_to_quantize_commit_loss_weight
124
+ )
122
125
123
126
if self .rotation_trick :
124
127
# rotation trick from @cfifty
125
128
quantized = rotate_to (x , quantized )
126
129
else :
127
-
128
- commit_loss = (
129
- commit_loss +
130
- F .mse_loss (x , quantized .detach ()) * self .input_to_quantize_commit_loss_weight
131
- )
132
-
133
130
quantized = (quantized - x ).detach () + x
134
131
135
132
quantized = inverse_pack (quantized )
You can’t perform that action at this time.
0 commit comments