Skip to content

Commit f259324

Browse files
committed
detach the quantize commit target if not using learnable codebook
1 parent 1de96c0 commit f259324

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.6.16',
6+
version = '1.6.17',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,8 @@ def __init__(
679679
self.commitment_weight = commitment_weight
680680
self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss # whether to use cross entropy loss to codebook as commitment loss
681681

682+
self.learnable_codebook = learnable_codebook
683+
682684
has_codebook_orthogonal_loss = orthogonal_reg_weight > 0
683685
self.has_codebook_orthogonal_loss = has_codebook_orthogonal_loss
684686
self.orthogonal_reg_weight = orthogonal_reg_weight
@@ -806,7 +808,14 @@ def forward(
806808
quantize, embed_ind, distances = self._codebook(x, sample_codebook_temp = sample_codebook_temp)
807809

808810
if self.training:
809-
orig_quantize = quantize
811+
# determine code to use for commitment loss
812+
813+
if not self.learnable_codebook:
814+
commit_quantize = quantize.detach()
815+
else:
816+
commit_quantize = quantize
817+
818+
# straight through
810819

811820
quantize = x + (quantize - x).detach()
812821

@@ -869,14 +878,14 @@ def calculate_ce_loss(codes):
869878
else:
870879
if exists(mask):
871880
# with variable lengthed sequences
872-
commit_loss = F.mse_loss(orig_quantize, x, reduction = 'none')
881+
commit_loss = F.mse_loss(commit_quantize, x, reduction = 'none')
873882

874883
if is_multiheaded:
875884
mask = repeat(mask, 'b n -> c (b h) n', c = commit_loss.shape[0], h = commit_loss.shape[1] // mask.shape[0])
876885

877886
commit_loss = commit_loss[mask].mean()
878887
else:
879-
commit_loss = F.mse_loss(orig_quantize, x)
888+
commit_loss = F.mse_loss(commit_quantize, x)
880889

881890
loss = loss + commit_loss * self.commitment_weight
882891

0 commit comments

Comments
 (0)