Skip to content

Commit 4ae176d

Browse files
committed
fix channel_last = False for VectorQuantize when using get_output_from_indices
1 parent 0024008 commit 4ae176d

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.14.4',
6+
version = '1.14.5',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -826,17 +826,21 @@ def get_codes_from_indices(self, indices):
826826
is_multiheaded = codebook.ndim > 2
827827

828828
if not is_multiheaded:
829-
return codebook[indices]
829+
codes = codebook[indices]
830+
else:
831+
indices, ps = pack_one(indices, 'b * h')
832+
indices = rearrange(indices, 'b n h -> b h n')
833+
834+
indices = repeat(indices, 'b h n -> b h n d', d = codebook.shape[-1])
835+
codebook = repeat(codebook, 'h n d -> b h n d', b = indices.shape[0])
830836

831-
indices, ps = pack_one(indices, 'b * h')
832-
indices = rearrange(indices, 'b n h -> b h n')
837+
codes = codebook.gather(2, indices)
838+
codes = rearrange(codes, 'b h n d -> b n (h d)')
839+
codes = unpack_one(codes, ps, 'b * d')
833840

834-
indices = repeat(indices, 'b h n -> b h n d', d = codebook.shape[-1])
835-
codebook = repeat(codebook, 'h n d -> b h n d', b = indices.shape[0])
841+
if not self.channel_last:
842+
codes = rearrange(codes, 'b ... d -> b d ...')
836843

837-
codes = codebook.gather(2, indices)
838-
codes = rearrange(codes, 'b h n d -> b n (h d)')
839-
codes = unpack_one(codes, ps, 'b * d')
840844
return codes
841845

842846
def get_output_from_indices(self, indices):

0 commit comments

Comments
 (0)