Skip to content

Commit 3aa6867

Browse files
committed
bring back the other half of the commit loss even in the presence of rotation trick, addressing #177
1 parent 7be8916 commit 3aa6867

File tree

2 files changed

+5
-8
lines changed

2 files changed

+5
-8
lines changed

examples/autoencoder_sim_vq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
seed = 1234
1717

1818
rotation_trick = True # rotation trick instead ot straight-through
19-
use_mlp = True # use a one layer mlp with relu instead of linear
19+
use_mlp = True # use a one layer mlp with relu instead of linear
2020

2121
device = "cuda" if torch.cuda.is_available() else "cpu"
2222

vector_quantize_pytorch/sim_vq.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,18 +118,15 @@ def forward(
118118

119119
# commit loss and straight through, as was done in the paper
120120

121-
commit_loss = F.mse_loss(x.detach(), quantized)
121+
commit_loss = (
122+
F.mse_loss(x.detach(), quantized) +
123+
F.mse_loss(x, quantized.detach()) * self.input_to_quantize_commit_loss_weight
124+
)
122125

123126
if self.rotation_trick:
124127
# rotation trick from @cfifty
125128
quantized = rotate_to(x, quantized)
126129
else:
127-
128-
commit_loss = (
129-
commit_loss +
130-
F.mse_loss(x, quantized.detach()) * self.input_to_quantize_commit_loss_weight
131-
)
132-
133130
quantized = (quantized - x).detach() + x
134131

135132
quantized = inverse_pack(quantized)

0 commit comments

Comments
 (0)