Skip to content

Commit 0c6cea2

Browse files
committed
residual vq should only do one projection to the codebook dim, and one projection out
1 parent 6fd0547 commit 0c6cea2

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.6.32',
6+
version = '1.7.0',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/residual_vq.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ class ResidualVQ(nn.Module):
2828
def __init__(
2929
self,
3030
*,
31+
dim,
3132
num_quantizers,
33+
codebook_dim = None,
3234
shared_codebook = False,
3335
heads = 1,
3436
quantize_dropout = False,
@@ -39,11 +41,17 @@ def __init__(
3941
):
4042
super().__init__()
4143
assert heads == 1, 'residual vq is not compatible with multi-headed codes'
44+
codebook_dim = default(codebook_dim, dim)
45+
codebook_input_dim = codebook_dim * heads
46+
47+
requires_projection = codebook_input_dim != dim
48+
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
49+
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
4250

4351
self.num_quantizers = num_quantizers
4452

4553
self.accept_image_fmap = accept_image_fmap
46-
self.layers = nn.ModuleList([VectorQuantize(accept_image_fmap = accept_image_fmap, **kwargs) for _ in range(num_quantizers)])
54+
self.layers = nn.ModuleList([VectorQuantize(dim = codebook_dim, codebook_dim = codebook_dim, accept_image_fmap = accept_image_fmap, **kwargs) for _ in range(num_quantizers)])
4755

4856
self.quantize_dropout = quantize_dropout and num_quantizers > 1
4957

@@ -114,6 +122,8 @@ def forward(
114122
):
115123
num_quant, quant_dropout_multiple_of, return_loss, device = self.num_quantizers, self.quantize_dropout_multiple_of, exists(indices), x.device
116124

125+
x = self.project_in(x)
126+
117127
assert not (self.accept_image_fmap and exists(indices))
118128

119129
quantized_out = 0.
@@ -169,6 +179,10 @@ def forward(
169179
all_indices.append(embed_indices)
170180
all_losses.append(loss)
171181

182+
# project out, if needed
183+
184+
quantized_out = self.project_out(quantized_out)
185+
172186
# whether to early return the cross entropy loss
173187

174188
if return_loss:

0 commit comments

Comments
 (0)