Skip to content

Commit b16be13

Browse files
committed
add ability to easily get codes for multi-headed vq
1 parent 37620e4 commit b16be13

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.1.1',
6+
version = '1.1.2',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch.distributed as distributed
55
from torch.cuda.amp import autocast
66

7-
from einops import rearrange, repeat
7+
from einops import rearrange, repeat, pack, unpack
88
from contextlib import contextmanager
99

1010
def exists(val):
@@ -470,6 +470,7 @@ def __init__(
470470
sync_codebook = False
471471
):
472472
super().__init__()
473+
self.dim = dim
473474
self.heads = heads
474475
self.separate_codebook_per_head = separate_codebook_per_head
475476

@@ -518,6 +519,25 @@ def codebook(self):
518519

519520
return rearrange(codebook, '1 ... -> ...')
520521

522+
def get_codes_from_indices(self, indices):
523+
codebook = self.codebook
524+
is_multiheaded = codebook.ndim > 2
525+
526+
if not is_multiheaded:
527+
codes = codebook[indices]
528+
return rearrange(codes, '... h d -> ... (h d)')
529+
530+
indices, ps = pack([indices], 'b * h')
531+
indices = rearrange(indices, 'b n h -> b h n')
532+
533+
indices = repeat(indices, 'b h n -> b h n d', d = codebook.shape[-1])
534+
codebook = repeat(codebook, 'h n d -> b h n d', b = indices.shape[0])
535+
536+
codes = codebook.gather(2, indices)
537+
codes = rearrange(codes, 'b h n d -> b n (h d)')
538+
codes, = unpack(codes, ps, 'b * d')
539+
return codes
540+
521541
def forward(
522542
self,
523543
x,

0 commit comments

Comments
 (0)