Skip to content

Commit 7959292

Browse files
committed
simvq
1 parent 723ea9f commit 7959292

File tree

5 files changed

+226
-2
lines changed

5 files changed

+226
-2
lines changed

examples/autoencoder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,9 @@ def iterate_dataset(data_loader):
7575
torch.random.manual_seed(seed)
7676

7777
model = SimpleVQAutoEncoder(
78-
codebook_size=num_codes,
79-
rotation_trick=rotation_trick
78+
codebook_size = num_codes,
79+
rotation_trick = True,
80+
straight_through = False
8081
).to(device)
8182

8283
opt = torch.optim.AdamW(model.parameters(), lr=lr)

examples/autoencoder_sim_vq.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# FashionMnist VQ experiment with various settings.
2+
# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py
3+
4+
from tqdm.auto import trange
5+
6+
import torch
7+
import torch.nn as nn
8+
from torchvision import datasets, transforms
9+
from torch.utils.data import DataLoader
10+
11+
from vector_quantize_pytorch import SimVQ, Sequential
12+
13+
lr = 3e-4
14+
train_iter = 10000
15+
num_codes = 256
16+
seed = 1234
17+
rotation_trick = True
18+
device = "cuda" if torch.cuda.is_available() else "cpu"
19+
20+
def SimVQAutoEncoder(**vq_kwargs):
21+
return Sequential(
22+
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
23+
nn.MaxPool2d(kernel_size=2, stride=2),
24+
nn.GELU(),
25+
nn.Conv2d(16, 64, kernel_size=3, stride=1, padding=1),
26+
nn.MaxPool2d(kernel_size=2, stride=2),
27+
SimVQ(dim=64, accept_image_fmap = True, **vq_kwargs),
28+
nn.Upsample(scale_factor=2, mode="nearest"),
29+
nn.Conv2d(64, 16, kernel_size=3, stride=1, padding=1),
30+
nn.GELU(),
31+
nn.Upsample(scale_factor=2, mode="nearest"),
32+
nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
33+
)
34+
35+
def train(model, train_loader, train_iterations=1000, alpha=10):
36+
def iterate_dataset(data_loader):
37+
data_iter = iter(data_loader)
38+
while True:
39+
try:
40+
x, y = next(data_iter)
41+
except StopIteration:
42+
data_iter = iter(data_loader)
43+
x, y = next(data_iter)
44+
yield x.to(device), y.to(device)
45+
46+
for _ in (pbar := trange(train_iterations)):
47+
opt.zero_grad()
48+
x, _ = next(iterate_dataset(train_loader))
49+
50+
out, indices, cmt_loss = model(x)
51+
out = out.clamp(-1., 1.)
52+
53+
rec_loss = (out - x).abs().mean()
54+
(rec_loss + alpha * cmt_loss).backward()
55+
56+
opt.step()
57+
58+
pbar.set_description(
59+
f"rec loss: {rec_loss.item():.3f} | "
60+
+ f"cmt loss: {cmt_loss.item():.3f} | "
61+
+ f"active %: {indices.unique().numel() / num_codes * 100:.3f}"
62+
)
63+
64+
transform = transforms.Compose(
65+
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
66+
)
67+
68+
train_dataset = DataLoader(
69+
datasets.FashionMNIST(
70+
root="~/data/fashion_mnist", train=True, download=True, transform=transform
71+
),
72+
batch_size=256,
73+
shuffle=True,
74+
)
75+
76+
print("baseline")
77+
torch.random.manual_seed(seed)
78+
79+
model = SimVQAutoEncoder(
80+
codebook_size=num_codes,
81+
).to(device)
82+
83+
opt = torch.optim.AdamW(model.parameters(), lr=lr)
84+
train(model, train_dataset, train_iterations=train_iter)

vector_quantize_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66
from vector_quantize_pytorch.residual_lfq import ResidualLFQ, GroupedResidualLFQ
77
from vector_quantize_pytorch.residual_fsq import ResidualFSQ, GroupedResidualFSQ
88
from vector_quantize_pytorch.latent_quantization import LatentQuantize
9+
from vector_quantize_pytorch.sim_vq import SimVQ
910

1011
from vector_quantize_pytorch.utils import Sequential

vector_quantize_pytorch/sim_vq.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from typing import Callable
2+
3+
import torch
4+
from torch import nn
5+
from torch.nn import Module
6+
import torch.nn.functional as F
7+
8+
from einx import get_at
9+
from einops import einsum, rearrange, repeat, reduce, pack, unpack
10+
11+
# helper functions
12+
13+
def exists(v):
14+
return v is not None
15+
16+
def identity(t):
17+
return t
18+
19+
def default(v, d):
20+
return v if exists(v) else d
21+
22+
def pack_one(t, pattern):
23+
packed, packed_shape = pack([t], pattern)
24+
25+
def inverse(out, inv_pattern = None):
26+
inv_pattern = default(inv_pattern, pattern)
27+
out, = unpack(out, packed_shape, inv_pattern)
28+
return out
29+
30+
return packed, inverse
31+
32+
def l2norm(t, dim = -1):
33+
return F.normalize(t, dim = dim)
34+
35+
def safe_div(num, den, eps = 1e-6):
36+
return num / den.clamp(min = eps)
37+
38+
def efficient_rotation_trick_transform(u, q, e):
39+
"""
40+
4.2 in https://arxiv.org/abs/2410.06424
41+
"""
42+
e = rearrange(e, 'b d -> b 1 d')
43+
w = l2norm(u + q, dim = 1).detach()
44+
45+
return (
46+
e -
47+
2 * (e @ rearrange(w, 'b d -> b d 1') @ rearrange(w, 'b d -> b 1 d')) +
48+
2 * (e @ rearrange(u, 'b d -> b d 1').detach() @ rearrange(q, 'b d -> b 1 d').detach())
49+
)
50+
51+
# class
52+
53+
class SimVQ(Module):
54+
def __init__(
55+
self,
56+
dim,
57+
codebook_size,
58+
init_fn: Callable = identity,
59+
accept_image_fmap = False
60+
):
61+
super().__init__()
62+
self.accept_image_fmap = accept_image_fmap
63+
64+
codebook = torch.randn(codebook_size, dim)
65+
codebook = init_fn(codebook)
66+
67+
# the codebook is actually implicit from a linear layer from frozen gaussian or uniform
68+
69+
self.codebook_to_codes = nn.Linear(dim, dim, bias = False)
70+
self.register_buffer('codebook', codebook)
71+
72+
def forward(
73+
self,
74+
x
75+
):
76+
if self.accept_image_fmap:
77+
x = rearrange(x, 'b d h w -> b h w d')
78+
x, inverse_pack = pack_one(x, 'b * d')
79+
80+
implicit_codebook = self.codebook_to_codes(self.codebook)
81+
82+
with torch.no_grad():
83+
dist = torch.cdist(x, implicit_codebook)
84+
indices = dist.argmin(dim = -1)
85+
86+
# select codes
87+
88+
quantized = get_at('[c] d, b n -> b n d', implicit_codebook, indices)
89+
90+
# commit loss
91+
92+
commit_loss = (F.pairwise_distance(x, quantized.detach()) ** 2).mean()
93+
94+
# straight through
95+
96+
x, inverse = pack_one(x, '* d')
97+
quantized, _ = pack_one(quantized, '* d')
98+
99+
norm_x = x.norm(dim = -1, keepdim = True)
100+
norm_quantize = quantized.norm(dim = -1, keepdim = True)
101+
102+
rot_quantize = efficient_rotation_trick_transform(
103+
safe_div(x, norm_x),
104+
safe_div(quantized, norm_quantize),
105+
x
106+
).squeeze()
107+
108+
quantized = rot_quantize * safe_div(norm_quantize, norm_x).detach()
109+
110+
x, quantized = inverse(x), inverse(quantized)
111+
112+
# quantized = (quantized - x).detach() + x
113+
114+
if self.accept_image_fmap:
115+
quantized = inverse_pack(quantized)
116+
quantized = rearrange(quantized, 'b h w d-> b d h w')
117+
118+
indices = inverse_pack(indices, 'b *')
119+
120+
return quantized, indices, commit_loss
121+
122+
# main
123+
124+
if __name__ == '__main__':
125+
126+
x = torch.randn(1, 512, 32, 32)
127+
128+
sim_vq = SimVQ(
129+
dim = 512,
130+
codebook_size = 1024,
131+
accept_image_fmap = True
132+
)
133+
134+
quantized, indices, commit_loss = sim_vq(x)
135+
136+
assert x.shape == quantized.shape

vector_quantize_pytorch/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vector_quantize_pytorch.residual_lfq import ResidualLFQ, GroupedResidualLFQ
1313
from vector_quantize_pytorch.residual_fsq import ResidualFSQ, GroupedResidualFSQ
1414
from vector_quantize_pytorch.latent_quantization import LatentQuantize
15+
from vector_quantize_pytorch.sim_vq import SimVQ
1516

1617
QUANTIZE_KLASSES = (
1718
VectorQuantize,
@@ -20,6 +21,7 @@
2021
RandomProjectionQuantizer,
2122
FSQ,
2223
LFQ,
24+
SimVQ,
2325
ResidualLFQ,
2426
GroupedResidualLFQ,
2527
ResidualFSQ,

0 commit comments

Comments
 (0)