|
4 | 4 | import torch.distributed as distributed
|
5 | 5 | from torch.cuda.amp import autocast
|
6 | 6 |
|
7 |
| -from einops import rearrange, repeat |
| 7 | +from einops import rearrange, repeat, pack, unpack |
8 | 8 | from contextlib import contextmanager
|
9 | 9 |
|
10 | 10 | def exists(val):
|
@@ -470,6 +470,7 @@ def __init__(
|
470 | 470 | sync_codebook = False
|
471 | 471 | ):
|
472 | 472 | super().__init__()
|
| 473 | + self.dim = dim |
473 | 474 | self.heads = heads
|
474 | 475 | self.separate_codebook_per_head = separate_codebook_per_head
|
475 | 476 |
|
@@ -518,6 +519,25 @@ def codebook(self):
|
518 | 519 |
|
519 | 520 | return rearrange(codebook, '1 ... -> ...')
|
520 | 521 |
|
| 522 | + def get_codes_from_indices(self, indices): |
| 523 | + codebook = self.codebook |
| 524 | + is_multiheaded = codebook.ndim > 2 |
| 525 | + |
| 526 | + if not is_multiheaded: |
| 527 | + codes = codebook[indices] |
| 528 | + return rearrange(codes, '... h d -> ... (h d)') |
| 529 | + |
| 530 | + indices, ps = pack([indices], 'b * h') |
| 531 | + indices = rearrange(indices, 'b n h -> b h n') |
| 532 | + |
| 533 | + indices = repeat(indices, 'b h n -> b h n d', d = codebook.shape[-1]) |
| 534 | + codebook = repeat(codebook, 'h n d -> b h n d', b = indices.shape[0]) |
| 535 | + |
| 536 | + codes = codebook.gather(2, indices) |
| 537 | + codes = rearrange(codes, 'b h n d -> b n (h d)') |
| 538 | + codes, = unpack(codes, ps, 'b * d') |
| 539 | + return codes |
| 540 | + |
521 | 541 | def forward(
|
522 | 542 | self,
|
523 | 543 | x,
|
|
0 commit comments