diff --git a/README.md b/README.md index 5f7983d..42ba239 100644 --- a/README.md +++ b/README.md @@ -714,3 +714,12 @@ assert loss.item() >= 0 url = {https://api.semanticscholar.org/CorpusID:273229218} } ``` + +```bibtex +@inproceedings{Zhu2024AddressingRC, + title = {Addressing Representation Collapse in Vector Quantized Models with One Linear Layer}, + author = {Yongxin Zhu and Bocheng Li and Yifei Xin and Linli Xu}, + year = {2024}, + url = {https://api.semanticscholar.org/CorpusID:273812459} +} +``` diff --git a/examples/autoencoder.py b/examples/autoencoder.py index 3b64d58..50ed7fe 100644 --- a/examples/autoencoder.py +++ b/examples/autoencoder.py @@ -71,12 +71,12 @@ def iterate_dataset(data_loader): shuffle=True, ) -print("baseline") torch.random.manual_seed(seed) model = SimpleVQAutoEncoder( - codebook_size=num_codes, - rotation_trick=rotation_trick + codebook_size = num_codes, + rotation_trick = True, + straight_through = False ).to(device) opt = torch.optim.AdamW(model.parameters(), lr=lr) diff --git a/examples/autoencoder_fsq.py b/examples/autoencoder_fsq.py index a508f51..56e3c6e 100644 --- a/examples/autoencoder_fsq.py +++ b/examples/autoencoder_fsq.py @@ -76,7 +76,6 @@ def iterate_dataset(data_loader): shuffle=True, ) -print("baseline") torch.random.manual_seed(seed) model = SimpleFSQAutoEncoder(levels).to(device) opt = torch.optim.AdamW(model.parameters(), lr=lr) diff --git a/examples/autoencoder_lfq.py b/examples/autoencoder_lfq.py index 0a48962..7ffdb7c 100644 --- a/examples/autoencoder_lfq.py +++ b/examples/autoencoder_lfq.py @@ -87,8 +87,6 @@ def iterate_dataset(data_loader): shuffle=True, ) -print("baseline") - torch.random.manual_seed(seed) model = LFQAutoEncoder( diff --git a/examples/autoencoder_sim_vq.py b/examples/autoencoder_sim_vq.py new file mode 100644 index 0000000..bd1c13b --- /dev/null +++ b/examples/autoencoder_sim_vq.py @@ -0,0 +1,84 @@ +# FashionMnist VQ experiment with various settings. +# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py + +from tqdm.auto import trange + +import torch +import torch.nn as nn +from torchvision import datasets, transforms +from torch.utils.data import DataLoader + +from vector_quantize_pytorch import SimVQ, Sequential + +lr = 3e-4 +train_iter = 10000 +num_codes = 256 +seed = 1234 +rotation_trick = True +device = "cuda" if torch.cuda.is_available() else "cpu" + +def SimVQAutoEncoder(**vq_kwargs): + return Sequential( + nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.GELU(), + nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), + nn.MaxPool2d(kernel_size=2, stride=2), + SimVQ(dim=32, accept_image_fmap = True, **vq_kwargs), + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1), + nn.GELU(), + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1), + ) + +def train(model, train_loader, train_iterations=1000, alpha=10): + def iterate_dataset(data_loader): + data_iter = iter(data_loader) + while True: + try: + x, y = next(data_iter) + except StopIteration: + data_iter = iter(data_loader) + x, y = next(data_iter) + yield x.to(device), y.to(device) + + for _ in (pbar := trange(train_iterations)): + opt.zero_grad() + x, _ = next(iterate_dataset(train_loader)) + + out, indices, cmt_loss = model(x) + out = out.clamp(-1., 1.) + + rec_loss = (out - x).abs().mean() + (rec_loss + alpha * cmt_loss).backward() + + opt.step() + + pbar.set_description( + f"rec loss: {rec_loss.item():.3f} | " + + f"cmt loss: {cmt_loss.item():.3f} | " + + f"active %: {indices.unique().numel() / num_codes * 100:.3f}" + ) + +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) + +train_dataset = DataLoader( + datasets.FashionMNIST( + root="~/data/fashion_mnist", train=True, download=True, transform=transform + ), + batch_size=256, + shuffle=True, +) + +torch.random.manual_seed(seed) + +model = SimVQAutoEncoder( + codebook_size = num_codes, + rotation_trick = rotation_trick +).to(device) + +opt = torch.optim.AdamW(model.parameters(), lr=lr) +train(model, train_dataset, train_iterations=train_iter) diff --git a/pyproject.toml b/pyproject.toml index cff581b..d65aa3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.19.5" +version = "1.20.0" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/vector_quantize_pytorch/__init__.py b/vector_quantize_pytorch/__init__.py index 06a0a24..9e7ce9b 100644 --- a/vector_quantize_pytorch/__init__.py +++ b/vector_quantize_pytorch/__init__.py @@ -6,5 +6,6 @@ from vector_quantize_pytorch.residual_lfq import ResidualLFQ, GroupedResidualLFQ from vector_quantize_pytorch.residual_fsq import ResidualFSQ, GroupedResidualFSQ from vector_quantize_pytorch.latent_quantization import LatentQuantize +from vector_quantize_pytorch.sim_vq import SimVQ from vector_quantize_pytorch.utils import Sequential diff --git a/vector_quantize_pytorch/sim_vq.py b/vector_quantize_pytorch/sim_vq.py new file mode 100644 index 0000000..7e16da1 --- /dev/null +++ b/vector_quantize_pytorch/sim_vq.py @@ -0,0 +1,124 @@ +from typing import Callable + +import torch +from torch import nn +from torch.nn import Module +import torch.nn.functional as F + +from einx import get_at +from einops import rearrange, pack, unpack + +from vector_quantize_pytorch.vector_quantize_pytorch import rotate_from_to + +# helper functions + +def exists(v): + return v is not None + +def identity(t): + return t + +def default(v, d): + return v if exists(v) else d + +def pack_one(t, pattern): + packed, packed_shape = pack([t], pattern) + + def inverse(out, inv_pattern = None): + inv_pattern = default(inv_pattern, pattern) + out, = unpack(out, packed_shape, inv_pattern) + return out + + return packed, inverse + +# class + +class SimVQ(Module): + def __init__( + self, + dim, + codebook_size, + init_fn: Callable = identity, + accept_image_fmap = False, + rotation_trick = True, # works even better with rotation trick turned on, with no asymmetric commit loss or straight through + commit_loss_input_to_quantize_weight = 0.25, + ): + super().__init__() + self.accept_image_fmap = accept_image_fmap + + codebook = torch.randn(codebook_size, dim) * (dim ** -0.5) + codebook = init_fn(codebook) + + # the codebook is actually implicit from a linear layer from frozen gaussian or uniform + + self.codebook_to_codes = nn.Linear(dim, dim, bias = False) + self.register_buffer('codebook', codebook) + + + # whether to use rotation trick from Fifty et al. + # https://arxiv.org/abs/2410.06424 + + self.rotation_trick = rotation_trick + self.register_buffer('zero', torch.tensor(0.), persistent = False) + + # commit loss weighting - weighing input to quantize a bit less is crucial for it to work + + self.commit_loss_input_to_quantize_weight = commit_loss_input_to_quantize_weight + + def forward( + self, + x + ): + if self.accept_image_fmap: + x = rearrange(x, 'b d h w -> b h w d') + x, inverse_pack = pack_one(x, 'b * d') + + implicit_codebook = self.codebook_to_codes(self.codebook) + + with torch.no_grad(): + dist = torch.cdist(x, implicit_codebook) + indices = dist.argmin(dim = -1) + + # select codes + + quantized = get_at('[c] d, b n -> b n d', implicit_codebook, indices) + + if self.rotation_trick: + # rotation trick from @cfifty + + quantized = rotate_from_to(quantized, x) + + commit_loss = self.zero + else: + # commit loss and straight through, as was done in the paper + + commit_loss = ( + F.mse_loss(x, quantized.detach()) * self.commit_loss_input_to_quantize_weight + + F.mse_loss(x.detach(), quantized) + ) + + quantized = (quantized - x).detach() + x + + if self.accept_image_fmap: + quantized = inverse_pack(quantized) + quantized = rearrange(quantized, 'b h w d-> b d h w') + + indices = inverse_pack(indices, 'b *') + + return quantized, indices, commit_loss + +# main + +if __name__ == '__main__': + + x = torch.randn(1, 512, 32, 32) + + sim_vq = SimVQ( + dim = 512, + codebook_size = 1024, + accept_image_fmap = True + ) + + quantized, indices, commit_loss = sim_vq(x) + + assert x.shape == quantized.shape diff --git a/vector_quantize_pytorch/utils.py b/vector_quantize_pytorch/utils.py index bdb4386..d591a09 100644 --- a/vector_quantize_pytorch/utils.py +++ b/vector_quantize_pytorch/utils.py @@ -12,6 +12,7 @@ from vector_quantize_pytorch.residual_lfq import ResidualLFQ, GroupedResidualLFQ from vector_quantize_pytorch.residual_fsq import ResidualFSQ, GroupedResidualFSQ from vector_quantize_pytorch.latent_quantization import LatentQuantize +from vector_quantize_pytorch.sim_vq import SimVQ QUANTIZE_KLASSES = ( VectorQuantize, @@ -20,6 +21,7 @@ RandomProjectionQuantizer, FSQ, LFQ, + SimVQ, ResidualLFQ, GroupedResidualLFQ, ResidualFSQ,