Skip to content

Commit 4fc8e0a

Browse files
committed
support variable lengthed sequences for residual VQ and LFQ
1 parent b5da2a9 commit 4fc8e0a

File tree

4 files changed

+30
-7
lines changed

4 files changed

+30
-7
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.11.8',
6+
version = '1.12.0',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ def forward(
161161
self,
162162
x,
163163
inv_temperature = 100.,
164-
return_loss_breakdown = False
164+
return_loss_breakdown = False,
165+
mask = None,
165166
):
166167
"""
167168
einstein notation
@@ -216,8 +217,14 @@ def forward(
216217

217218
per_sample_entropy = entropy(prob).mean()
218219

220+
# account for mask
221+
222+
if exists(mask):
223+
prob = prob[mask]
224+
219225
# distribution over all available tokens in the batch
220-
avg_prob = reduce(prob, 'b n c d -> c d', 'mean')
226+
227+
avg_prob = reduce(prob, '... c d -> c d', 'mean')
221228
codebook_entropy = entropy(avg_prob).mean()
222229

223230
# 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
@@ -231,7 +238,12 @@ def forward(
231238
# commit loss
232239

233240
if self.training:
234-
commit_loss = F.mse_loss(original_input, quantized.detach())
241+
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')
242+
243+
if exists(mask):
244+
commit_loss = commit_loss[mask]
245+
246+
commit_loss = commit_loss.mean()
235247
else:
236248
commit_loss = self.zero
237249

vector_quantize_pytorch/residual_lfq.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def get_output_from_indices(self, indices):
122122
def forward(
123123
self,
124124
x,
125+
mask = None,
125126
return_all_codes = False,
126127
rand_quantize_dropout_fixed_seed = None
127128
):
@@ -161,7 +162,7 @@ def forward(
161162
all_losses.append(null_loss)
162163
continue
163164

164-
quantized, indices, loss = layer(residual)
165+
quantized, indices, loss = layer(residual, mask = mask)
165166

166167
residual = residual - quantized.detach()
167168
quantized_out = quantized_out + quantized
@@ -236,6 +237,7 @@ def get_output_from_indices(self, indices):
236237
def forward(
237238
self,
238239
x,
240+
mask = None,
239241
return_all_codes = False
240242
):
241243
shape, split_dim = x.shape, self.split_dim
@@ -246,6 +248,7 @@ def forward(
246248
x = x.chunk(self.groups, dim = split_dim)
247249

248250
forward_kwargs = dict(
251+
mask = mask,
249252
return_all_codes = return_all_codes,
250253
rand_quantize_dropout_fixed_seed = random.randint(0, 1e7)
251254
)

vector_quantize_pytorch/residual_vq.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def get_output_from_indices(self, indices):
124124
def forward(
125125
self,
126126
x,
127+
mask = None,
127128
indices = None,
128129
return_all_codes = False,
129130
sample_codebook_temp = None,
@@ -175,7 +176,12 @@ def forward(
175176
if return_loss:
176177
layer_indices = indices[..., quantizer_index]
177178

178-
quantized, *rest = layer(residual, indices = layer_indices, sample_codebook_temp = sample_codebook_temp)
179+
quantized, *rest = layer(
180+
residual,
181+
mask = mask,
182+
indices = layer_indices,
183+
sample_codebook_temp = sample_codebook_temp,
184+
)
179185

180186
residual = residual - quantized.detach()
181187
quantized_out = quantized_out + quantized
@@ -263,7 +269,8 @@ def forward(
263269
x,
264270
indices = None,
265271
return_all_codes = False,
266-
sample_codebook_temp = None
272+
sample_codebook_temp = None,
273+
mask = None,
267274
):
268275
shape, split_dim = x.shape, self.split_dim
269276
assert shape[split_dim] == self.dim
@@ -279,6 +286,7 @@ def forward(
279286
forward_kwargs = dict(
280287
return_all_codes = return_all_codes,
281288
sample_codebook_temp = sample_codebook_temp,
289+
mask = mask,
282290
rand_quantize_dropout_fixed_seed = random.randint(0, 1e7)
283291
)
284292

0 commit comments

Comments
 (0)