File tree Expand file tree Collapse file tree 2 files changed +13
-2
lines changed Expand file tree Collapse file tree 2 files changed +13
-2
lines changed Original file line number Diff line number Diff line change 3
3
setup (
4
4
name = 'vector_quantize_pytorch' ,
5
5
packages = find_packages (),
6
- version = '0.3.7 ' ,
6
+ version = '0.3.8 ' ,
7
7
license = 'MIT' ,
8
8
description = 'Vector Quantization - Pytorch' ,
9
9
author = 'Phil Wang' ,
Original file line number Diff line number Diff line change @@ -242,7 +242,8 @@ def __init__(
242
242
kmeans_init = False ,
243
243
kmeans_iters = 10 ,
244
244
use_cosine_sim = False ,
245
- threshold_ema_dead_code = 0
245
+ threshold_ema_dead_code = 0 ,
246
+ channel_last = True
246
247
):
247
248
super ().__init__ ()
248
249
n_embed = default (n_embed , codebook_size )
@@ -271,12 +272,18 @@ def __init__(
271
272
)
272
273
273
274
self .codebook_size = codebook_size
275
+ self .channel_last = channel_last
274
276
275
277
@property
276
278
def codebook (self ):
277
279
return self ._codebook .codebook
278
280
279
281
def forward (self , x ):
282
+ need_transpose = not self .channel_last
283
+
284
+ if need_transpose :
285
+ x = rearrange (x , 'b n d -> b d n' )
286
+
280
287
x = self .project_in (x )
281
288
282
289
quantize , embed_ind = self ._codebook (x )
@@ -288,4 +295,8 @@ def forward(self, x):
288
295
quantize = x + (quantize - x ).detach ()
289
296
290
297
quantize = self .project_out (quantize )
298
+
299
+ if need_transpose :
300
+ quantize = rearrange (quantize , 'b n d -> b d n' )
301
+
291
302
return quantize , embed_ind , commit_loss
You can’t perform that action at this time.
0 commit comments