Skip to content

Commit ce34332

Browse files
committed
allow for 1d channel first inputs into FSQ, needed for a contracting project
1 parent 358f4cd commit ce34332

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.14.6',
6+
version = '1.14.7',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/finite_scalar_quantization.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ def __init__(
4747
num_codebooks = 1,
4848
keep_num_codebooks_dim: Optional[bool] = None,
4949
scale: Optional[float] = None,
50-
allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64)
50+
allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
51+
channel_first: bool = False
5152
):
5253
super().__init__()
5354
_levels = torch.tensor(levels, dtype=int32)
@@ -71,14 +72,17 @@ def __init__(
7172

7273
self.dim = default(dim, len(_levels) * num_codebooks)
7374

75+
self.channel_first = channel_first
76+
7477
has_projections = self.dim != effective_codebook_dim
7578
self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
7679
self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
80+
7781
self.has_projections = has_projections
7882

7983
self.codebook_size = self._levels.prod().item()
8084

81-
implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out = False)
85+
implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size))
8286
self.register_buffer("implicit_codebook", implicit_codebook, persistent = False)
8387

8488
self.allowed_dtypes = allowed_dtypes
@@ -103,33 +107,35 @@ def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
103107
def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
104108
half_width = self._levels // 2
105109
return (zhat - half_width) / half_width
106-
110+
111+
def _indices_to_codes(self, indices: Tensor):
112+
indices = rearrange(indices, '... -> ... 1')
113+
codes_non_centered = (indices // self._basis) % self._levels
114+
codes = self._scale_and_shift_inverse(codes_non_centered)
115+
return codes
116+
107117
def codes_to_indices(self, zhat: Tensor) -> Tensor:
108118
"""Converts a `code` to an index in the codebook."""
109119
assert zhat.shape[-1] == self.codebook_dim
110120
zhat = self._scale_and_shift(zhat)
111121
return (zhat * self._basis).sum(dim=-1).to(int32)
112-
122+
113123
def indices_to_codes(
114124
self,
115-
indices: Tensor,
116-
project_out = True
125+
indices: Tensor
117126
) -> Tensor:
118127
"""Inverse of `codes_to_indices`."""
119128

120129
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
121130

122-
indices = rearrange(indices, '... -> ... 1')
123-
codes_non_centered = (indices // self._basis) % self._levels
124-
codes = self._scale_and_shift_inverse(codes_non_centered)
131+
codes = self._indices_to_codes(indices)
125132

126133
if self.keep_num_codebooks_dim:
127134
codes = rearrange(codes, '... c d -> ... (c d)')
128135

129-
if project_out:
130-
codes = self.project_out(codes)
136+
codes = self.project_out(codes)
131137

132-
if is_img_or_video:
138+
if is_img_or_video or self.channel_first:
133139
codes = rearrange(codes, 'b ... d -> b d ...')
134140

135141
return codes
@@ -146,10 +152,11 @@ def forward(self, z: Tensor) -> Tensor:
146152

147153
orig_dtype = z.dtype
148154
is_img_or_video = z.ndim >= 4
155+
need_move_channel_last = is_img_or_video or self.channel_first
149156

150157
# standardize image or video into (batch, seq, dimension)
151158

152-
if is_img_or_video:
159+
if need_move_channel_last:
153160
z = rearrange(z, 'b d ... -> b ... d')
154161
z, ps = pack_one(z, 'b * d')
155162

@@ -180,7 +187,7 @@ def forward(self, z: Tensor) -> Tensor:
180187

181188
# reconstitute image or video dimensions
182189

183-
if is_img_or_video:
190+
if need_move_channel_last:
184191
out = unpack_one(out, ps, 'b * d')
185192
out = rearrange(out, 'b ... d -> b d ...')
186193

0 commit comments

Comments
 (0)