Skip to content

Commit 905e438

Browse files
committed
offer get_output_from_indices for VectorQuantize, ResidualVQ, and ResidualLFQ - which takes into account linear projection from codebook dimension back to the input dimension
1 parent 40e499d commit 905e438

File tree

4 files changed

+18
-4
lines changed

4 files changed

+18
-4
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.10.1',
6+
version = '1.10.2',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/residual_lfq.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from vector_quantize_pytorch.lookup_free_quantization import LFQ
1212

13-
from einops import rearrange, repeat, pack, unpack
13+
from einops import rearrange, repeat, reduce, pack, unpack
1414

1515
# helper functions
1616

@@ -111,6 +111,11 @@ def get_codes_from_indices(self, indices):
111111

112112
return all_codes
113113

114+
def get_output_from_indices(self, indices):
115+
codes = self.get_codes_from_indices(indices)
116+
codes_summed = reduce(codes, 'q ... -> ...', 'sum')
117+
return self.project_out(codes_summed)
118+
114119
def forward(
115120
self,
116121
x,

vector_quantize_pytorch/residual_vq.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch.nn.functional as F
99
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
1010

11-
from einops import rearrange, repeat, pack, unpack
11+
from einops import rearrange, repeat, reduce, pack, unpack
1212

1313
# helper functions
1414

@@ -113,6 +113,11 @@ def get_codes_from_indices(self, indices):
113113

114114
return all_codes
115115

116+
def get_output_from_indices(self, indices):
117+
codes = self.get_codes_from_indices(indices)
118+
codes_summed = reduce(codes, 'q ... -> ...', 'sum')
119+
return self.project_out(codes_summed)
120+
116121
def forward(
117122
self,
118123
x,

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,10 @@ def get_codes_from_indices(self, indices):
834834
codes = unpack_one(codes, ps, 'b * d')
835835
return codes
836836

837+
def get_output_from_indices(self, indices):
838+
codes = self.get_codes_from_indices(indices)
839+
return self.project_out(codes)
840+
837841
def forward(
838842
self,
839843
x,
@@ -964,7 +968,7 @@ def calculate_ce_loss(codes):
964968
embed_ind = rearrange(embed_ind, 'b (h w) ... -> b h w ...', h = height, w = width)
965969

966970
if only_one:
967-
embed_ind = rearrange(embed_ind, 'b 1 -> b')
971+
embed_ind = rearrange(embed_ind, 'b 1 ... -> b ...')
968972

969973
# aggregate loss
970974

0 commit comments

Comments
 (0)