Skip to content

Commit 50ec361

Browse files
committed
if mask is passed into VQ, output 0 for padding
1 parent 74a27c8 commit 50ec361

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.15.4"
3+
version = "1.15.5"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_readme.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_vq_mask():
5959
assert torch.allclose(quantized, mask_quantized[:, :512])
6060
assert torch.allclose(indices, mask_indices[:, :512])
6161

62-
assert torch.allclose(mask_quantized[:, 512:], x[:, 512:])
62+
assert (mask_quantized[:, 512:] == 0.).all()
6363
assert (mask_indices[:, 512:] == -1).all()
6464

6565
def test_residual_vq():

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,7 @@ def __init__(
759759
affine_param_batch_decay = 0.99,
760760
affine_param_codebook_decay = 0.9,
761761
sync_update_v = 0., # the v that controls optimistic vs pessimistic update for synchronous update rule (21) https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
762+
return_zeros_for_masked_padding = True
762763
):
763764
super().__init__()
764765
self.dim = dim
@@ -855,6 +856,9 @@ def __init__(
855856

856857
self.register_buffer('zero', torch.tensor(0.), persistent = False)
857858

859+
# for variable lengthed sequences, whether to take care of masking out the padding to 0 (or return the original input)
860+
self.return_zeros_for_masked_padding = return_zeros_for_masked_padding
861+
858862
@property
859863
def codebook(self):
860864
codebook = self._codebook.embed
@@ -1133,11 +1137,16 @@ def calculate_ce_loss(codes):
11331137
# if masking, only return quantized for where mask has True
11341138

11351139
if exists(mask):
1140+
masked_out_value = orig_input
1141+
1142+
if self.return_zeros_for_masked_padding:
1143+
masked_out_value = torch.zeros_like(orig_input)
1144+
11361145
quantize = einx.where(
11371146
'b n, b n d, b n d -> b n d',
11381147
mask,
11391148
quantize,
1140-
orig_input
1149+
masked_out_value
11411150
)
11421151

11431152
embed_ind = einx.where(

0 commit comments

Comments
 (0)