Skip to content

Commit 3f8db79

Browse files
committed
address #204
1 parent 2ebb9d1 commit 3f8db79

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.22.3"
3+
version = "1.22.4"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -559,9 +559,18 @@ def forward(
559559

560560
else:
561561
if exists(codebook_transform_fn):
562-
quantize = einx.get_at('h b n [c] d, h b n -> h b n d', transformed_embed, embed_ind)
562+
# quantize = einx.get_at('h b n [c] d, h b n -> h b n d', transformed_embed, embed_ind)
563+
564+
repeated_embed_ind = repeat(embed_ind, 'h b n -> h b n 1 d', d = transformed_embed.shape[-1])
565+
quantize = transformed_embed.gather(-2, repeated_embed_ind)
566+
quantize = rearrange(quantize, 'h b n 1 d -> h b n d')
567+
563568
else:
564-
quantize = einx.get_at('h [c] d, h b n -> h b n d', embed, embed_ind)
569+
# quantize = einx.get_at('h [c] d, h b n -> h b n d', embed, embed_ind)
570+
571+
repeated_embed = repeat(embed, 'h c d -> h b c d', b = embed_ind.shape[1])
572+
repeated_embed_ind = repeat(embed_ind, 'h b n -> h b n d', d = embed.shape[-1])
573+
quantize = repeated_embed.gather(-2, repeated_embed_ind)
565574

566575
if self.training and self.ema_update and not freeze_codebook:
567576

@@ -767,9 +776,18 @@ def forward(
767776

768777
else:
769778
if exists(codebook_transform_fn):
770-
quantize = einx.get_at('h b n [c] d, h b n -> h b n d', transformed_embed, embed_ind)
779+
# quantize = einx.get_at('h b n [c] d, h b n -> h b n d', transformed_embed, embed_ind)
780+
781+
repeated_embed_ind = repeat(embed_ind, 'h b n -> h b n 1 d', d = transformed_embed.shape[-1])
782+
quantize = transformed_embed.gather(-2, repeated_embed_ind)
783+
quantize = rearrange(quantize, 'h b n 1 d -> h b n d')
784+
771785
else:
772-
quantize = einx.get_at('h [c] d, h b n -> h b n d', embed, embed_ind)
786+
# quantize = einx.get_at('h [c] d, h b n -> h b n d', embed, embed_ind)
787+
788+
repeated_embed = repeat(embed, 'h c d -> h b c d', b = embed_ind.shape[1])
789+
repeated_embed_ind = repeat(embed_ind, 'h b n -> h b n d', d = embed.shape[-1])
790+
quantize = repeated_embed.gather(-2, repeated_embed_ind)
773791

774792
if self.training and self.ema_update and not freeze_codebook:
775793
if exists(mask):

0 commit comments

Comments
 (0)