Skip to content

Commit 4d26b61

Browse files
committed
able to set channel first inputs for LFQ (handle channel first 1d sequences)
1 parent f8f357a commit 4d26b61

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.14.19"
3+
version = "1.14.20"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ def __init__(
8787
projection_has_bias = True,
8888
soft_clamp_input_value = None,
8989
cosine_sim_project_in = False,
90-
cosine_sim_project_in_scale = None
90+
cosine_sim_project_in_scale = None,
91+
channel_first = None
9192
):
9293
super().__init__()
9394

@@ -97,8 +98,9 @@ def __init__(
9798
assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
9899

99100
codebook_size = default(codebook_size, lambda: 2 ** dim)
100-
codebook_dim = int(log2(codebook_size))
101+
self.codebook_size = codebook_size
101102

103+
codebook_dim = int(log2(codebook_size))
102104
codebook_dims = codebook_dim * num_codebooks
103105
dim = default(dim, codebook_dims)
104106

@@ -122,6 +124,10 @@ def __init__(
122124
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
123125
self.keep_num_codebooks_dim = keep_num_codebooks_dim
124126

127+
# channel first
128+
129+
self.channel_first = channel_first
130+
125131
# straight through activation
126132

127133
self.activation = straight_through_activation
@@ -174,6 +180,7 @@ def indices_to_codes(
174180
project_out = True
175181
):
176182
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
183+
should_transpose = default(self.channel_first, is_img_or_video)
177184

178185
if not self.keep_num_codebooks_dim:
179186
indices = rearrange(indices, '... -> ... 1')
@@ -194,7 +201,7 @@ def indices_to_codes(
194201

195202
# rearrange codes back to original shape
196203

197-
if is_img_or_video:
204+
if should_transpose:
198205
codes = rearrange(codes, 'b ... d -> b d ...')
199206

200207
return codes
@@ -218,10 +225,11 @@ def forward(
218225
x = x.float()
219226

220227
is_img_or_video = x.ndim >= 4
228+
should_transpose = default(self.channel_first, is_img_or_video)
221229

222230
# standardize image or video into (batch, seq, dimension)
223231

224-
if is_img_or_video:
232+
if should_transpose:
225233
x = rearrange(x, 'b d ... -> b ... d')
226234
x, ps = pack_one(x, 'b * d')
227235

@@ -333,7 +341,7 @@ def forward(
333341

334342
# reconstitute image or video dimensions
335343

336-
if is_img_or_video:
344+
if should_transpose:
337345
x = unpack_one(x, ps, 'b * d')
338346
x = rearrange(x, 'b ... d -> b d ...')
339347

0 commit comments

Comments
 (0)