Skip to content

Commit f47b7db

Browse files
committed
incorporate the implicit neural codebook into residual vq
1 parent d209796 commit f47b7db

File tree

4 files changed

+221
-31
lines changed

4 files changed

+221
-31
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,3 +688,14 @@ assert loss.item() >= 0
688688
url = {https://api.semanticscholar.org/CorpusID:256901024}
689689
}
690690
```
691+
692+
```bibtex
693+
@article{Huijben2024ResidualQW,
694+
title = {Residual Quantization with Implicit Neural Codebooks},
695+
author = {Iris Huijben and Matthijs Douze and Matthew Muckley and Ruud van Sloun and Jakob Verbeek},
696+
journal = {ArXiv},
697+
year = {2024},
698+
volume = {abs/2401.14732},
699+
url = {https://api.semanticscholar.org/CorpusID:267301189}
700+
}
701+
```

tests/test_readme.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,20 +62,35 @@ def test_vq_mask():
6262
assert (mask_quantized[:, 512:] == 0.).all()
6363
assert (mask_indices[:, 512:] == -1).all()
6464

65-
def test_residual_vq():
65+
@pytest.mark.parametrize('implicit_neural_codebook', (True, False))
66+
@pytest.mark.parametrize('use_cosine_sim', (True, False))
67+
def test_residual_vq(
68+
implicit_neural_codebook,
69+
use_cosine_sim
70+
):
6671
from vector_quantize_pytorch import ResidualVQ
6772

6873
residual_vq = ResidualVQ(
6974
dim = 256,
70-
num_quantizers = 8, # specify number of quantizers
71-
codebook_size = 1024, # codebook size
75+
num_quantizers = 8,
76+
codebook_size = 1024,
77+
implicit_neural_codebook = implicit_neural_codebook,
78+
use_cosine_sim = use_cosine_sim,
7279
)
7380

7481
x = torch.randn(1, 1024, 256)
7582

7683
quantized, indices, commit_loss = residual_vq(x)
7784
quantized, indices, commit_loss, all_codes = residual_vq(x, return_all_codes = True)
7885

86+
# test eval mode and `get_output_from_indices`
87+
88+
residual_vq.eval()
89+
quantized, indices, commit_loss = residual_vq(x)
90+
91+
quantized_out = residual_vq.get_output_from_indices(indices)
92+
assert torch.allclose(quantized, quantized_out, atol = 1e-6)
93+
7994
def test_residual_vq2():
8095
from vector_quantize_pytorch import ResidualVQ
8196

@@ -91,7 +106,6 @@ def test_residual_vq2():
91106
x = torch.randn(1, 1024, 256)
92107
quantized, indices, commit_loss = residual_vq(x)
93108

94-
95109
def test_grouped_residual_vq():
96110
from vector_quantize_pytorch import GroupedResidualVQ
97111

vector_quantize_pytorch/residual_vq.py

Lines changed: 123 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
def exists(val):
2323
return val is not None
2424

25+
def first(it):
26+
return it[0]
27+
2528
def default(val, d):
2629
return val if exists(val) else d
2730

@@ -34,6 +37,64 @@ def round_up_multiple(num, mult):
3437
def is_distributed():
3538
return dist.is_initialized() and dist.get_world_size() > 1
3639

40+
# the mlp for generating the neural implicit codebook
41+
# from Huijben et al. https://arxiv.org/abs/2401.14732
42+
43+
class MLP(Module):
44+
def __init__(
45+
self,
46+
dim,
47+
dim_hidden = None,
48+
depth = 4, # they used 4 layers in the paper
49+
l2norm_output = False
50+
):
51+
super().__init__()
52+
dim_hidden = default(dim_hidden, dim)
53+
54+
self.proj_in = nn.Linear(2 * dim, dim)
55+
56+
layers = ModuleList([])
57+
58+
for _ in range(depth):
59+
layers.append(nn.Sequential(
60+
nn.Linear(dim, dim_hidden),
61+
nn.SiLU(),
62+
nn.Linear(dim_hidden, dim)
63+
))
64+
65+
self.layers = layers
66+
self.l2norm_output = l2norm_output
67+
68+
def forward(
69+
self,
70+
codes,
71+
*,
72+
condition
73+
):
74+
one_headed = codes.ndim == 2
75+
76+
if one_headed:
77+
codes = rearrange(codes, 'c d -> 1 c d')
78+
79+
heads, num_codes, batch, seq_len = codes.shape[0], codes.shape[-2], condition.shape[0], condition.shape[-2]
80+
81+
codes = repeat(codes, 'h c d -> h b n c d', n = seq_len, b = batch)
82+
condition = repeat(condition, 'b n d -> h b n c d', c = num_codes, h = heads)
83+
84+
x = torch.cat((condition, codes), dim = -1)
85+
x = self.proj_in(x)
86+
87+
for layer in self.layers:
88+
x = layer(x) + x
89+
90+
if self.l2norm_output:
91+
x = F.normalize(x, dim = -1)
92+
93+
if not one_headed:
94+
return x
95+
96+
return rearrange(x, '1 ... -> ...')
97+
3798
# main class
3899

39100
class ResidualVQ(Module):
@@ -50,7 +111,9 @@ def __init__(
50111
quantize_dropout_cutoff_index = 0,
51112
quantize_dropout_multiple_of = 1,
52113
accept_image_fmap = False,
53-
**kwargs
114+
implicit_neural_codebook = False, # QINCo from https://arxiv.org/abs/2401.14732
115+
mlp_kwargs: dict = dict(),
116+
**vq_kwargs
54117
):
55118
super().__init__()
56119
assert heads == 1, 'residual vq is not compatible with multi-headed codes'
@@ -65,7 +128,16 @@ def __init__(
65128
self.num_quantizers = num_quantizers
66129

67130
self.accept_image_fmap = accept_image_fmap
68-
self.layers = ModuleList([VectorQuantize(dim = codebook_dim, codebook_dim = codebook_dim, accept_image_fmap = accept_image_fmap, **kwargs) for _ in range(num_quantizers)])
131+
132+
self.implicit_neural_codebook = implicit_neural_codebook
133+
134+
if implicit_neural_codebook:
135+
vq_kwargs.update(
136+
learnable_codebook = True,
137+
ema_update = False
138+
)
139+
140+
self.layers = ModuleList([VectorQuantize(dim = codebook_dim, codebook_dim = codebook_dim, accept_image_fmap = accept_image_fmap, **vq_kwargs) for _ in range(num_quantizers)])
69141

70142
assert all([not vq.has_projections for vq in self.layers])
71143

@@ -76,6 +148,12 @@ def __init__(
76148
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
77149
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
78150

151+
# setting up the MLPs for implicit neural codebooks
152+
153+
self.mlps = ModuleList([MLP(dim = codebook_dim, l2norm_output = first(self.layers).use_cosine_sim, **mlp_kwargs) for _ in range(num_quantizers - 1)])
154+
155+
# sharing codebook logic
156+
79157
if not shared_codebook:
80158
return
81159

@@ -120,7 +198,31 @@ def get_codes_from_indices(self, indices):
120198
mask = indices == -1.
121199
indices = indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later
122200

123-
all_codes = get_at('q [c] d, b n q -> q b n d', self.codebooks, indices)
201+
if not self.implicit_neural_codebook:
202+
# gather all the codes
203+
204+
all_codes = get_at('q [c] d, b n q -> q b n d', self.codebooks, indices)
205+
206+
else:
207+
# else if using implicit neural codebook, codes will need to be derived layer by layer
208+
209+
code_transform_mlps = (None, *self.mlps)
210+
211+
all_codes = []
212+
quantized_out = 0.
213+
214+
for codes, indices, maybe_transform_mlp in zip(self.codebooks, indices.unbind(dim = -1), code_transform_mlps):
215+
216+
if exists(maybe_transform_mlp):
217+
codes = maybe_transform_mlp(codes, condition = quantized_out)
218+
layer_codes = get_at('b n [c] d, b n -> b n d', codes, indices)
219+
else:
220+
layer_codes = get_at('[c] d, b n -> b n d', codes, indices)
221+
222+
all_codes.append(layer_codes)
223+
quantized_out += layer_codes
224+
225+
all_codes = torch.stack(all_codes)
124226

125227
# mask out any codes that were dropout-ed
126228

@@ -195,9 +297,16 @@ def forward(
195297
null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long)
196298
null_loss = torch.full((1,), 0., device = device, dtype = x.dtype)
197299

300+
# setup the mlps for implicit neural codebook
301+
302+
maybe_code_transforms = (None,) * len(self.layers)
303+
304+
if self.implicit_neural_codebook:
305+
maybe_code_transforms = (None, *self.mlps)
306+
198307
# go through the layers
199308

200-
for quantizer_index, layer in enumerate(self.layers):
309+
for quantizer_index, (vq, maybe_mlp) in enumerate(zip(self.layers, maybe_code_transforms)):
201310

202311
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
203312
all_indices.append(null_indices)
@@ -208,12 +317,20 @@ def forward(
208317
if return_loss:
209318
layer_indices = indices[..., quantizer_index]
210319

211-
quantized, *rest = layer(
320+
# setup the transform code function to be passed into VectorQuantize forward
321+
322+
if exists(maybe_mlp):
323+
maybe_mlp = partial(maybe_mlp, condition = quantized_out)
324+
325+
# vector quantize forward
326+
327+
quantized, *rest = vq(
212328
residual,
213329
mask = mask,
214330
indices = layer_indices,
215331
sample_codebook_temp = sample_codebook_temp,
216-
freeze_codebook = freeze_codebook
332+
freeze_codebook = freeze_codebook,
333+
codebook_transform_fn = maybe_mlp
217334
)
218335

219336
residual = residual - quantized.detach()

0 commit comments

Comments
 (0)