Skip to content

Commit 74a27c8

Browse files
committed
address #149
1 parent 4c514db commit 74a27c8

File tree

3 files changed

+54
-5
lines changed

3 files changed

+54
-5
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.15.3"
3+
version = "1.15.4"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -25,7 +25,7 @@ classifiers=[
2525
dependencies = [
2626
"torch>=2.0",
2727
"einops>=0.8.0",
28-
"einx>=0.2.2",
28+
"einx>=0.3.0",
2929
]
3030

3131
[project.urls]

tests/test_readme.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,31 @@ def test_vq_eval():
3737
quantized, indices, commit_loss = vq(x)
3838
assert torch.allclose(quantized, vq.get_output_from_indices(indices))
3939

40+
def test_vq_mask():
41+
from vector_quantize_pytorch import VectorQuantize
42+
43+
vq = VectorQuantize(
44+
dim = 256,
45+
codebook_size = 512, # codebook size
46+
decay = 1., # the exponential moving average decay, lower means the dictionary will change faster
47+
commitment_weight = 1. # the weight on the commitment loss
48+
)
49+
50+
x = torch.randn(1, 1024, 256)
51+
lens = torch.full((1,), 512)
52+
53+
vq.train()
54+
55+
quantized, indices, commit_loss = vq(x[:, :512])
56+
mask_quantized, mask_indices, mask_commit_loss = vq(x, lens = lens)
57+
58+
assert torch.allclose(commit_loss, mask_commit_loss)
59+
assert torch.allclose(quantized, mask_quantized[:, :512])
60+
assert torch.allclose(indices, mask_indices[:, :512])
61+
62+
assert torch.allclose(mask_quantized[:, 512:], x[:, 512:])
63+
assert (mask_indices[:, 512:] == -1).all()
64+
4065
def test_residual_vq():
4166
from vector_quantize_pytorch import ResidualVQ
4267

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33

44
import torch
55
from torch.nn import Module
6-
from torch import nn, einsum
6+
from torch import nn, einsum, Tensor
77
import torch.nn.functional as F
88
import torch.distributed as distributed
99
from torch.optim import Optimizer
1010
from torch.cuda.amp import autocast
1111

12+
import einx
1213
from einops import rearrange, repeat, reduce, pack, unpack
1314

1415
from typing import Callable
@@ -63,6 +64,10 @@ def pack_one(t, pattern):
6364
def unpack_one(t, ps, pattern):
6465
return unpack(t, ps, pattern)[0]
6566

67+
def lens_to_mask(lens, max_length):
68+
seq = torch.arange(max_length, device = lens.device)
69+
return seq < lens[:, None]
70+
6671
def uniform_init(*shape):
6772
t = torch.empty(shape)
6873
nn.init.kaiming_uniform_(t)
@@ -897,12 +902,22 @@ def forward(
897902
x,
898903
indices = None,
899904
mask = None,
905+
lens = None,
900906
sample_codebook_temp = None,
901907
freeze_codebook = False,
902908
return_loss_breakdown = False,
903909
):
904910
orig_input = x
905911

912+
# handle masking, either passed in as `mask` or `lens`
913+
914+
assert not (exists(mask) and exists(lens))
915+
916+
if exists(lens):
917+
mask = lens_to_mask(lens, x.shape[1])
918+
919+
# handle one token given
920+
906921
only_one = x.ndim == 2
907922

908923
if only_one:
@@ -917,6 +932,7 @@ def forward(
917932
# rearrange inputs
918933

919934
if self.accept_image_fmap:
935+
assert not exists(mask)
920936
height, width = x.shape[-2:]
921937
x = rearrange(x, 'b c h w -> b (h w) c')
922938

@@ -1117,12 +1133,20 @@ def calculate_ce_loss(codes):
11171133
# if masking, only return quantized for where mask has True
11181134

11191135
if exists(mask):
1120-
quantize = torch.where(
1121-
rearrange(mask, '... -> ... 1'),
1136+
quantize = einx.where(
1137+
'b n, b n d, b n d -> b n d',
1138+
mask,
11221139
quantize,
11231140
orig_input
11241141
)
11251142

1143+
embed_ind = einx.where(
1144+
'b n, b n ..., -> b n ...',
1145+
mask,
1146+
embed_ind,
1147+
-1
1148+
)
1149+
11261150
if not return_loss_breakdown:
11271151
return quantize, embed_ind, loss
11281152

0 commit comments

Comments
 (0)