Skip to content

Commit a868492

Browse files
committed
make sure the codebook at least receive gradients from sync update
1 parent ae5db01 commit a868492

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.6.9',
6+
version = '1.6.10',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -798,11 +798,13 @@ def forward(
798798
quantize, embed_ind, distances = self._codebook(x, sample_codebook_temp = sample_codebook_temp)
799799

800800
if self.training:
801-
quantize = x + (quantize - x).detach()
801+
maybe_sync_update = 0.
802802

803803
if self.sync_update_v > 0.:
804804
# (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
805-
quantize = quantize + self.sync_update_v * (quantize - quantize.detach())
805+
maybe_sync_update = self.sync_update_v * (quantize - quantize.detach())
806+
print(maybe_sync_update)
807+
quantize = x + (quantize - x).detach() + maybe_sync_update
806808

807809
# function for calculating cross entropy loss to distance matrix
808810
# used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss

0 commit comments

Comments
 (0)