Skip to content

Commit 25683f2

Browse files
committed
handel inplace optimizer correctly for shared codebook under residual vq setting
1 parent b5cb143 commit 25683f2

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.17.5"
3+
version = "1.17.6"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/residual_vq.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ def __init__(
139139

140140
if shared_codebook:
141141
vq_kwargs.update(
142-
manual_ema_update = True
142+
manual_ema_update = True,
143+
manual_in_place_optimizer_update = True
143144
)
144145

145146
self.layers = ModuleList([VectorQuantize(dim = codebook_dim, codebook_dim = codebook_dim, accept_image_fmap = accept_image_fmap, **vq_kwargs) for _ in range(num_quantizers)])
@@ -360,6 +361,7 @@ def forward(
360361

361362
if self.shared_codebook:
362363
first(self.layers)._codebook.update_ema()
364+
first(self.layers).update_in_place_optimizer()
363365

364366
# project out, if needed
365367

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,7 @@ def __init__(
818818
manual_ema_update = False,
819819
learnable_codebook = False,
820820
in_place_codebook_optimizer: Callable[..., Optimizer] = None, # Optimizer used to update the codebook embedding if using learnable_codebook
821+
manual_in_place_optimizer_update = False,
821822
affine_param = False,
822823
affine_param_batch_decay = 0.99,
823824
affine_param_codebook_decay = 0.9,
@@ -913,6 +914,7 @@ def __init__(
913914
self._codebook = codebook_class(**codebook_kwargs)
914915

915916
self.in_place_codebook_optimizer = in_place_codebook_optimizer(self._codebook.parameters()) if exists(in_place_codebook_optimizer) else None
917+
self.manual_in_place_optimizer_update = manual_in_place_optimizer_update
916918

917919
self.codebook_size = codebook_size
918920

@@ -966,6 +968,13 @@ def get_output_from_indices(self, indices):
966968
codes = self.get_codes_from_indices(indices)
967969
return self.project_out(codes)
968970

971+
def update_in_place_optimizer(self):
972+
if not exists(self.in_place_codebook_optimizer):
973+
return
974+
975+
self.in_place_codebook_optimizer.step()
976+
self.in_place_codebook_optimizer.zero_grad()
977+
969978
def forward(
970979
self,
971980
x,
@@ -1057,8 +1066,9 @@ def forward(
10571066
loss = F.mse_loss(quantize, x.detach())
10581067

10591068
loss.backward()
1060-
self.in_place_codebook_optimizer.step()
1061-
self.in_place_codebook_optimizer.zero_grad()
1069+
1070+
if not self.manual_in_place_optimizer_update:
1071+
self.update_in_place_optimizer()
10621072

10631073
inplace_optimize_loss = loss
10641074

0 commit comments

Comments
 (0)