Skip to content

Commit e616784

Browse files
committed
Fix: Corrected embed initialization in EuclideanCodebook forward pass
1 parent f20e474 commit e616784

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -532,17 +532,15 @@ def forward(
532532
if self.affine_param:
533533
self.update_affine(flatten, self.embed, mask = mask)
534534

535-
# affine params
535+
# get maybe learnable codes
536+
embed = self.embed if self.learnable_codebook else self.embed.detach()
536537

538+
# affine params
537539
if self.affine_param:
538540
codebook_std = self.codebook_variance.clamp(min = 1e-5).sqrt()
539541
batch_std = self.batch_variance.clamp(min = 1e-5).sqrt()
540542
embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean
541543

542-
# get maybe learnable codes
543-
544-
embed = self.embed if self.learnable_codebook else self.embed.detach()
545-
546544
# handle maybe implicit neural codebook
547545
# and calculate distance
548546

0 commit comments

Comments
 (0)