Skip to content

Commit b38a4c2

Browse files
committed
address #26 (comment)
1 parent 3edf4dd commit b38a4c2

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.14.22"
3+
version = "1.14.23"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@ def identity(t):
2626
def l2norm(t):
2727
return F.normalize(t, p = 2, dim = -1)
2828

29+
def Sequential(*modules):
30+
modules = [*filter(exists, modules)]
31+
if len(modules) == 0:
32+
return None
33+
elif len(modules) == 1:
34+
return modules[0]
35+
36+
return nn.Sequential(*modules)
37+
2938
def cdist(x, y):
3039
x2 = reduce(x ** 2, 'b n d -> b n', 'sum')
3140
y2 = reduce(y ** 2, 'b n d -> b n', 'sum')
@@ -702,6 +711,7 @@ def __init__(
702711
kmeans_iters = 10,
703712
sync_kmeans = True,
704713
use_cosine_sim = False,
714+
layernorm_after_project_in = False, # proposed by @SaltyChtao here https://github.com/lucidrains/vector-quantize-pytorch/issues/26#issuecomment-1324711561
705715
threshold_ema_dead_code = 0,
706716
channel_last = True,
707717
accept_image_fmap = False,
@@ -721,7 +731,7 @@ def __init__(
721731
in_place_codebook_optimizer: Callable[..., Optimizer] = None, # Optimizer used to update the codebook embedding if using learnable_codebook
722732
affine_param = False,
723733
affine_param_batch_decay = 0.99,
724-
affine_param_codebook_decay = 0.9,
734+
affine_param_codebook_decay = 0.9,
725735
sync_update_v = 0. # the v that controls optimistic vs pessimistic update for synchronous update rule (21) https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
726736
):
727737
super().__init__()
@@ -733,7 +743,12 @@ def __init__(
733743
codebook_input_dim = codebook_dim * heads
734744

735745
requires_projection = codebook_input_dim != dim
736-
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
746+
747+
self.project_in = Sequential(
748+
nn.Linear(dim, codebook_input_dim),
749+
nn.LayerNorm(codebook_input_dim) if layernorm_after_project_in else None
750+
) if requires_projection else nn.Identity()
751+
737752
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
738753

739754
self.has_projections = requires_projection

0 commit comments

Comments
 (0)