You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: vector_quantize_pytorch/vector_quantize_pytorch.py
+17-2Lines changed: 17 additions & 2 deletions
Original file line number
Diff line number
Diff line change
@@ -26,6 +26,15 @@ def identity(t):
26
26
defl2norm(t):
27
27
returnF.normalize(t, p=2, dim=-1)
28
28
29
+
defSequential(*modules):
30
+
modules= [*filter(exists, modules)]
31
+
iflen(modules) ==0:
32
+
returnNone
33
+
eliflen(modules) ==1:
34
+
returnmodules[0]
35
+
36
+
returnnn.Sequential(*modules)
37
+
29
38
defcdist(x, y):
30
39
x2=reduce(x**2, 'b n d -> b n', 'sum')
31
40
y2=reduce(y**2, 'b n d -> b n', 'sum')
@@ -702,6 +711,7 @@ def __init__(
702
711
kmeans_iters=10,
703
712
sync_kmeans=True,
704
713
use_cosine_sim=False,
714
+
layernorm_after_project_in=False, # proposed by @SaltyChtao here https://github.com/lucidrains/vector-quantize-pytorch/issues/26#issuecomment-1324711561
705
715
threshold_ema_dead_code=0,
706
716
channel_last=True,
707
717
accept_image_fmap=False,
@@ -721,7 +731,7 @@ def __init__(
721
731
in_place_codebook_optimizer: Callable[..., Optimizer] =None, # Optimizer used to update the codebook embedding if using learnable_codebook
722
732
affine_param=False,
723
733
affine_param_batch_decay=0.99,
724
-
affine_param_codebook_decay=0.9,
734
+
affine_param_codebook_decay=0.9,
725
735
sync_update_v=0.# the v that controls optimistic vs pessimistic update for synchronous update rule (21) https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
0 commit comments