Skip to content

Commit 7e34d13

Browse files
committed
assert that residual quantizers never have projections except before the residual quantizing
1 parent cbbc77b commit 7e34d13

File tree

7 files changed

+21
-5
lines changed

7 files changed

+21
-5
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.11.5',
6+
version = '1.11.6',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/finite_scalar_quantization.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,11 @@ def __init__(
6868
self.keep_num_codebooks_dim = keep_num_codebooks_dim
6969

7070
self.dim = default(dim, len(_levels) * num_codebooks)
71-
self.project_in = nn.Linear(self.dim, effective_codebook_dim) if self.dim != effective_codebook_dim else nn.Identity()
72-
self.project_out = nn.Linear(effective_codebook_dim, self.dim) if self.dim != effective_codebook_dim else nn.Identity()
71+
72+
has_projections = self.dim != effective_codebook_dim
73+
self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
74+
self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
75+
self.has_projections = has_projections
7376

7477
self.codebook_size = self._levels.prod().item()
7578

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,10 @@ def __init__(
8484
codebook_dims = codebook_dim * num_codebooks
8585
dim = default(dim, codebook_dims)
8686

87-
self.project_in = nn.Linear(dim, codebook_dims) if dim != codebook_dims else nn.Identity()
88-
self.project_out = nn.Linear(codebook_dims, dim) if dim != codebook_dims else nn.Identity()
87+
has_projections = dim != codebook_dims
88+
self.project_in = nn.Linear(dim, codebook_dims) if has_projections else nn.Identity()
89+
self.project_out = nn.Linear(codebook_dims, dim) if has_projections else nn.Identity()
90+
self.has_projections = has_projections
8991

9092
self.dim = dim
9193
self.codebook_dim = codebook_dim

vector_quantize_pytorch/residual_fsq.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
requires_projection = codebook_dim != dim
4848
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
4949
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
50+
self.has_projections = requires_projection
5051

5152
self.num_quantizers = num_quantizers
5253

@@ -68,6 +69,8 @@ def __init__(
6869

6970
self.layers.append(fsq)
7071

72+
assert all([not fsq.has_projections for fsq in self.layers])
73+
7174
self.codebook_size = self.layers[0].codebook_size
7275

7376
self.register_buffer('scales', torch.stack(scales), persistent = False)

vector_quantize_pytorch/residual_lfq.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
requires_projection = codebook_dim != dim
4646
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
4747
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
48+
self.has_projections = requires_projection
4849

4950
self.num_quantizers = num_quantizers
5051

@@ -61,6 +62,8 @@ def __init__(
6162

6263
self.layers.append(lfq)
6364

65+
assert all([not lfq.has_projections for lfq in self.layers])
66+
6467
self.quantize_dropout = quantize_dropout and num_quantizers > 1
6568

6669
assert quantize_dropout_cutoff_index >= 0

vector_quantize_pytorch/residual_vq.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,15 @@ def __init__(
4747
requires_projection = codebook_input_dim != dim
4848
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
4949
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
50+
self.has_projections = requires_projection
5051

5152
self.num_quantizers = num_quantizers
5253

5354
self.accept_image_fmap = accept_image_fmap
5455
self.layers = nn.ModuleList([VectorQuantize(dim = codebook_dim, codebook_dim = codebook_dim, accept_image_fmap = accept_image_fmap, **kwargs) for _ in range(num_quantizers)])
5556

57+
assert all([not vq.has_projections for vq in self.layers])
58+
5659
self.quantize_dropout = quantize_dropout and num_quantizers > 1
5760

5861
assert quantize_dropout_cutoff_index >= 0

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,8 @@ def __init__(
732732
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
733733
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
734734

735+
self.has_projections = requires_projection
736+
735737
self.eps = eps
736738
self.commitment_weight = commitment_weight
737739
self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss # whether to use cross entropy loss to codebook as commitment loss

0 commit comments

Comments
 (0)