Skip to content

Commit e0e073d

Browse files
committed
fix orthogonal regularization, addressing #43
1 parent b449efc commit e0e073d

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.1.5',
6+
version = '1.1.6',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -592,15 +592,18 @@ def forward(
592592
if self.orthogonal_reg_weight > 0:
593593
codebook = self._codebook.embed
594594

595+
# only calculate orthogonal loss for the activated codes for this batch
596+
595597
if self.orthogonal_reg_active_codes_only:
596-
# only calculate orthogonal loss for the activated codes for this batch
598+
assert not (is_multiheaded and self.separate_codebook_per_head), 'orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet'
597599
unique_code_ids = torch.unique(embed_ind)
598-
codebook = codebook[unique_code_ids]
600+
codebook = codebook[:, unique_code_ids]
601+
602+
num_codes = codebook.shape[-2]
599603

600-
num_codes = codebook.shape[0]
601604
if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
602605
rand_ids = torch.randperm(num_codes, device = device)[:self.orthogonal_reg_max_codes]
603-
codebook = codebook[rand_ids]
606+
codebook = codebook[:, rand_ids]
604607

605608
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
606609
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight

0 commit comments

Comments
 (0)