Skip to content

Commit ae5db01

Browse files
committed
reintroduce orthogonal regularization, due to bug that @kashif found
1 parent d23d1fd commit ae5db01

File tree

3 files changed

+72
-2
lines changed

3 files changed

+72
-2
lines changed

README.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,28 @@ x = torch.randn(1, 1024, 256)
182182
quantized, indices, commit_loss = vq(x)
183183
```
184184

185+
### Orthogonal regularization loss
186+
187+
VQ-VAE / VQ-GAN is quickly gaining popularity. A <a href="https://arxiv.org/abs/2112.00384">recent paper</a> proposes that when using vector quantization on images, enforcing the codebook to be orthogonal leads to translation equivariance of the discretized codes, leading to large improvements in downstream text to image generation tasks.
188+
189+
You can use this feature by simply setting the `orthogonal_reg_weight` to be greater than `0`, in which case the orthogonal regularization will be added to the auxiliary loss outputted by the module.
190+
191+
```python
192+
import torch
193+
from vector_quantize_pytorch import VectorQuantize
194+
vq = VectorQuantize(
195+
dim = 256,
196+
codebook_size = 256,
197+
accept_image_fmap = True, # set this true to be able to pass in an image feature map
198+
orthogonal_reg_weight = 10, # in paper, they recommended a value of 10
199+
orthogonal_reg_max_codes = 128, # this would randomly sample from the codebook for the orthogonal regularization loss, for limiting memory usage
200+
orthogonal_reg_active_codes_only = False # set this to True if you have a very large codebook, and would only like to enforce the loss on the activated codes per batch
201+
)
202+
img_fmap = torch.randn(1, 256, 32, 32)
203+
quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32), (1,)
204+
# loss now contains the orthogonal regularization loss with the weight as assigned
205+
```
206+
185207
### Multi-headed VQ
186208

187209
There has been a number of papers that proposes variants of discrete latent representations with a multi-headed approach (multiple codes per feature). I have decided to offer one variant where the same codebook is used to vector quantize across the input dimension `head` times.
@@ -399,3 +421,14 @@ if __name__ == '__main__':
399421
url = {https://openreview.net/forum?id=oapKSVM2bcj}
400422
}
401423
```
424+
425+
```bibtex
426+
@misc{shin2021translationequivariant,
427+
title = {Translation-equivariant Image Quantizer for Bi-directional Image-Text Generation},
428+
author = {Woncheol Shin and Gyubok Lee and Jiyoung Lee and Joonseok Lee and Edward Choi},
429+
year = {2021},
430+
eprint = {2112.00384},
431+
archivePrefix = {arXiv},
432+
primaryClass = {cs.CV}
433+
}
434+
```

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.6.7',
6+
version = '1.6.9',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,15 @@ def batched_embedding(indices, embeds):
211211
embeds = repeat(embeds, 'h c d -> h b c d', b = batch)
212212
return embeds.gather(2, indices)
213213

214+
# regularization losses
215+
216+
def orthogonal_loss_fn(t):
217+
# eq (2) from https://arxiv.org/abs/2112.00384
218+
h, n = t.shape[:2]
219+
normed_codes = l2norm(t)
220+
cosine_sim = einsum('h i d, h j d -> h i j', normed_codes, normed_codes)
221+
return (cosine_sim ** 2).sum() / (h * n ** 2) - (1 / n)
222+
214223
# distance types
215224

216225
class EuclideanCodebook(nn.Module):
@@ -630,6 +639,9 @@ def __init__(
630639
accept_image_fmap = False,
631640
commitment_weight = 1.,
632641
commitment_use_cross_entropy_loss = False,
642+
orthogonal_reg_weight = 0.,
643+
orthogonal_reg_active_codes_only = False,
644+
orthogonal_reg_max_codes = None,
633645
stochastic_sample_codes = False,
634646
sample_codebook_temp = 1.,
635647
straight_through = False,
@@ -659,6 +671,12 @@ def __init__(
659671
self.commitment_weight = commitment_weight
660672
self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss # whether to use cross entropy loss to codebook as commitment loss
661673

674+
has_codebook_orthogonal_loss = orthogonal_reg_weight > 0
675+
self.has_codebook_orthogonal_loss = has_codebook_orthogonal_loss
676+
self.orthogonal_reg_weight = orthogonal_reg_weight
677+
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
678+
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
679+
662680
assert not (ema_update and learnable_codebook), 'learnable codebook not compatible with EMA update'
663681

664682
assert 0 <= sync_update_v <= 1.
@@ -686,7 +704,7 @@ def __init__(
686704
eps = eps,
687705
threshold_ema_dead_code = threshold_ema_dead_code,
688706
use_ddp = sync_codebook,
689-
learnable_codebook = learnable_codebook,
707+
learnable_codebook = has_codebook_orthogonal_loss or learnable_codebook,
690708
sample_codebook_temp = sample_codebook_temp,
691709
gumbel_sample = gumbel_sample_fn,
692710
ema_update = ema_update
@@ -854,6 +872,25 @@ def calculate_ce_loss(codes):
854872

855873
loss = loss + commit_loss * self.commitment_weight
856874

875+
if self.has_codebook_orthogonal_loss:
876+
codebook = self._codebook.embed
877+
878+
# only calculate orthogonal loss for the activated codes for this batch
879+
880+
if self.orthogonal_reg_active_codes_only:
881+
assert not (is_multiheaded and self.separate_codebook_per_head), 'orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet'
882+
unique_code_ids = torch.unique(embed_ind)
883+
codebook = codebook[:, unique_code_ids]
884+
885+
num_codes = codebook.shape[-2]
886+
887+
if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
888+
rand_ids = torch.randperm(num_codes, device = device)[:self.orthogonal_reg_max_codes]
889+
codebook = codebook[:, rand_ids]
890+
891+
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
892+
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
893+
857894
# handle multi-headed quantized embeddings
858895

859896
if is_multiheaded:

0 commit comments

Comments
 (0)