Skip to content

Commit 0fa6e78

Browse files
committed
address #221
1 parent fe903ce commit 0fa6e78

File tree

3 files changed

+71
-16
lines changed

3 files changed

+71
-16
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.22.18"
3+
version = "1.23.0"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_readme.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,3 +434,30 @@ def test_vq_custom_ema_update_weighting(
434434

435435
assert torch.allclose(codebook_before[did_not_update], codebook_after[did_not_update], atol = 1e-6)
436436
assert (codebook_before[did_update] != codebook_after[did_update]).all()
437+
438+
def test_accum_ema_update():
439+
from vector_quantize_pytorch import VectorQuantize
440+
441+
vq = VectorQuantize(
442+
dim = 256,
443+
use_cosine_sim = True,
444+
codebook_dim = 128,
445+
codebook_size = 8, # codebook size
446+
decay = 0.8, # the exponential moving average decay, lower means the dictionary will change faster
447+
commitment_weight = 1., # the weight on the commitment loss
448+
)
449+
450+
x = torch.randn(16, 1024, 256)
451+
452+
codebook_before = vq.codebook.clone()
453+
454+
vq.train()
455+
456+
_ = vq(x, accum_ema_update = True)
457+
_ = vq(x, accum_ema_update = True)
458+
459+
assert torch.allclose(codebook_before, vq.codebook, atol = 1e-6)
460+
461+
_ = vq(x)
462+
463+
assert not torch.allclose(codebook_before, vq.codebook, atol = 1e-6)

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,22 @@ def log(t, eps = 1e-20):
6161
def entropy(prob, eps = 1e-5):
6262
return (-prob * log(prob, eps = eps)).sum(dim = -1)
6363

64+
def accum_grad_(t, grad):
65+
if exists(t.grad):
66+
t.grad.add_(grad)
67+
else:
68+
t.grad = grad.clone().detach()
69+
6470
def ema_inplace(old, new, decay, weight = None):
71+
72+
# if old.grad is populated, add it to new and set it to None
73+
74+
if exists(old.grad):
75+
new.add_(old.grad)
76+
old.grad = None
77+
78+
# take care of custom weighting
79+
6580
weight = default(weight, 1.)
6681

6782
if is_tensor(weight):
@@ -71,7 +86,7 @@ def ema_inplace(old, new, decay, weight = None):
7186
assert weight.ndim == 2 and weight.shape == old.shape[:2]
7287
weight = append_dims_to(weight, old.ndim)
7388

74-
old.lerp_(new, (1. - decay) * weight)
89+
old.data.lerp_(new, (1. - decay) * weight)
7590

7691
def pack_one(t, pattern):
7792
packed, ps = pack([t], pattern)
@@ -511,7 +526,8 @@ def forward(
511526
mask = None,
512527
freeze_codebook = False,
513528
codebook_transform_fn: Callable | None = None,
514-
ema_update_weight: Tensor | Callable | None = None
529+
ema_update_weight: Tensor | Callable | None = None,
530+
accum_ema_update = False
515531
):
516532
needs_codebook_dim = x.ndim < 4
517533
sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp)
@@ -603,12 +619,16 @@ def forward(
603619
if callable(ema_update_weight):
604620
ema_update_weight = ema_update_weight(embed_sum, cluster_size)
605621

606-
ema_inplace(self.cluster_size.data, cluster_size, self.decay, ema_update_weight)
607-
ema_inplace(self.embed_avg.data, embed_sum, self.decay, ema_update_weight)
622+
if accum_ema_update:
623+
accum_grad_(self.cluster_size, cluster_size)
624+
accum_grad_(self.embed_avg, embed_sum)
625+
else:
626+
ema_inplace(self.cluster_size, cluster_size, self.decay, ema_update_weight)
627+
ema_inplace(self.embed_avg, embed_sum, self.decay, ema_update_weight)
608628

609-
if not self.manual_ema_update:
610-
self.update_ema()
611-
self.expire_codes_(x)
629+
if not self.manual_ema_update:
630+
self.update_ema()
631+
self.expire_codes_(x)
612632

613633
if needs_codebook_dim:
614634
quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))
@@ -743,7 +763,8 @@ def forward(
743763
mask = None,
744764
freeze_codebook = False,
745765
codebook_transform_fn: Callable | None = None,
746-
ema_update_weight: Tensor | None = None
766+
ema_update_weight: Tensor | None = None,
767+
accum_ema_update = False
747768
):
748769
needs_codebook_dim = x.ndim < 4
749770
sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp)
@@ -819,12 +840,17 @@ def forward(
819840
if callable(ema_update_weight):
820841
ema_update_weight = ema_update_weight(embed_sum, bins)
821842

822-
ema_inplace(self.cluster_size.data, bins, self.decay, ema_update_weight)
823-
ema_inplace(self.embed_avg.data, embed_sum, self.decay, ema_update_weight)
843+
if accum_ema_update:
844+
accum_grad_(self.cluster_size, bins)
845+
accum_grad_(self.embed_avg, embed_sum)
846+
else:
847+
848+
ema_inplace(self.cluster_size, bins, self.decay, ema_update_weight)
849+
ema_inplace(self.embed_avg, embed_sum, self.decay, ema_update_weight)
824850

825-
if not self.manual_ema_update:
826-
self.update_ema()
827-
self.expire_codes_(x)
851+
if not self.manual_ema_update:
852+
self.update_ema()
853+
self.expire_codes_(x)
828854

829855
if needs_codebook_dim:
830856
quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))
@@ -1062,7 +1088,8 @@ def forward(
10621088
freeze_codebook = None,
10631089
return_loss_breakdown = False,
10641090
codebook_transform_fn: Callable | None = None,
1065-
ema_update_weight: Tensor | None = None
1091+
ema_update_weight: Tensor | None = None,
1092+
accum_ema_update = False
10661093
):
10671094
orig_input, input_requires_grad = x, x.requires_grad
10681095

@@ -1119,7 +1146,8 @@ def forward(
11191146
mask = mask,
11201147
freeze_codebook = freeze_codebook,
11211148
codebook_transform_fn = codebook_transform_fn,
1122-
ema_update_weight = ema_update_weight
1149+
ema_update_weight = ema_update_weight,
1150+
accum_ema_update = accum_ema_update
11231151
)
11241152

11251153
# quantize

0 commit comments

Comments
 (0)