Skip to content

Commit c509a41

Browse files
committed
add ability to set weight of ema update on forward with ema_update_weight for #210
1 parent 078aed4 commit c509a41

File tree

3 files changed

+69
-20
lines changed

3 files changed

+69
-20
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.22.11"
3+
version = "1.22.12"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_readme.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,3 +401,34 @@ def test_residual_sim_vq():
401401

402402
assert x.shape == quantized.shape
403403
assert torch.allclose(quantized, residual_sim_vq.get_output_from_indices(indices), atol = 1e-5)
404+
405+
@pytest.mark.parametrize('use_cosine_sim', (False, True))
406+
def test_vq_custom_ema_update_weighting(
407+
use_cosine_sim
408+
):
409+
from vector_quantize_pytorch import VectorQuantize
410+
411+
vq = VectorQuantize(
412+
dim = 256,
413+
use_cosine_sim = use_cosine_sim,
414+
codebook_dim = 128,
415+
codebook_size = 8,
416+
decay = 0.8,
417+
)
418+
419+
x = torch.randn(16, 1024, 256)
420+
421+
codebook_before = vq.codebook.clone()
422+
423+
weights = torch.randint(0, 2, (8,)).float()
424+
update_weights_callable = lambda embed_sum, cluster_size: weights
425+
426+
quantized, indices, loss = vq(x, ema_update_weight = update_weights_callable)
427+
428+
codebook_after = vq.codebook
429+
430+
did_update = weights.bool()
431+
did_not_update = ~did_update
432+
433+
assert torch.allclose(codebook_before[did_not_update], codebook_after[did_not_update], atol = 1e-6)
434+
assert (codebook_before[did_update] != codebook_after[did_update]).all()

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77
from torch.nn import Module
8-
from torch import nn, einsum, Tensor
8+
from torch import nn, einsum, is_tensor, Tensor
99
import torch.nn.functional as F
1010
import torch.distributed as distributed
1111
from torch.optim import Optimizer
@@ -34,6 +34,12 @@ def l2norm(t, dim = -1, eps = 1e-6):
3434
def safe_div(num, den, eps = 1e-6):
3535
return num / den.clamp(min = eps)
3636

37+
def append_dims_to(t, ndims):
38+
assert t.ndim <= ndims
39+
append_ndims = ndims - t.ndim
40+
shape = t.shape
41+
return t.reshape(*shape, *((1,) * append_ndims))
42+
3743
def Sequential(*modules):
3844
modules = [*filter(exists, modules)]
3945
if len(modules) == 0:
@@ -55,13 +61,17 @@ def log(t, eps = 1e-20):
5561
def entropy(prob, eps = 1e-5):
5662
return (-prob * log(prob, eps = eps)).sum(dim = -1)
5763

58-
def ema_inplace(old, new, decay):
59-
is_mps = str(old.device).startswith('mps:')
64+
def ema_inplace(old, new, decay, weight = None):
65+
weight = default(weight, 1.)
6066

61-
if not is_mps:
62-
old.lerp_(new, 1 - decay)
63-
else:
64-
old.mul_(decay).add_(new * (1 - decay))
67+
if is_tensor(weight):
68+
if weight.ndim == 1:
69+
weight = rearrange(weight, 'c -> 1 c')
70+
71+
assert weight.ndim == 2 and weight.shape == old.shape[:2]
72+
weight = append_dims_to(weight, old.ndim)
73+
74+
old.lerp_(new, (1. - decay) * weight)
6575

6676
def pack_one(t, pattern):
6777
packed, ps = pack([t], pattern)
@@ -392,9 +402,9 @@ def init_embed_(self, data, mask = None):
392402

393403
embed_sum = embed * rearrange(cluster_size, '... -> ... 1')
394404

395-
self.embed.data.copy_(embed)
396405
self.embed_avg.data.copy_(embed_sum)
397406
self.cluster_size.data.copy_(cluster_size)
407+
self.update_ema()
398408
self.initted.data.copy_(torch.Tensor([True]))
399409

400410
@torch.jit.ignore
@@ -500,7 +510,8 @@ def forward(
500510
sample_codebook_temp = None,
501511
mask = None,
502512
freeze_codebook = False,
503-
codebook_transform_fn: Callable | None = None
513+
codebook_transform_fn: Callable | None = None,
514+
ema_update_weight: Tensor | Callable | None = None
504515
):
505516
needs_codebook_dim = x.ndim < 4
506517
sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp)
@@ -585,15 +596,17 @@ def forward(
585596
embed_onehot[~mask] = 0.
586597

587598
cluster_size = embed_onehot.sum(dim = 1)
588-
589599
self.all_reduce_fn(cluster_size)
590-
ema_inplace(self.cluster_size.data, cluster_size, self.decay)
591600

592601
embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot)
593602
embed_sum = embed_sum.contiguous()
594603
self.all_reduce_fn(embed_sum)
595604

596-
ema_inplace(self.embed_avg.data, embed_sum, self.decay)
605+
if callable(ema_update_weight):
606+
ema_update_weight = ema_update_weight(embed_sum, cluster_size)
607+
608+
ema_inplace(self.cluster_size.data, cluster_size, self.decay, ema_update_weight)
609+
ema_inplace(self.embed_avg.data, embed_sum, self.decay, ema_update_weight)
597610

598611
if not self.manual_ema_update:
599612
self.update_ema()
@@ -688,9 +701,9 @@ def init_embed_(self, data, mask = None):
688701

689702
embed_sum = embed * rearrange(cluster_size, '... -> ... 1')
690703

691-
self.embed.data.copy_(embed)
692704
self.embed_avg.data.copy_(embed_sum)
693705
self.cluster_size.data.copy_(cluster_size)
706+
self.update_ema()
694707
self.initted.data.copy_(torch.Tensor([True]))
695708

696709
def replace(self, batch_samples, batch_mask):
@@ -731,7 +744,8 @@ def forward(
731744
sample_codebook_temp = None,
732745
mask = None,
733746
freeze_codebook = False,
734-
codebook_transform_fn: Callable | None = None
747+
codebook_transform_fn: Callable | None = None,
748+
ema_update_weight: Tensor | None = None
735749
):
736750
needs_codebook_dim = x.ndim < 4
737751
sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp)
@@ -800,13 +814,15 @@ def forward(
800814
bins = embed_onehot.sum(dim = 1)
801815
self.all_reduce_fn(bins)
802816

803-
ema_inplace(self.cluster_size.data, bins, self.decay)
804-
805817
embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot)
806818
embed_sum = embed_sum.contiguous()
807819
self.all_reduce_fn(embed_sum)
808820

809-
ema_inplace(self.embed_avg.data, embed_sum, self.decay)
821+
if callable(ema_update_weight):
822+
ema_update_weight = ema_update_weight(embed_sum, bins)
823+
824+
ema_inplace(self.cluster_size.data, bins, self.decay, ema_update_weight)
825+
ema_inplace(self.embed_avg.data, embed_sum, self.decay, ema_update_weight)
810826

811827
if not self.manual_ema_update:
812828
self.update_ema()
@@ -1047,7 +1063,8 @@ def forward(
10471063
sample_codebook_temp = None,
10481064
freeze_codebook = None,
10491065
return_loss_breakdown = False,
1050-
codebook_transform_fn: Callable | None = None
1066+
codebook_transform_fn: Callable | None = None,
1067+
ema_update_weight: Tensor | None = None
10511068
):
10521069
orig_input, input_requires_grad = x, x.requires_grad
10531070

@@ -1103,7 +1120,8 @@ def forward(
11031120
sample_codebook_temp = sample_codebook_temp,
11041121
mask = mask,
11051122
freeze_codebook = freeze_codebook,
1106-
codebook_transform_fn = codebook_transform_fn
1123+
codebook_transform_fn = codebook_transform_fn,
1124+
ema_update_weight = ema_update_weight
11071125
)
11081126

11091127
# quantize

0 commit comments

Comments
 (0)