Skip to content

Commit 72ede73

Browse files
authored
Merge pull request #172 from lucidrains/simvq
SimVQ
2 parents 723ea9f + 97b9a87 commit 72ede73

File tree

9 files changed

+224
-7
lines changed

9 files changed

+224
-7
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,3 +714,12 @@ assert loss.item() >= 0
714714
url = {https://api.semanticscholar.org/CorpusID:273229218}
715715
}
716716
```
717+
718+
```bibtex
719+
@inproceedings{Zhu2024AddressingRC,
720+
title = {Addressing Representation Collapse in Vector Quantized Models with One Linear Layer},
721+
author = {Yongxin Zhu and Bocheng Li and Yifei Xin and Linli Xu},
722+
year = {2024},
723+
url = {https://api.semanticscholar.org/CorpusID:273812459}
724+
}
725+
```

examples/autoencoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ def iterate_dataset(data_loader):
7171
shuffle=True,
7272
)
7373

74-
print("baseline")
7574
torch.random.manual_seed(seed)
7675

7776
model = SimpleVQAutoEncoder(
78-
codebook_size=num_codes,
79-
rotation_trick=rotation_trick
77+
codebook_size = num_codes,
78+
rotation_trick = True,
79+
straight_through = False
8080
).to(device)
8181

8282
opt = torch.optim.AdamW(model.parameters(), lr=lr)

examples/autoencoder_fsq.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def iterate_dataset(data_loader):
7676
shuffle=True,
7777
)
7878

79-
print("baseline")
8079
torch.random.manual_seed(seed)
8180
model = SimpleFSQAutoEncoder(levels).to(device)
8281
opt = torch.optim.AdamW(model.parameters(), lr=lr)

examples/autoencoder_lfq.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,6 @@ def iterate_dataset(data_loader):
8787
shuffle=True,
8888
)
8989

90-
print("baseline")
91-
9290
torch.random.manual_seed(seed)
9391

9492
model = LFQAutoEncoder(

examples/autoencoder_sim_vq.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# FashionMnist VQ experiment with various settings.
2+
# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py
3+
4+
from tqdm.auto import trange
5+
6+
import torch
7+
import torch.nn as nn
8+
from torchvision import datasets, transforms
9+
from torch.utils.data import DataLoader
10+
11+
from vector_quantize_pytorch import SimVQ, Sequential
12+
13+
lr = 3e-4
14+
train_iter = 10000
15+
num_codes = 256
16+
seed = 1234
17+
rotation_trick = True
18+
device = "cuda" if torch.cuda.is_available() else "cpu"
19+
20+
def SimVQAutoEncoder(**vq_kwargs):
21+
return Sequential(
22+
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
23+
nn.MaxPool2d(kernel_size=2, stride=2),
24+
nn.GELU(),
25+
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
26+
nn.MaxPool2d(kernel_size=2, stride=2),
27+
SimVQ(dim=32, accept_image_fmap = True, **vq_kwargs),
28+
nn.Upsample(scale_factor=2, mode="nearest"),
29+
nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
30+
nn.GELU(),
31+
nn.Upsample(scale_factor=2, mode="nearest"),
32+
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
33+
)
34+
35+
def train(model, train_loader, train_iterations=1000, alpha=10):
36+
def iterate_dataset(data_loader):
37+
data_iter = iter(data_loader)
38+
while True:
39+
try:
40+
x, y = next(data_iter)
41+
except StopIteration:
42+
data_iter = iter(data_loader)
43+
x, y = next(data_iter)
44+
yield x.to(device), y.to(device)
45+
46+
for _ in (pbar := trange(train_iterations)):
47+
opt.zero_grad()
48+
x, _ = next(iterate_dataset(train_loader))
49+
50+
out, indices, cmt_loss = model(x)
51+
out = out.clamp(-1., 1.)
52+
53+
rec_loss = (out - x).abs().mean()
54+
(rec_loss + alpha * cmt_loss).backward()
55+
56+
opt.step()
57+
58+
pbar.set_description(
59+
f"rec loss: {rec_loss.item():.3f} | "
60+
+ f"cmt loss: {cmt_loss.item():.3f} | "
61+
+ f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
62+
)
63+
64+
transform = transforms.Compose(
65+
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
66+
)
67+
68+
train_dataset = DataLoader(
69+
datasets.FashionMNIST(
70+
root="~/data/fashion_mnist", train=True, download=True, transform=transform
71+
),
72+
batch_size=256,
73+
shuffle=True,
74+
)
75+
76+
torch.random.manual_seed(seed)
77+
78+
model = SimVQAutoEncoder(
79+
codebook_size = num_codes,
80+
rotation_trick = rotation_trick
81+
).to(device)
82+
83+
opt = torch.optim.AdamW(model.parameters(), lr=lr)
84+
train(model, train_dataset, train_iterations=train_iter)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.19.5"
3+
version = "1.20.0"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66
from vector_quantize_pytorch.residual_lfq import ResidualLFQ, GroupedResidualLFQ
77
from vector_quantize_pytorch.residual_fsq import ResidualFSQ, GroupedResidualFSQ
88
from vector_quantize_pytorch.latent_quantization import LatentQuantize
9+
from vector_quantize_pytorch.sim_vq import SimVQ
910

1011
from vector_quantize_pytorch.utils import Sequential

vector_quantize_pytorch/sim_vq.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from typing import Callable
2+
3+
import torch
4+
from torch import nn
5+
from torch.nn import Module
6+
import torch.nn.functional as F
7+
8+
from einx import get_at
9+
from einops import rearrange, pack, unpack
10+
11+
from vector_quantize_pytorch.vector_quantize_pytorch import rotate_from_to
12+
13+
# helper functions
14+
15+
def exists(v):
16+
return v is not None
17+
18+
def identity(t):
19+
return t
20+
21+
def default(v, d):
22+
return v if exists(v) else d
23+
24+
def pack_one(t, pattern):
25+
packed, packed_shape = pack([t], pattern)
26+
27+
def inverse(out, inv_pattern = None):
28+
inv_pattern = default(inv_pattern, pattern)
29+
out, = unpack(out, packed_shape, inv_pattern)
30+
return out
31+
32+
return packed, inverse
33+
34+
# class
35+
36+
class SimVQ(Module):
37+
def __init__(
38+
self,
39+
dim,
40+
codebook_size,
41+
init_fn: Callable = identity,
42+
accept_image_fmap = False,
43+
rotation_trick = True, # works even better with rotation trick turned on, with no asymmetric commit loss or straight through
44+
commit_loss_input_to_quantize_weight = 0.25,
45+
):
46+
super().__init__()
47+
self.accept_image_fmap = accept_image_fmap
48+
49+
codebook = torch.randn(codebook_size, dim) * (dim ** -0.5)
50+
codebook = init_fn(codebook)
51+
52+
# the codebook is actually implicit from a linear layer from frozen gaussian or uniform
53+
54+
self.codebook_to_codes = nn.Linear(dim, dim, bias = False)
55+
self.register_buffer('codebook', codebook)
56+
57+
58+
# whether to use rotation trick from Fifty et al.
59+
# https://arxiv.org/abs/2410.06424
60+
61+
self.rotation_trick = rotation_trick
62+
self.register_buffer('zero', torch.tensor(0.), persistent = False)
63+
64+
# commit loss weighting - weighing input to quantize a bit less is crucial for it to work
65+
66+
self.commit_loss_input_to_quantize_weight = commit_loss_input_to_quantize_weight
67+
68+
def forward(
69+
self,
70+
x
71+
):
72+
if self.accept_image_fmap:
73+
x = rearrange(x, 'b d h w -> b h w d')
74+
x, inverse_pack = pack_one(x, 'b * d')
75+
76+
implicit_codebook = self.codebook_to_codes(self.codebook)
77+
78+
with torch.no_grad():
79+
dist = torch.cdist(x, implicit_codebook)
80+
indices = dist.argmin(dim = -1)
81+
82+
# select codes
83+
84+
quantized = get_at('[c] d, b n -> b n d', implicit_codebook, indices)
85+
86+
if self.rotation_trick:
87+
# rotation trick from @cfifty
88+
89+
quantized = rotate_from_to(quantized, x)
90+
91+
commit_loss = self.zero
92+
else:
93+
# commit loss and straight through, as was done in the paper
94+
95+
commit_loss = (
96+
F.mse_loss(x, quantized.detach()) * self.commit_loss_input_to_quantize_weight +
97+
F.mse_loss(x.detach(), quantized)
98+
)
99+
100+
quantized = (quantized - x).detach() + x
101+
102+
if self.accept_image_fmap:
103+
quantized = inverse_pack(quantized)
104+
quantized = rearrange(quantized, 'b h w d-> b d h w')
105+
106+
indices = inverse_pack(indices, 'b *')
107+
108+
return quantized, indices, commit_loss
109+
110+
# main
111+
112+
if __name__ == '__main__':
113+
114+
x = torch.randn(1, 512, 32, 32)
115+
116+
sim_vq = SimVQ(
117+
dim = 512,
118+
codebook_size = 1024,
119+
accept_image_fmap = True
120+
)
121+
122+
quantized, indices, commit_loss = sim_vq(x)
123+
124+
assert x.shape == quantized.shape

vector_quantize_pytorch/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vector_quantize_pytorch.residual_lfq import ResidualLFQ, GroupedResidualLFQ
1313
from vector_quantize_pytorch.residual_fsq import ResidualFSQ, GroupedResidualFSQ
1414
from vector_quantize_pytorch.latent_quantization import LatentQuantize
15+
from vector_quantize_pytorch.sim_vq import SimVQ
1516

1617
QUANTIZE_KLASSES = (
1718
VectorQuantize,
@@ -20,6 +21,7 @@
2021
RandomProjectionQuantizer,
2122
FSQ,
2223
LFQ,
24+
SimVQ,
2325
ResidualLFQ,
2426
GroupedResidualLFQ,
2527
ResidualFSQ,

0 commit comments

Comments
 (0)