You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+33Lines changed: 33 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -182,6 +182,28 @@ x = torch.randn(1, 1024, 256)
182
182
quantized, indices, commit_loss = vq(x)
183
183
```
184
184
185
+
### Orthogonal regularization loss
186
+
187
+
VQ-VAE / VQ-GAN is quickly gaining popularity. A <ahref="https://arxiv.org/abs/2112.00384">recent paper</a> proposes that when using vector quantization on images, enforcing the codebook to be orthogonal leads to translation equivariance of the discretized codes, leading to large improvements in downstream text to image generation tasks.
188
+
189
+
You can use this feature by simply setting the `orthogonal_reg_weight` to be greater than `0`, in which case the orthogonal regularization will be added to the auxiliary loss outputted by the module.
190
+
191
+
```python
192
+
import torch
193
+
from vector_quantize_pytorch import VectorQuantize
194
+
vq = VectorQuantize(
195
+
dim=256,
196
+
codebook_size=256,
197
+
accept_image_fmap=True, # set this true to be able to pass in an image feature map
198
+
orthogonal_reg_weight=10, # in paper, they recommended a value of 10
199
+
orthogonal_reg_max_codes=128, # this would randomly sample from the codebook for the orthogonal regularization loss, for limiting memory usage
200
+
orthogonal_reg_active_codes_only=False# set this to True if you have a very large codebook, and would only like to enforce the loss on the activated codes per batch
# loss now contains the orthogonal regularization loss with the weight as assigned
205
+
```
206
+
185
207
### Multi-headed VQ
186
208
187
209
There has been a number of papers that proposes variants of discrete latent representations with a multi-headed approach (multiple codes per feature). I have decided to offer one variant where the same codebook is used to vector quantize across the input dimension `head` times.
# only calculate orthogonal loss for the activated codes for this batch
879
+
880
+
ifself.orthogonal_reg_active_codes_only:
881
+
assertnot (is_multiheadedandself.separate_codebook_per_head), 'orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet'
0 commit comments