Skip to content

Commit b5cb143

Browse files
committed
address #162
1 parent 492e666 commit b5cb143

File tree

4 files changed

+45
-15
lines changed

4 files changed

+45
-15
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.17.4"
3+
version = "1.17.5"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_readme.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,12 @@ def test_vq_mask():
6565
@pytest.mark.parametrize('implicit_neural_codebook', (True, False))
6666
@pytest.mark.parametrize('use_cosine_sim', (True, False))
6767
@pytest.mark.parametrize('train', (True, False))
68+
@pytest.mark.parametrize('shared_codebook', (True, False))
6869
def test_residual_vq(
6970
implicit_neural_codebook,
7071
use_cosine_sim,
71-
train
72+
train,
73+
shared_codebook
7274
):
7375
from vector_quantize_pytorch import ResidualVQ
7476

@@ -78,6 +80,7 @@ def test_residual_vq(
7880
codebook_size = 128,
7981
implicit_neural_codebook = implicit_neural_codebook,
8082
use_cosine_sim = use_cosine_sim,
83+
shared_codebook = shared_codebook
8184
)
8285

8386
x = torch.randn(1, 256, 32)

vector_quantize_pytorch/residual_vq.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ def __init__(
137137
ema_update = False
138138
)
139139

140+
if shared_codebook:
141+
vq_kwargs.update(
142+
manual_ema_update = True
143+
)
144+
140145
self.layers = ModuleList([VectorQuantize(dim = codebook_dim, codebook_dim = codebook_dim, accept_image_fmap = accept_image_fmap, **vq_kwargs) for _ in range(num_quantizers)])
141146

142147
assert all([not vq.has_projections for vq in self.layers])
@@ -157,6 +162,8 @@ def __init__(
157162

158163
# sharing codebook logic
159164

165+
self.shared_codebook = shared_codebook
166+
160167
if not shared_codebook:
161168
return
162169

@@ -349,6 +356,11 @@ def forward(
349356
all_indices.append(embed_indices)
350357
all_losses.append(loss)
351358

359+
# if shared codebook, update ema only at end
360+
361+
if self.shared_codebook:
362+
first(self.layers)._codebook.update_ema()
363+
352364
# project out, if needed
353365

354366
quantized_out = self.project_out(quantized_out)

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def __init__(
280280
gumbel_sample = gumbel_sample,
281281
sample_codebook_temp = 1.,
282282
ema_update = True,
283+
manual_ema_update = False,
283284
affine_param = False,
284285
sync_affine_param = False,
285286
affine_param_batch_decay = 0.99,
@@ -290,6 +291,7 @@ def __init__(
290291

291292
self.decay = decay
292293
self.ema_update = ema_update
294+
self.manual_ema_update = manual_ema_update
293295

294296
init_fn = uniform_init if not kmeans_init else torch.zeros
295297
embed = init_fn(num_codebooks, codebook_size, dim)
@@ -458,6 +460,12 @@ def expire_codes_(self, batch_samples):
458460
batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d')
459461
self.replace(batch_samples, batch_mask = expired_codes)
460462

463+
def update_ema(self):
464+
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True)
465+
466+
embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
467+
self.embed.data.copy_(embed_normalized)
468+
461469
@autocast('cuda', enabled = False)
462470
def forward(
463471
self,
@@ -551,11 +559,9 @@ def forward(
551559

552560
ema_inplace(self.embed_avg.data, embed_sum, self.decay)
553561

554-
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True)
555-
556-
embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
557-
self.embed.data.copy_(embed_normalized)
558-
self.expire_codes_(x)
562+
if not self.manual_ema_update:
563+
self.update_ema()
564+
self.expire_codes_(x)
559565

560566
if needs_codebook_dim:
561567
quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))
@@ -582,11 +588,14 @@ def __init__(
582588
gumbel_sample = gumbel_sample,
583589
sample_codebook_temp = 1.,
584590
ema_update = True,
591+
manual_ema_update = False
585592
):
586593
super().__init__()
587594
self.transform_input = l2norm
588595

589596
self.ema_update = ema_update
597+
self.manual_ema_update = manual_ema_update
598+
590599
self.decay = decay
591600

592601
if not kmeans_init:
@@ -671,6 +680,14 @@ def expire_codes_(self, batch_samples):
671680
batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d')
672681
self.replace(batch_samples, batch_mask = expired_codes)
673682

683+
def update_ema(self):
684+
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True)
685+
686+
embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
687+
embed_normalized = l2norm(embed_normalized)
688+
689+
self.embed.data.copy_(embed_normalized)
690+
674691
@autocast('cuda', enabled = False)
675692
def forward(
676693
self,
@@ -746,13 +763,9 @@ def forward(
746763

747764
ema_inplace(self.embed_avg.data, embed_sum, self.decay)
748765

749-
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True)
750-
751-
embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
752-
embed_normalized = l2norm(embed_normalized)
753-
754-
self.embed.data.copy_(embed_normalized)
755-
self.expire_codes_(x)
766+
if not self.manual_ema_update:
767+
self.update_ema()
768+
self.expire_codes_(x)
756769

757770
if needs_codebook_dim:
758771
quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))
@@ -802,6 +815,7 @@ def __init__(
802815
sync_codebook = None,
803816
sync_affine_param = False,
804817
ema_update = True,
818+
manual_ema_update = False,
805819
learnable_codebook = False,
806820
in_place_codebook_optimizer: Callable[..., Optimizer] = None, # Optimizer used to update the codebook embedding if using learnable_codebook
807821
affine_param = False,
@@ -881,7 +895,8 @@ def __init__(
881895
learnable_codebook = has_codebook_orthogonal_loss or learnable_codebook,
882896
sample_codebook_temp = sample_codebook_temp,
883897
gumbel_sample = gumbel_sample_fn,
884-
ema_update = ema_update
898+
ema_update = ema_update,
899+
manual_ema_update = manual_ema_update
885900
)
886901

887902
if affine_param:

0 commit comments

Comments
 (0)