Skip to content

Commit 8a5a159

Browse files
committed
do not used detached quantized tensor for commit loss, to support learnable codebooks correctly
1 parent 4a44819 commit 8a5a159

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.6.12',
6+
version = '1.6.14',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,8 @@ def forward(
807807
quantize, embed_ind, distances = self._codebook(x, sample_codebook_temp = sample_codebook_temp)
808808

809809
if self.training:
810+
orig_quantize = quantize
811+
810812
quantize = x + (quantize - x).detach()
811813

812814
if self.sync_update_v > 0.:
@@ -866,18 +868,16 @@ def calculate_ce_loss(codes):
866868

867869
commit_loss = calculate_ce_loss(embed_ind)
868870
else:
869-
detached_quantize = quantize.detach()
870-
871871
if exists(mask):
872872
# with variable lengthed sequences
873-
commit_loss = F.mse_loss(detached_quantize, x, reduction = 'none')
873+
commit_loss = F.mse_loss(orig_quantize, x, reduction = 'none')
874874

875875
if is_multiheaded:
876876
mask = repeat(mask, 'b n -> c (b h) n', c = commit_loss.shape[0], h = commit_loss.shape[1] // mask.shape[0])
877877

878878
commit_loss = commit_loss[mask].mean()
879879
else:
880-
commit_loss = F.mse_loss(detached_quantize, x)
880+
commit_loss = F.mse_loss(orig_quantize, x)
881881

882882
loss = loss + commit_loss * self.commitment_weight
883883

0 commit comments

Comments
 (0)