Skip to content

Commit 8417aec

Browse files
committed
should be using l2norm of input for cosine sim
1 parent d2ec399 commit 8417aec

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.5.4',
6+
version = '1.5.5',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,9 @@ def forward(self, x):
312312

313313
quantize = batched_embedding(embed_ind, self.embed)
314314

315+
if self.training:
316+
quantize = x + (quantize - x).detach()
317+
315318
if needs_codebook_dim:
316319
quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))
317320

@@ -456,6 +459,10 @@ def forward(self, x):
456459

457460
quantize = batched_embedding(embed_ind, self.embed)
458461

462+
if self.training:
463+
l2norm_x = l2norm(x)
464+
quantize = l2norm_x + (quantize - l2norm_x).detach()
465+
459466
if needs_codebook_dim:
460467
quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))
461468

@@ -587,9 +594,6 @@ def forward(
587594

588595
quantize, embed_ind, distances = self._codebook(x)
589596

590-
if self.training:
591-
quantize = x + (quantize - x).detach()
592-
593597
if return_loss:
594598
if not is_multiheaded:
595599
dist_einops_eq = '1 b n l -> b l n'

0 commit comments

Comments
 (0)