Skip to content

Commit a4bef4d

Browse files
committed
add compatibility for the residual VQ proposed in TIGER https://arxiv.org/abs/2305.05065, for building recommendation systems
1 parent 35a8a41 commit a4bef4d

File tree

4 files changed

+93
-13
lines changed

4 files changed

+93
-13
lines changed

README.md

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,29 @@ quantized, indices, commit_loss = residual_vq(x)
100100
# (1, 1024, 256), (2, 1, 1024, 8), (2, 1, 8)
101101
```
102102

103+
104+
<a href="https://arxiv.org/abs/2305.05065">This paper</a> out of Google Deepmind claims that residual vector quantization can induce hierarchical semantic ids for building a recommender system. In their scheme, they use increasing number of codes across depth for it to work. This repository supports that scheme as so
105+
106+
```python
107+
import torch
108+
from vector_quantize_pytorch import ResidualVQ
109+
110+
residual_vq = ResidualVQ(
111+
dim = 2,
112+
codebook_size = (5, 128, 256), # from top most hierarchy to lowest, 5 codes, 128 codes, then 256 codes
113+
)
114+
115+
x = torch.randn(2, 2, 2)
116+
117+
residual_vq.train()
118+
119+
quantized, indices, commit_loss = residual_vq(x, freeze_codebook = True)
120+
121+
quantized_out = residual_vq.get_output_from_indices(indices)
122+
123+
assert torch.allclose(quantized, quantized_out, atol = 1e-5)
124+
```
125+
103126
## Initialization
104127

105128
The SoundStream paper proposes that the codebook should be initialized by the kmeans centroids of the first batch. You can easily turn on this feature with one flag `kmeans_init = True`, for either `VectorQuantize` or `ResidualVQ` class
@@ -713,4 +736,15 @@ assert loss.item() >= 0
713736
volume = {abs/2410.06424},
714737
url = {https://api.semanticscholar.org/CorpusID:273229218}
715738
}
716-
```
739+
```
740+
741+
```bibtex
742+
@article{Rajput2023RecommenderSW,
743+
title = {Recommender Systems with Generative Retrieval},
744+
author = {Shashank Rajput and Nikhil Mehta and Anima Singh and Raghunandan H. Keshavan and Trung Hieu Vu and Lukasz Heldt and Lichan Hong and Yi Tay and Vinh Q. Tran and Jonah Samost and Maciej Kula and Ed H. Chi and Maheswaran Sathiamoorthy},
745+
journal = {ArXiv},
746+
year = {2023},
747+
volume = {abs/2305.05065},
748+
url = {https://api.semanticscholar.org/CorpusID:258564854}
749+
}
750+
```

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.18.8"
3+
version = "1.19.0"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_readme.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,24 @@ def test_rq():
201201
x = torch.randn(1, 1024, 512)
202202
indices = quantizer(x)
203203

204+
def test_tiger():
205+
from vector_quantize_pytorch import ResidualVQ
206+
207+
residual_vq = ResidualVQ(
208+
dim = 2,
209+
codebook_size = (5, 128, 256),
210+
)
211+
212+
x = torch.randn(2, 2, 2)
213+
214+
residual_vq.train()
215+
216+
quantized, indices, commit_loss = residual_vq(x, freeze_codebook = True)
217+
218+
quantized_out = residual_vq.get_output_from_indices(indices) # pass your indices into here, but the indices must come during .eval(), as during training some of the indices are dropped out (-1)
219+
220+
assert torch.allclose(quantized, quantized_out, atol = 1e-5)
221+
204222
def test_fsq():
205223
from vector_quantize_pytorch import FSQ
206224

vector_quantize_pytorch/residual_vq.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from __future__ import annotations
2-
from typing import List
32

43
import random
54
from math import ceil
@@ -28,6 +27,12 @@ def first(it):
2827
def default(val, d):
2928
return val if exists(val) else d
3029

30+
def cast_tuple(t, length = 1):
31+
return t if isinstance(t, tuple) else ((t,) * length)
32+
33+
def unique(arr):
34+
return list({*arr})
35+
3136
def round_up_multiple(num, mult):
3237
return ceil(num / mult) * mult
3338

@@ -110,7 +115,8 @@ def __init__(
110115
self,
111116
*,
112117
dim,
113-
num_quantizers,
118+
num_quantizers: int | None = None,
119+
codebook_size: int | tuple[int, ...],
114120
codebook_dim = None,
115121
shared_codebook = False,
116122
heads = 1,
@@ -124,6 +130,8 @@ def __init__(
124130
):
125131
super().__init__()
126132
assert heads == 1, 'residual vq is not compatible with multi-headed codes'
133+
assert exists(num_quantizers) or isinstance(codebook_size, tuple)
134+
127135
codebook_dim = default(codebook_dim, dim)
128136
codebook_input_dim = codebook_dim * heads
129137

@@ -132,8 +140,6 @@ def __init__(
132140
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
133141
self.has_projections = requires_projection
134142

135-
self.num_quantizers = num_quantizers
136-
137143
self.accept_image_fmap = accept_image_fmap
138144

139145
self.implicit_neural_codebook = implicit_neural_codebook
@@ -150,7 +156,21 @@ def __init__(
150156
manual_in_place_optimizer_update = True
151157
)
152158

153-
self.layers = ModuleList([VectorQuantize(dim = codebook_dim, codebook_dim = codebook_dim, accept_image_fmap = accept_image_fmap, **vq_kwargs) for _ in range(num_quantizers)])
159+
# take care of maybe different codebook sizes across depth, used in TIGER paper https://arxiv.org/abs/2305.05065
160+
161+
codebook_sizes = cast_tuple(codebook_size, num_quantizers)
162+
163+
num_quantizers = len(codebook_sizes)
164+
self.num_quantizers = num_quantizers
165+
166+
assert len(codebook_sizes) == num_quantizers
167+
168+
self.codebook_sizes = codebook_sizes
169+
self.uniform_codebook_size = len(unique(codebook_sizes)) == 1
170+
171+
# define vq across layers
172+
173+
self.layers = ModuleList([VectorQuantize(dim = codebook_dim, codebook_size = layer_codebook_size, codebook_dim = codebook_dim, accept_image_fmap = accept_image_fmap, **vq_kwargs) for layer_codebook_size in codebook_sizes])
154174

155175
assert all([not vq.has_projections for vq in self.layers])
156176

@@ -167,6 +187,8 @@ def __init__(
167187

168188
if implicit_neural_codebook:
169189
self.mlps = ModuleList([MLP(dim = codebook_dim, l2norm_output = first(self.layers).use_cosine_sim, **mlp_kwargs) for _ in range(num_quantizers - 1)])
190+
else:
191+
self.mlps = (None,) * (num_quantizers - 1)
170192

171193
# sharing codebook logic
172194

@@ -175,6 +197,8 @@ def __init__(
175197
if not shared_codebook:
176198
return
177199

200+
assert self.uniform_codebook_size
201+
178202
first_vq, *rest_vq = self.layers
179203
codebook = first_vq._codebook
180204

@@ -192,8 +216,13 @@ def codebook_dim(self):
192216
@property
193217
def codebooks(self):
194218
codebooks = [layer._codebook.embed for layer in self.layers]
195-
codebooks = torch.stack(codebooks, dim = 0)
196-
codebooks = rearrange(codebooks, 'q 1 c d -> q c d')
219+
220+
codebooks = tuple(rearrange(codebook, '1 ... -> ...') for codebook in codebooks)
221+
222+
if not self.uniform_codebook_size:
223+
return codebooks
224+
225+
codebooks = torch.stack(codebooks)
197226
return codebooks
198227

199228
def get_codes_from_indices(self, indices):
@@ -216,13 +245,12 @@ def get_codes_from_indices(self, indices):
216245
mask = indices == -1.
217246
indices = indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later
218247

219-
if not self.implicit_neural_codebook:
220-
# gather all the codes
248+
if not self.implicit_neural_codebook and self.uniform_codebook_size:
221249

222250
all_codes = get_at('q [c] d, b n q -> q b n d', self.codebooks, indices)
223251

224252
else:
225-
# else if using implicit neural codebook, codes will need to be derived layer by layer
253+
# else if using implicit neural codebook, or non uniform codebook sizes, codes will need to be derived layer by layer
226254

227255
code_transform_mlps = (None, *self.mlps)
228256

@@ -261,7 +289,7 @@ def forward(
261289
self,
262290
x,
263291
mask = None,
264-
indices: Tensor | List[Tensor] | None = None,
292+
indices: Tensor | list[Tensor] | None = None,
265293
return_all_codes = False,
266294
sample_codebook_temp = None,
267295
freeze_codebook = False,

0 commit comments

Comments
 (0)