Skip to content

Commit 567dc9c

Browse files
committed
add channel_last, for ability to define whether input has features at last dimension, default to true
1 parent baf249e commit 567dc9c

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '0.3.7',
6+
version = '0.3.8',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
author = 'Phil Wang',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,8 @@ def __init__(
242242
kmeans_init = False,
243243
kmeans_iters = 10,
244244
use_cosine_sim = False,
245-
threshold_ema_dead_code = 0
245+
threshold_ema_dead_code = 0,
246+
channel_last = True
246247
):
247248
super().__init__()
248249
n_embed = default(n_embed, codebook_size)
@@ -271,12 +272,18 @@ def __init__(
271272
)
272273

273274
self.codebook_size = codebook_size
275+
self.channel_last = channel_last
274276

275277
@property
276278
def codebook(self):
277279
return self._codebook.codebook
278280

279281
def forward(self, x):
282+
need_transpose = not self.channel_last
283+
284+
if need_transpose:
285+
x = rearrange(x, 'b n d -> b d n')
286+
280287
x = self.project_in(x)
281288

282289
quantize, embed_ind = self._codebook(x)
@@ -288,4 +295,8 @@ def forward(self, x):
288295
quantize = x + (quantize - x).detach()
289296

290297
quantize = self.project_out(quantize)
298+
299+
if need_transpose:
300+
quantize = rearrange(quantize, 'b n d -> b d n')
301+
291302
return quantize, embed_ind, commit_loss

0 commit comments

Comments
 (0)