Skip to content

Commit 8d50c02

Browse files
committed
support multiple codebooks for the token factorization for video tokenizer in magvit2
1 parent 564a11d commit 8d50c02

File tree

3 files changed

+58
-9
lines changed

3 files changed

+58
-9
lines changed

README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,28 @@ assert video_feats.shape == quantized.shape
336336

337337
```
338338

339+
Or support multiple codebooks
340+
341+
```python
342+
import torch
343+
from vector_quantize_pytorch import LFQ
344+
345+
quantizer = LFQ(
346+
codebook_size = 4096,
347+
dim = 16,
348+
num_codebooks = 4 # 4 codebooks, total codebook dimension is log2(4096) * 4
349+
)
350+
351+
image_feats = torch.randn(1, 16, 32, 32)
352+
353+
quantized, indices, entropy_aux_loss = quantizer(image_feats)
354+
355+
# (1, 16, 32, 32), (1, 32, 32, 4), (1,)
356+
357+
assert image_feats.shape == quantized.shape
358+
assert (quantized == quantizer.indices_to_codes(indices)).all()
359+
```
360+
339361
## Citations
340362

341363
```bibtex

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.9.6',
6+
version = '1.9.7',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def __init__(
5454
codebook_size = None,
5555
entropy_loss_weight = 0.1,
5656
diversity_gamma = 2.5,
57-
straight_through_activation = nn.Tanh()
57+
straight_through_activation = nn.Tanh(),
58+
num_codebooks = 1,
59+
keep_num_codebooks_dim = None
5860
):
5961
super().__init__()
6062

@@ -66,13 +68,19 @@ def __init__(
6668
codebook_size = default(codebook_size, lambda: 2 ** dim)
6769
codebook_dim = int(log2(codebook_size))
6870

69-
dim = default(dim, codebook_dim)
71+
codebook_dims = codebook_dim * num_codebooks
72+
dim = default(dim, codebook_dims)
7073

71-
self.project_in = nn.Linear(dim, codebook_dim) if dim != codebook_dim else nn.Identity()
72-
self.project_out = nn.Linear(codebook_dim, dim) if dim != codebook_dim else nn.Identity()
74+
self.project_in = nn.Linear(dim, codebook_dims) if dim != codebook_dims else nn.Identity()
75+
self.project_out = nn.Linear(codebook_dims, dim) if dim != codebook_dims else nn.Identity()
7376

7477
self.dim = dim
7578
self.codebook_dim = codebook_dim
79+
self.num_codebooks = num_codebooks
80+
81+
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
82+
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
83+
self.keep_num_codebooks_dim = keep_num_codebooks_dim
7684

7785
# straight through activation
7886

@@ -95,11 +103,16 @@ def indices_to_codes(
95103
):
96104
is_img_or_video = indices.ndim >= 3
97105

106+
if not self.keep_num_codebooks_dim:
107+
indices = rearrange(indices, '... -> ... 1')
108+
98109
# indices to codes, which are bits of either -1 or 1
99110

100111
bits = ((indices[..., None].int() & self.mask) != 0).float()
101112
codes = bits * 2 - 1
102113

114+
codes = rearrange(codes, '... c d -> ... (c d)')
115+
103116
# whether to project codes out to original dimensions
104117
# if the input feature dimensions were not log2(codebook size)
105118

@@ -123,6 +136,7 @@ def forward(
123136
b - batch
124137
n - sequence (or flattened spatial dimensions)
125138
d - feature dimension, which is also log2(codebook size)
139+
c - number of codebook dim
126140
"""
127141

128142
is_img_or_video = x.ndim >= 4
@@ -133,10 +147,14 @@ def forward(
133147
x = rearrange(x, 'b d ... -> b ... d')
134148
x, ps = pack_one(x, 'b * d')
135149

136-
assert x.shape[-1] == self.dim
150+
assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
137151

138152
x = self.project_in(x)
139153

154+
# split out number of codebooks
155+
156+
x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks)
157+
140158
# quantize by eq 3.
141159

142160
ones = torch.ones_like(x)
@@ -152,7 +170,7 @@ def forward(
152170

153171
# calculate indices
154172

155-
indices = reduce((x > 0).int() * self.mask.int(), 'b n d -> b n', 'sum')
173+
indices = reduce((x > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
156174

157175
# entropy aux loss
158176

@@ -161,7 +179,7 @@ def forward(
161179

162180
bit_entropy = binary_entropy(prob).mean()
163181

164-
avg_prob = reduce(prob, 'b n d -> b d', 'mean')
182+
avg_prob = reduce(prob, 'b n c d -> b c d', 'mean')
165183
codebook_entropy = binary_entropy(avg_prob).mean()
166184

167185
# 1. entropy will be nudged to be low for each bit, so each scalar commits to one latent binary bit or the other
@@ -174,6 +192,10 @@ def forward(
174192

175193
entropy_aux_loss = entropy_aux_loss * self.entropy_loss_weight
176194

195+
# merge back codebook dim
196+
197+
x = rearrange(x, 'b n c d -> b n (c d)')
198+
177199
# project out to feature dimension if needed
178200

179201
x = self.project_out(x)
@@ -184,6 +206,11 @@ def forward(
184206
x = unpack_one(x, ps, 'b * d')
185207
x = rearrange(x, 'b ... d -> b d ...')
186208

187-
indices = unpack_one(indices, ps, 'b *')
209+
indices = unpack_one(indices, ps, 'b * c')
210+
211+
# whether to remove single codebook dim
212+
213+
if not self.keep_num_codebooks_dim:
214+
indices = rearrange(indices, '... 1 -> ...')
188215

189216
return Return(x, indices, entropy_aux_loss)

0 commit comments

Comments
 (0)