Skip to content

Commit 826c620

Browse files
committed
fix residual VQ in eval mode
1 parent 2352c87 commit 826c620

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '0.3.8',
6+
version = '0.3.9',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
author = 'Phil Wang',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,11 +288,11 @@ def forward(self, x):
288288

289289
quantize, embed_ind = self._codebook(x)
290290

291-
commit_loss = 0.
292-
293291
if self.training:
294292
commit_loss = F.mse_loss(quantize.detach(), x) * self.commitment
295293
quantize = x + (quantize - x).detach()
294+
else:
295+
commit_loss = torch.tensor([0.], device = x.device)
296296

297297
quantize = self.project_out(quantize)
298298

0 commit comments

Comments
 (0)