Skip to content

Commit 483ed6a

Browse files
committed
random projection quantizer needs to have input normalized to mean 0 std of 1
1 parent b16be13 commit 483ed6a

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.1.2',
6+
version = '1.1.4',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/random_projection_quantizer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ def __init__(
2525

2626
self.register_buffer('rand_projs', rand_projs)
2727

28+
# in section 3 of https://arxiv.org/abs/2202.01855
29+
# "The input data is normalized to have 0 mean and standard deviation of 1 ... to prevent collapse"
30+
31+
self.norm = nn.LayerNorm(dim, elementwise_affine = False)
32+
2833
self.vq = VectorQuantize(
2934
dim = codebook_dim * num_codebooks,
3035
heads = num_codebooks,
@@ -37,6 +42,8 @@ def __init__(
3742
@torch.no_grad()
3843
def forward(self, x):
3944

45+
x = self.norm(x)
46+
4047
x = einsum('b n d, h d e -> b n h e', x, self.rand_projs)
4148
x, ps = pack([x], 'b n *')
4249

0 commit comments

Comments
 (0)