Skip to content

Commit 01b45eb

Browse files
committed
fix commit loss
1 parent e11a966 commit 01b45eb

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

vector_quantize_pytorch/sim_vq.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ def forward(
7070

7171
# commit loss
7272

73-
commit_loss = (F.pairwise_distance(x, quantized) ** 2).mean()
73+
commit_loss = (
74+
0.25 * F.mse_loss(x, quantized.detach()) +
75+
F.mse_loss(x.detach(), quantized)
76+
)
7477

7578
quantized = (quantized - x).detach() + x
7679

0 commit comments

Comments
 (0)