Skip to content

Commit d3dd5e5

Browse files
authored
Merge pull request #164 from cfifty/master
Implement the rotation trick.
2 parents 5fe30db + a2babde commit d3dd5e5

File tree

3 files changed

+58
-7
lines changed

3 files changed

+58
-7
lines changed

README.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,20 @@ quantized, indices, commit_loss = residual_vq(x)
122122
# (1, 1024, 256), (1, 1024, 4), (1, 4)
123123
```
124124

125+
## Gradient Computation
126+
127+
VQ-VAEs are traditionally trained with the straight-through estimator (STE). During the backwards pass, the gradient flows _around_ the VQ layer rather than _through_ it. The <a href="https://arxiv.org/abs/2410.06424">rotation trick paper</a> proposes to transform the gradient _through_ the VQ layer so the relative angle and magnitude between the input vector and quantized output are encoded into the gradient. You can enable or disable this feature with ```rotation_trick=True/False``` in the ```VectorQuantize``` class.
128+
129+
```python
130+
from vector_quantize_pytorch import VectorQuantize
131+
132+
vq_layer = VectorQuantize(
133+
dim = 256,
134+
codebook_size = 256,
135+
rotation_trick = True, # Set to False to use the STE gradient estimator or True to use the rotation trick.
136+
)
137+
```
138+
125139
## Increasing codebook usage
126140

127141
This repository will contain a few techniques from various papers to combat "dead" codebook entries, which is a common problem when using vector quantizers.
@@ -699,3 +713,14 @@ assert loss.item() >= 0
699713
url = {https://api.semanticscholar.org/CorpusID:267301189}
700714
}
701715
```
716+
717+
```bibtex
718+
@article{Fifty2024Restructuring,
719+
title = {Restructuring Vector Quantization with the Rotation Trick},
720+
author = {Christopher Fifty, Ronald G. Junkins, Dennis Duan, Aniketh Iyengar, Jerry W. Liu, Ehsan Amid, Sebastian Thrun, Christopher Ré},
721+
journal = {ArXiv},
722+
year = {2024},
723+
volume = {abs/2410.06424},
724+
url = {https://api.semanticscholar.org/CorpusID:273229218}
725+
}
726+
```

tests/test_readme.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def test_residual_vq(
8989

9090
quantized, indices, commit_loss = residual_vq(x, freeze_codebook = train and not implicit_neural_codebook)
9191
quantized_out = residual_vq.get_output_from_indices(indices)
92-
assert torch.allclose(quantized, quantized_out, atol = 1e-6)
92+
assert torch.allclose(quantized, quantized_out, atol = 1e-5)
9393

9494
def test_residual_vq2():
9595
from vector_quantize_pytorch import ResidualVQ

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,7 @@ def __init__(
811811
stochastic_sample_codes = False,
812812
sample_codebook_temp = 1.,
813813
straight_through = False,
814+
rotation_trick = True, # Propagate grads through VQ layer w/ rotation trick: https://arxiv.org/abs/2410.06424.
814815
reinmax = False, # using reinmax for improved straight-through, assuming straight through helps at all
815816
sync_codebook = None,
816817
sync_affine_param = False,
@@ -821,7 +822,7 @@ def __init__(
821822
manual_in_place_optimizer_update = False,
822823
affine_param = False,
823824
affine_param_batch_decay = 0.99,
824-
affine_param_codebook_decay = 0.9,
825+
affine_param_codebook_decay = 0.9,
825826
sync_update_v = 0., # the v that controls optimistic vs pessimistic update for synchronous update rule (21) https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
826827
return_zeros_for_masked_padding = True
827828
):
@@ -863,6 +864,9 @@ def __init__(
863864
self.codebook_diversity_temperature = codebook_diversity_temperature
864865
self.codebook_diversity_loss_weight = codebook_diversity_loss_weight
865866

867+
assert not (straight_through and rotation_trick)
868+
self.rotation_trick = rotation_trick
869+
866870
assert not (ema_update and learnable_codebook), 'learnable codebook not compatible with EMA update'
867871

868872
assert 0 <= sync_update_v <= 1.
@@ -942,6 +946,13 @@ def codebook(self, codes):
942946

943947
self._codebook.embed.copy_(codes)
944948

949+
@staticmethod
950+
def rotation_trick_transform(u, q, e):
951+
w = ((u + q) / torch.norm(u + q, dim=1, keepdim=True)).detach()
952+
e = e - 2 * torch.bmm(torch.bmm(e, w.unsqueeze(-1)), w.unsqueeze(1)) + 2 * torch.bmm(
953+
torch.bmm(e, u.unsqueeze(-1).detach()), q.unsqueeze(1).detach())
954+
return e
955+
945956
def get_codes_from_indices(self, indices):
946957
codebook = self.codebook
947958
is_multiheaded = codebook.ndim > 2
@@ -1090,11 +1101,26 @@ def forward(
10901101
# determine code to use for commitment loss
10911102
maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity
10921103

1093-
commit_quantize = maybe_detach(quantize)
1094-
1095-
# straight through
1096-
1097-
quantize = x + (quantize - x).detach()
1104+
commit_quantize = maybe_detach(quantize)
1105+
1106+
# Use the rotation trick (https://arxiv.org/abs/2410.06424) to get gradients through VQ layer.
1107+
if self.rotation_trick:
1108+
init_shape = x.shape
1109+
x = x.reshape(-1, init_shape[-1])
1110+
quantize = quantize.reshape(-1, init_shape[-1])
1111+
1112+
eps = 1e-6 # For numerical stability if any vector is close to 0 norm.
1113+
rot_quantize = self.rotation_trick_transform(
1114+
x / (torch.norm(x, dim=1, keepdim=True) + eps),
1115+
quantize / (torch.norm(quantize, dim=1, keepdim=True) + eps),
1116+
x.unsqueeze(1)).squeeze()
1117+
quantize = rot_quantize * (torch.norm(quantize, dim=1, keepdim=True)
1118+
/ (torch.norm(x, dim=1, keepdim=True) + 1e-6)).detach()
1119+
1120+
x = x.reshape(init_shape)
1121+
quantize = quantize.reshape(init_shape)
1122+
else: # Use STE to get gradients through VQ layer.
1123+
quantize = x + (quantize - x).detach()
10981124

10991125
if self.sync_update_v > 0.:
11001126
# (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf

0 commit comments

Comments
 (0)