Skip to content

Commit 3debe66

Browse files
authored
Merge pull request #77 from lucidrains/LFQ
begin work on the proposed "lookup free quantization"
2 parents 6fe65cf + 000c3cb commit 3debe66

File tree

6 files changed

+355
-7
lines changed

6 files changed

+355
-7
lines changed

README.md

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,12 +284,39 @@ assert xhat.shape == x.shape
284284
assert torch.all(xhat == quantizer.indices_to_codes(indices))
285285
```
286286

287+
### Lookup Free Quantization
287288

288-
## Todo
289+
<img src="./lfq.png" width="450px"></img>
289290

290-
- [x] allow for multi-headed codebooks
291-
- [x] support masking
292-
- [x] make sure affine param works with (`sync_affine_param` set to `True`)
291+
The research team behind <a href="https://arxiv.org/abs/2212.05199">MagViT</a> has released new SOTA results for generative video modeling. The core change between v1 and v2 of their architecture is using a new type of quantization, which is essentially the same as <a href="https://arxiv.org/abs/2309.15505">Finite Scalar Quantization</a> but with 2 levels (binary latents). (FSQ would be a generalization of this technique). However, this team chose to use extra entropy regularizations to promote codebook usage.
292+
293+
Finite scalar quantization and follow up papers will likely lead to further game changing results in generative modeling.
294+
295+
You can use it simply as follows. Will be dogfooded at <a href="https://github.com/lucidrains/magvit2-pytorch">MagViT2 pytorch port</a>
296+
297+
```python
298+
import torch
299+
from vector_quantize_pytorch import LFQ
300+
301+
# you can specify either dim or codebook_size
302+
# if both specified, will be validated against each other
303+
304+
quantizer = LFQ(
305+
dim = 16, # this is the input feature dimension, but also the log2(codebook_size)
306+
# codebook_size = 2 ** 16, # correspondingly, this would be 2 ^ dim - since each scalar in the feature dimension is a binary latent
307+
entropy_loss_weight = 0.1, # how much weight to place on entropy loss
308+
diversity_gamma = 1. # within entropy loss, how much weight to give to diversity of codes, taken from https://arxiv.org/abs/1911.05894
309+
)
310+
311+
image_feats = torch.randn(1, 16, 32, 32)
312+
313+
quantized, indices, entropy_aux_loss = quantizer(image_feats)
314+
315+
# (1, 16, 32, 32), (1, 32, 32), (1,)
316+
317+
assert image_feats.shape == quantized.shape
318+
assert (quantized == quantizer.indices_to_codes(indices)).all()
319+
```
293320

294321
## Citations
295322

@@ -429,3 +456,14 @@ assert torch.all(xhat == quantizer.indices_to_codes(indices))
429456
primaryClass = {cs.CV}
430457
}
431458
```
459+
460+
```bibtex
461+
@misc{yu2023language,
462+
title = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation},
463+
author = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
464+
year = {2023},
465+
eprint = {2310.05737},
466+
archivePrefix = {arXiv},
467+
primaryClass = {cs.CV}
468+
}
469+
```

examples/autoencoder_lfq.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# FashionMnist VQ experiment with various settings.
2+
# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py
3+
4+
from tqdm.auto import trange
5+
from math import log2
6+
7+
import torch
8+
import torch.nn as nn
9+
import torch.nn.functional as F
10+
from torchvision import datasets, transforms
11+
from torch.utils.data import DataLoader
12+
13+
from vector_quantize_pytorch import LFQ
14+
15+
lr = 3e-4
16+
train_iter = 10000
17+
seed = 1234
18+
codebook_size = 2 ** 8
19+
diversity_gamma = 10.
20+
device = "cuda" if torch.cuda.is_available() else "cpu"
21+
22+
class LFQAutoEncoder(nn.Module):
23+
def __init__(
24+
self,
25+
codebook_size,
26+
**vq_kwargs
27+
):
28+
super().__init__()
29+
assert log2(codebook_size).is_integer()
30+
quantize_dim = int(log2(codebook_size))
31+
32+
self.encode = nn.Sequential(
33+
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
34+
nn.MaxPool2d(kernel_size=2, stride=2),
35+
nn.GELU(),
36+
nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1),
37+
nn.GELU(),
38+
nn.Conv2d(16, quantize_dim, kernel_size=3, stride=1, padding=1),
39+
nn.MaxPool2d(kernel_size=2, stride=2)
40+
)
41+
42+
self.quantize = LFQ(dim=quantize_dim, **vq_kwargs)
43+
44+
self.decode = nn.Sequential(
45+
nn.Upsample(scale_factor=2, mode="nearest"),
46+
nn.Conv2d(quantize_dim, 16, kernel_size=3, stride=1, padding=1),
47+
nn.GELU(),
48+
nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1),
49+
nn.GELU(),
50+
nn.Upsample(scale_factor=2, mode="nearest"),
51+
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1)
52+
)
53+
return
54+
55+
def forward(self, x):
56+
x = self.encode(x)
57+
x, indices, entropy_aux_loss = self.quantize(x)
58+
x = self.decode(x)
59+
return x.clamp(-1, 1), indices, entropy_aux_loss
60+
61+
62+
def train(model, train_loader, train_iterations=1000):
63+
def iterate_dataset(data_loader):
64+
data_iter = iter(data_loader)
65+
while True:
66+
try:
67+
x, y = next(data_iter)
68+
except StopIteration:
69+
data_iter = iter(data_loader)
70+
x, y = next(data_iter)
71+
yield x.to(device), y.to(device)
72+
73+
for _ in (pbar := trange(train_iterations)):
74+
opt.zero_grad()
75+
x, _ = next(iterate_dataset(train_loader))
76+
out, indices, entropy_aux_loss = model(x)
77+
78+
rec_loss = F.l1_loss(out, x)
79+
(rec_loss + entropy_aux_loss).backward()
80+
81+
opt.step()
82+
pbar.set_description(
83+
f"rec loss: {rec_loss.item():.3f} | "
84+
+ f"entropy aux loss: {entropy_aux_loss.item():.3f} | "
85+
+ f"active %: {indices.unique().numel() / codebook_size * 100:.3f}"
86+
)
87+
return
88+
89+
transform = transforms.Compose(
90+
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
91+
)
92+
93+
train_dataset = DataLoader(
94+
datasets.FashionMNIST(
95+
root="~/data/fashion_mnist", train=True, download=True, transform=transform
96+
),
97+
batch_size=256,
98+
shuffle=True,
99+
)
100+
101+
print("baseline")
102+
103+
torch.random.manual_seed(seed)
104+
105+
model = LFQAutoEncoder(
106+
codebook_size = codebook_size,
107+
diversity_gamma = diversity_gamma
108+
).to(device)
109+
110+
opt = torch.optim.AdamW(model.parameters(), lr=lr)
111+
112+
train(model, train_dataset, train_iterations=train_iter)

lfq.png

92.8 KB
Loading

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.8.1',
6+
version = '1.9.0',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',
@@ -17,7 +17,7 @@
1717
'quantization'
1818
],
1919
install_requires=[
20-
'einops>=0.6.1',
20+
'einops>=0.7.0',
2121
'torch'
2222
],
2323
classifiers=[

vector_quantize_pytorch/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
22
from vector_quantize_pytorch.residual_vq import ResidualVQ, GroupedResidualVQ
33
from vector_quantize_pytorch.random_projection_quantizer import RandomProjectionQuantizer
4-
from vector_quantize_pytorch.finite_scalar_quantization import FSQ
4+
from vector_quantize_pytorch.finite_scalar_quantization import FSQ
5+
from vector_quantize_pytorch.lookup_free_quantization import LFQ
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
"""
2+
Lookup Free Quantization
3+
Proposed in https://arxiv.org/abs/2310.05737
4+
5+
basically a 2-level FSQ (Finite Scalar Quantization) with entropy loss
6+
https://arxiv.org/abs/2309.15505
7+
"""
8+
9+
from math import log2
10+
from collections import namedtuple
11+
12+
import torch
13+
from torch import nn, Tensor
14+
from torch.nn import Module, ModuleList
15+
16+
from einops import rearrange, reduce, pack, unpack
17+
18+
# constants
19+
20+
Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])
21+
22+
# helper functions
23+
24+
def exists(v):
25+
return v is not None
26+
27+
def default(*args):
28+
for arg in args:
29+
if exists(arg):
30+
return arg
31+
return None
32+
33+
def pack_one(t, pattern):
34+
return pack([t], pattern)
35+
36+
def unpack_one(t, ps, pattern):
37+
return unpack(t, ps, pattern)[0]
38+
39+
# entropy
40+
41+
def binary_entropy(prob):
42+
return -prob * log(prob) - (1 - prob) * log(1 - prob)
43+
44+
# tensor helpers
45+
46+
def log(t, eps = 1e-20):
47+
return t.clamp(min = eps).log()
48+
49+
# convert to bit representations and back
50+
51+
def decimal_to_bits(x, bits):
52+
device = x.device
53+
54+
x = x.int()
55+
56+
mask = 2 ** torch.arange(bits - 1, -1, -1, device = device)
57+
x = rearrange(x, 'b n -> b n 1')
58+
59+
bits = ((x & mask) != 0).float()
60+
bits = rearrange(bits, 'b n d -> b n d')
61+
return bits * 2 - 1
62+
63+
def bits_to_decimal(x, bits):
64+
device, dtype = x.device, x.dtype
65+
66+
x = (x > 0).int()
67+
68+
mask = 2 ** torch.arange(bits - 1, -1, -1, device = device, dtype = torch.int32)
69+
dec = reduce(x * mask, 'b n d -> b n', 'sum')
70+
return dec
71+
72+
# class
73+
74+
class LFQ(Module):
75+
def __init__(
76+
self,
77+
*,
78+
dim = None,
79+
codebook_size = None,
80+
entropy_loss_weight = 0.1,
81+
diversity_gamma = 2.5
82+
):
83+
super().__init__()
84+
85+
# some assert validations
86+
87+
assert exists(dim) or exists(codebook_size)
88+
assert not exists(codebook_size) or log2(codebook_size).is_integer()
89+
90+
codebook_size = default(codebook_size, 2 ** dim)
91+
dim = default(dim, int(log2(codebook_size)))
92+
93+
assert (2 ** dim) == codebook_size, f'2 ^ dimension ({dim}) must be equal to the codebook size ({codebook_size})'
94+
95+
self.dim = dim
96+
97+
# entropy aux loss related weights
98+
99+
self.diversity_gamma = diversity_gamma
100+
self.entropy_loss_weight = entropy_loss_weight
101+
102+
# for no auxiliary loss, during inference
103+
104+
self.register_buffer('zero', torch.zeros(1,), persistent = False)
105+
106+
def indices_to_codes(self, indices):
107+
is_img_or_video = indices.ndim >= 3
108+
109+
# rearrange if image or video into (batch, seq, dimension)
110+
111+
if is_img_or_video:
112+
indices, ps = pack_one(indices, 'b *')
113+
114+
# indices to codes, which are bits of either -1 or 1
115+
116+
codes = decimal_to_bits(indices, self.dim)
117+
118+
# rearrange codes back to original shape
119+
120+
if is_img_or_video:
121+
codes = unpack_one(codes, ps, 'b * d')
122+
codes = rearrange(codes, 'b ... d -> b d ...')
123+
124+
return codes
125+
126+
def forward(
127+
self,
128+
x,
129+
inv_temperature = 1.
130+
):
131+
"""
132+
einstein notation
133+
b - batch
134+
n - sequence (or flattened spatial dimensions)
135+
d - feature dimension, which is also log2(codebook size)
136+
"""
137+
138+
is_img_or_video = x.ndim >= 4
139+
140+
# rearrange if image or video into (batch, seq, dimension)
141+
142+
if is_img_or_video:
143+
x = rearrange(x, 'b d ... -> b ... d')
144+
x, ps = pack_one(x, 'b * d')
145+
146+
assert x.shape[-1] == self.dim
147+
148+
# quantize by eq 3.
149+
150+
greater_than_zero = x > 0
151+
ones = torch.ones_like(x)
152+
153+
quantized = torch.where(greater_than_zero, ones, -ones)
154+
155+
# use straight-through gradients with tanh if training
156+
157+
if self.training:
158+
x = torch.tanh(x * inv_temperature)
159+
x = x - x.detach() + quantized
160+
else:
161+
x = quantized
162+
163+
# calculate indices
164+
165+
indices = bits_to_decimal(x, self.dim)
166+
167+
# entropy aux loss (todo)
168+
169+
if self.training:
170+
prob = (x * inv_temperature).sigmoid()
171+
172+
bit_entropy = binary_entropy(prob).mean()
173+
174+
avg_prob = reduce(prob, 'b n d -> b d', 'mean')
175+
codebook_entropy = binary_entropy(avg_prob).mean()
176+
177+
# 1. entropy will be nudged to be low for each bit, so each scalar commits to one latent binary bit or the other
178+
# 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used
179+
180+
entropy_aux_loss = bit_entropy - self.diversity_gamma * codebook_entropy
181+
else:
182+
# if not training, just return dummy 0
183+
entropy_aux_loss = self.zero
184+
185+
entropy_aux_loss = entropy_aux_loss * self.entropy_loss_weight
186+
187+
# reconstitute image or video dimensions
188+
189+
if is_img_or_video:
190+
x = unpack_one(x, ps, 'b * d')
191+
x = rearrange(x, 'b ... d -> b d ...')
192+
193+
indices = unpack_one(indices, ps, 'b *')
194+
195+
# bits to decimal for the codebook indices
196+
197+
return Return(x, indices, entropy_aux_loss)

0 commit comments

Comments
 (0)