Skip to content

Commit 0b2fe93

Browse files
authored
Merge pull request #53 from kashif/patch-1
Fix learnable_codebook output
2 parents 6fff51a + f6ab8df commit 0b2fe93

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def forward(
405405
if self.affine_param:
406406
self.update_affine(flatten, self.embed)
407407

408-
embed = self.embed if not self.learnable_codebook else self.embed.detach()
408+
embed = self.embed if self.learnable_codebook else self.embed.detach()
409409

410410
if self.affine_param:
411411
codebook_std = self.codebook_variance.clamp(min = 1e-5).sqrt()
@@ -572,7 +572,7 @@ def forward(
572572

573573
self.init_embed_(flatten)
574574

575-
embed = self.embed if not self.learnable_codebook else self.embed.detach()
575+
embed = self.embed if self.learnable_codebook else self.embed.detach()
576576

577577
dist = einsum('h n d, h c d -> h n c', flatten, embed)
578578

0 commit comments

Comments
 (0)