From 7959292f8ea1655eefe9dd0ecf99902be134b3df Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 11 Nov 2024 06:17:48 -0800 Subject: [PATCH 1/5] simvq --- examples/autoencoder.py | 5 +- examples/autoencoder_sim_vq.py | 84 +++++++++++++++++ vector_quantize_pytorch/__init__.py | 1 + vector_quantize_pytorch/sim_vq.py | 136 ++++++++++++++++++++++++++++ vector_quantize_pytorch/utils.py | 2 + 5 files changed, 226 insertions(+), 2 deletions(-) create mode 100644 examples/autoencoder_sim_vq.py create mode 100644 vector_quantize_pytorch/sim_vq.py diff --git a/examples/autoencoder.py b/examples/autoencoder.py index 3b64d58..df79eb4 100644 --- a/examples/autoencoder.py +++ b/examples/autoencoder.py @@ -75,8 +75,9 @@ def iterate_dataset(data_loader): torch.random.manual_seed(seed) model = SimpleVQAutoEncoder( - codebook_size=num_codes, - rotation_trick=rotation_trick + codebook_size = num_codes, + rotation_trick = True, + straight_through = False ).to(device) opt = torch.optim.AdamW(model.parameters(), lr=lr) diff --git a/examples/autoencoder_sim_vq.py b/examples/autoencoder_sim_vq.py new file mode 100644 index 0000000..543543e --- /dev/null +++ b/examples/autoencoder_sim_vq.py @@ -0,0 +1,84 @@ +# FashionMnist VQ experiment with various settings. +# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py + +from tqdm.auto import trange + +import torch +import torch.nn as nn +from torchvision import datasets, transforms +from torch.utils.data import DataLoader + +from vector_quantize_pytorch import SimVQ, Sequential + +lr = 3e-4 +train_iter = 10000 +num_codes = 256 +seed = 1234 +rotation_trick = True +device = "cuda" if torch.cuda.is_available() else "cpu" + +def SimVQAutoEncoder(**vq_kwargs): + return Sequential( + nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.GELU(), + nn.Conv2d(16, 64, kernel_size=3, stride=1, padding=1), + nn.MaxPool2d(kernel_size=2, stride=2), + SimVQ(dim=64, accept_image_fmap = True, **vq_kwargs), + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(64, 16, kernel_size=3, stride=1, padding=1), + nn.GELU(), + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1), + ) + +def train(model, train_loader, train_iterations=1000, alpha=10): + def iterate_dataset(data_loader): + data_iter = iter(data_loader) + while True: + try: + x, y = next(data_iter) + except StopIteration: + data_iter = iter(data_loader) + x, y = next(data_iter) + yield x.to(device), y.to(device) + + for _ in (pbar := trange(train_iterations)): + opt.zero_grad() + x, _ = next(iterate_dataset(train_loader)) + + out, indices, cmt_loss = model(x) + out = out.clamp(-1., 1.) + + rec_loss = (out - x).abs().mean() + (rec_loss + alpha * cmt_loss).backward() + + opt.step() + + pbar.set_description( + f"rec loss: {rec_loss.item():.3f} | " + + f"cmt loss: {cmt_loss.item():.3f} | " + + f"active %: {indices.unique().numel() / num_codes * 100:.3f}" + ) + +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) + +train_dataset = DataLoader( + datasets.FashionMNIST( + root="~/data/fashion_mnist", train=True, download=True, transform=transform + ), + batch_size=256, + shuffle=True, +) + +print("baseline") +torch.random.manual_seed(seed) + +model = SimVQAutoEncoder( + codebook_size=num_codes, +).to(device) + +opt = torch.optim.AdamW(model.parameters(), lr=lr) +train(model, train_dataset, train_iterations=train_iter) diff --git a/vector_quantize_pytorch/__init__.py b/vector_quantize_pytorch/__init__.py index 06a0a24..9e7ce9b 100644 --- a/vector_quantize_pytorch/__init__.py +++ b/vector_quantize_pytorch/__init__.py @@ -6,5 +6,6 @@ from vector_quantize_pytorch.residual_lfq import ResidualLFQ, GroupedResidualLFQ from vector_quantize_pytorch.residual_fsq import ResidualFSQ, GroupedResidualFSQ from vector_quantize_pytorch.latent_quantization import LatentQuantize +from vector_quantize_pytorch.sim_vq import SimVQ from vector_quantize_pytorch.utils import Sequential diff --git a/vector_quantize_pytorch/sim_vq.py b/vector_quantize_pytorch/sim_vq.py new file mode 100644 index 0000000..25b099e --- /dev/null +++ b/vector_quantize_pytorch/sim_vq.py @@ -0,0 +1,136 @@ +from typing import Callable + +import torch +from torch import nn +from torch.nn import Module +import torch.nn.functional as F + +from einx import get_at +from einops import einsum, rearrange, repeat, reduce, pack, unpack + +# helper functions + +def exists(v): + return v is not None + +def identity(t): + return t + +def default(v, d): + return v if exists(v) else d + +def pack_one(t, pattern): + packed, packed_shape = pack([t], pattern) + + def inverse(out, inv_pattern = None): + inv_pattern = default(inv_pattern, pattern) + out, = unpack(out, packed_shape, inv_pattern) + return out + + return packed, inverse + +def l2norm(t, dim = -1): + return F.normalize(t, dim = dim) + +def safe_div(num, den, eps = 1e-6): + return num / den.clamp(min = eps) + +def efficient_rotation_trick_transform(u, q, e): + """ + 4.2 in https://arxiv.org/abs/2410.06424 + """ + e = rearrange(e, 'b d -> b 1 d') + w = l2norm(u + q, dim = 1).detach() + + return ( + e - + 2 * (e @ rearrange(w, 'b d -> b d 1') @ rearrange(w, 'b d -> b 1 d')) + + 2 * (e @ rearrange(u, 'b d -> b d 1').detach() @ rearrange(q, 'b d -> b 1 d').detach()) + ) + +# class + +class SimVQ(Module): + def __init__( + self, + dim, + codebook_size, + init_fn: Callable = identity, + accept_image_fmap = False + ): + super().__init__() + self.accept_image_fmap = accept_image_fmap + + codebook = torch.randn(codebook_size, dim) + codebook = init_fn(codebook) + + # the codebook is actually implicit from a linear layer from frozen gaussian or uniform + + self.codebook_to_codes = nn.Linear(dim, dim, bias = False) + self.register_buffer('codebook', codebook) + + def forward( + self, + x + ): + if self.accept_image_fmap: + x = rearrange(x, 'b d h w -> b h w d') + x, inverse_pack = pack_one(x, 'b * d') + + implicit_codebook = self.codebook_to_codes(self.codebook) + + with torch.no_grad(): + dist = torch.cdist(x, implicit_codebook) + indices = dist.argmin(dim = -1) + + # select codes + + quantized = get_at('[c] d, b n -> b n d', implicit_codebook, indices) + + # commit loss + + commit_loss = (F.pairwise_distance(x, quantized.detach()) ** 2).mean() + + # straight through + + x, inverse = pack_one(x, '* d') + quantized, _ = pack_one(quantized, '* d') + + norm_x = x.norm(dim = -1, keepdim = True) + norm_quantize = quantized.norm(dim = -1, keepdim = True) + + rot_quantize = efficient_rotation_trick_transform( + safe_div(x, norm_x), + safe_div(quantized, norm_quantize), + x + ).squeeze() + + quantized = rot_quantize * safe_div(norm_quantize, norm_x).detach() + + x, quantized = inverse(x), inverse(quantized) + + # quantized = (quantized - x).detach() + x + + if self.accept_image_fmap: + quantized = inverse_pack(quantized) + quantized = rearrange(quantized, 'b h w d-> b d h w') + + indices = inverse_pack(indices, 'b *') + + return quantized, indices, commit_loss + +# main + +if __name__ == '__main__': + + x = torch.randn(1, 512, 32, 32) + + sim_vq = SimVQ( + dim = 512, + codebook_size = 1024, + accept_image_fmap = True + ) + + quantized, indices, commit_loss = sim_vq(x) + + assert x.shape == quantized.shape diff --git a/vector_quantize_pytorch/utils.py b/vector_quantize_pytorch/utils.py index bdb4386..d591a09 100644 --- a/vector_quantize_pytorch/utils.py +++ b/vector_quantize_pytorch/utils.py @@ -12,6 +12,7 @@ from vector_quantize_pytorch.residual_lfq import ResidualLFQ, GroupedResidualLFQ from vector_quantize_pytorch.residual_fsq import ResidualFSQ, GroupedResidualFSQ from vector_quantize_pytorch.latent_quantization import LatentQuantize +from vector_quantize_pytorch.sim_vq import SimVQ QUANTIZE_KLASSES = ( VectorQuantize, @@ -20,6 +21,7 @@ RandomProjectionQuantizer, FSQ, LFQ, + SimVQ, ResidualLFQ, GroupedResidualLFQ, ResidualFSQ, From 007209dd7f9c3748ba2fa57faedb1377c48ab3bb Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 11 Nov 2024 06:43:26 -0800 Subject: [PATCH 2/5] update init --- vector_quantize_pytorch/sim_vq.py | 41 ++----------------------------- 1 file changed, 2 insertions(+), 39 deletions(-) diff --git a/vector_quantize_pytorch/sim_vq.py b/vector_quantize_pytorch/sim_vq.py index 25b099e..a247d1f 100644 --- a/vector_quantize_pytorch/sim_vq.py +++ b/vector_quantize_pytorch/sim_vq.py @@ -29,25 +29,6 @@ def inverse(out, inv_pattern = None): return packed, inverse -def l2norm(t, dim = -1): - return F.normalize(t, dim = dim) - -def safe_div(num, den, eps = 1e-6): - return num / den.clamp(min = eps) - -def efficient_rotation_trick_transform(u, q, e): - """ - 4.2 in https://arxiv.org/abs/2410.06424 - """ - e = rearrange(e, 'b d -> b 1 d') - w = l2norm(u + q, dim = 1).detach() - - return ( - e - - 2 * (e @ rearrange(w, 'b d -> b d 1') @ rearrange(w, 'b d -> b 1 d')) + - 2 * (e @ rearrange(u, 'b d -> b d 1').detach() @ rearrange(q, 'b d -> b 1 d').detach()) - ) - # class class SimVQ(Module): @@ -61,7 +42,7 @@ def __init__( super().__init__() self.accept_image_fmap = accept_image_fmap - codebook = torch.randn(codebook_size, dim) + codebook = torch.randn(codebook_size, dim) * (dim ** -0.5) codebook = init_fn(codebook) # the codebook is actually implicit from a linear layer from frozen gaussian or uniform @@ -89,25 +70,7 @@ def forward( # commit loss - commit_loss = (F.pairwise_distance(x, quantized.detach()) ** 2).mean() - - # straight through - - x, inverse = pack_one(x, '* d') - quantized, _ = pack_one(quantized, '* d') - - norm_x = x.norm(dim = -1, keepdim = True) - norm_quantize = quantized.norm(dim = -1, keepdim = True) - - rot_quantize = efficient_rotation_trick_transform( - safe_div(x, norm_x), - safe_div(quantized, norm_quantize), - x - ).squeeze() - - quantized = rot_quantize * safe_div(norm_quantize, norm_x).detach() - - x, quantized = inverse(x), inverse(quantized) + commit_loss = (F.pairwise_distance(x, quantized) ** 2).mean() # quantized = (quantized - x).detach() + x From e11a9665e2ba778f8665b4551dc8a20557c4b1f5 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 11 Nov 2024 06:52:51 -0800 Subject: [PATCH 3/5] uncomment st --- vector_quantize_pytorch/sim_vq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vector_quantize_pytorch/sim_vq.py b/vector_quantize_pytorch/sim_vq.py index a247d1f..eb06d72 100644 --- a/vector_quantize_pytorch/sim_vq.py +++ b/vector_quantize_pytorch/sim_vq.py @@ -72,7 +72,7 @@ def forward( commit_loss = (F.pairwise_distance(x, quantized) ** 2).mean() - # quantized = (quantized - x).detach() + x + quantized = (quantized - x).detach() + x if self.accept_image_fmap: quantized = inverse_pack(quantized) From 01b45eb59201eb3d8e64f8fcfac2d2e3e5a76b03 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 11 Nov 2024 08:22:58 -0800 Subject: [PATCH 4/5] fix commit loss --- vector_quantize_pytorch/sim_vq.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vector_quantize_pytorch/sim_vq.py b/vector_quantize_pytorch/sim_vq.py index eb06d72..f99ad3a 100644 --- a/vector_quantize_pytorch/sim_vq.py +++ b/vector_quantize_pytorch/sim_vq.py @@ -70,7 +70,10 @@ def forward( # commit loss - commit_loss = (F.pairwise_distance(x, quantized) ** 2).mean() + commit_loss = ( + 0.25 * F.mse_loss(x, quantized.detach()) + + F.mse_loss(x.detach(), quantized) + ) quantized = (quantized - x).detach() + x From 97b9a87afb3ddadeb4dddc38ad24986b86f7e1c8 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 11 Nov 2024 09:09:07 -0800 Subject: [PATCH 5/5] add SimVQ with or without rotation trick https://arxiv.org/abs/2411.02038 --- README.md | 9 ++++++++ examples/autoencoder.py | 1 - examples/autoencoder_fsq.py | 1 - examples/autoencoder_lfq.py | 2 -- examples/autoencoder_sim_vq.py | 10 ++++---- pyproject.toml | 2 +- vector_quantize_pytorch/sim_vq.py | 38 ++++++++++++++++++++++++------- 7 files changed, 45 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 5f7983d..42ba239 100644 --- a/README.md +++ b/README.md @@ -714,3 +714,12 @@ assert loss.item() >= 0 url = {https://api.semanticscholar.org/CorpusID:273229218} } ``` + +```bibtex +@inproceedings{Zhu2024AddressingRC, + title = {Addressing Representation Collapse in Vector Quantized Models with One Linear Layer}, + author = {Yongxin Zhu and Bocheng Li and Yifei Xin and Linli Xu}, + year = {2024}, + url = {https://api.semanticscholar.org/CorpusID:273812459} +} +``` diff --git a/examples/autoencoder.py b/examples/autoencoder.py index df79eb4..50ed7fe 100644 --- a/examples/autoencoder.py +++ b/examples/autoencoder.py @@ -71,7 +71,6 @@ def iterate_dataset(data_loader): shuffle=True, ) -print("baseline") torch.random.manual_seed(seed) model = SimpleVQAutoEncoder( diff --git a/examples/autoencoder_fsq.py b/examples/autoencoder_fsq.py index a508f51..56e3c6e 100644 --- a/examples/autoencoder_fsq.py +++ b/examples/autoencoder_fsq.py @@ -76,7 +76,6 @@ def iterate_dataset(data_loader): shuffle=True, ) -print("baseline") torch.random.manual_seed(seed) model = SimpleFSQAutoEncoder(levels).to(device) opt = torch.optim.AdamW(model.parameters(), lr=lr) diff --git a/examples/autoencoder_lfq.py b/examples/autoencoder_lfq.py index 0a48962..7ffdb7c 100644 --- a/examples/autoencoder_lfq.py +++ b/examples/autoencoder_lfq.py @@ -87,8 +87,6 @@ def iterate_dataset(data_loader): shuffle=True, ) -print("baseline") - torch.random.manual_seed(seed) model = LFQAutoEncoder( diff --git a/examples/autoencoder_sim_vq.py b/examples/autoencoder_sim_vq.py index 543543e..bd1c13b 100644 --- a/examples/autoencoder_sim_vq.py +++ b/examples/autoencoder_sim_vq.py @@ -22,11 +22,11 @@ def SimVQAutoEncoder(**vq_kwargs): nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), nn.MaxPool2d(kernel_size=2, stride=2), nn.GELU(), - nn.Conv2d(16, 64, kernel_size=3, stride=1, padding=1), + nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), nn.MaxPool2d(kernel_size=2, stride=2), - SimVQ(dim=64, accept_image_fmap = True, **vq_kwargs), + SimVQ(dim=32, accept_image_fmap = True, **vq_kwargs), nn.Upsample(scale_factor=2, mode="nearest"), - nn.Conv2d(64, 16, kernel_size=3, stride=1, padding=1), + nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1), nn.GELU(), nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1), @@ -73,11 +73,11 @@ def iterate_dataset(data_loader): shuffle=True, ) -print("baseline") torch.random.manual_seed(seed) model = SimVQAutoEncoder( - codebook_size=num_codes, + codebook_size = num_codes, + rotation_trick = rotation_trick ).to(device) opt = torch.optim.AdamW(model.parameters(), lr=lr) diff --git a/pyproject.toml b/pyproject.toml index cff581b..d65aa3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.19.5" +version = "1.20.0" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/vector_quantize_pytorch/sim_vq.py b/vector_quantize_pytorch/sim_vq.py index f99ad3a..7e16da1 100644 --- a/vector_quantize_pytorch/sim_vq.py +++ b/vector_quantize_pytorch/sim_vq.py @@ -6,7 +6,9 @@ import torch.nn.functional as F from einx import get_at -from einops import einsum, rearrange, repeat, reduce, pack, unpack +from einops import rearrange, pack, unpack + +from vector_quantize_pytorch.vector_quantize_pytorch import rotate_from_to # helper functions @@ -37,7 +39,9 @@ def __init__( dim, codebook_size, init_fn: Callable = identity, - accept_image_fmap = False + accept_image_fmap = False, + rotation_trick = True, # works even better with rotation trick turned on, with no asymmetric commit loss or straight through + commit_loss_input_to_quantize_weight = 0.25, ): super().__init__() self.accept_image_fmap = accept_image_fmap @@ -50,6 +54,17 @@ def __init__( self.codebook_to_codes = nn.Linear(dim, dim, bias = False) self.register_buffer('codebook', codebook) + + # whether to use rotation trick from Fifty et al. + # https://arxiv.org/abs/2410.06424 + + self.rotation_trick = rotation_trick + self.register_buffer('zero', torch.tensor(0.), persistent = False) + + # commit loss weighting - weighing input to quantize a bit less is crucial for it to work + + self.commit_loss_input_to_quantize_weight = commit_loss_input_to_quantize_weight + def forward( self, x @@ -68,14 +83,21 @@ def forward( quantized = get_at('[c] d, b n -> b n d', implicit_codebook, indices) - # commit loss + if self.rotation_trick: + # rotation trick from @cfifty + + quantized = rotate_from_to(quantized, x) + + commit_loss = self.zero + else: + # commit loss and straight through, as was done in the paper - commit_loss = ( - 0.25 * F.mse_loss(x, quantized.detach()) + - F.mse_loss(x.detach(), quantized) - ) + commit_loss = ( + F.mse_loss(x, quantized.detach()) * self.commit_loss_input_to_quantize_weight + + F.mse_loss(x.detach(), quantized) + ) - quantized = (quantized - x).detach() + x + quantized = (quantized - x).detach() + x if self.accept_image_fmap: quantized = inverse_pack(quantized)