Skip to content

Commit a1b4a71

Browse files
committed
add ability to use cosine similarity for measuring distance to codes
1 parent 345ae69 commit a1b4a71

File tree

3 files changed

+173
-51
lines changed

3 files changed

+173
-51
lines changed

README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,24 @@ x = torch.randn(1, 1024, 256)
9090
quantized, indices, commit_loss = vq(x)
9191
```
9292

93+
### Cosine Similarity
94+
95+
The <a href="https://openreview.net/forum?id=pfNyExj7z2">Improved VQGAN paper</a> also proposes to l2 normalize the codes and the encoded vectors, which boils down to using cosine similarity for the distance. They claim enforcing the vectors on a sphere leads to improvements in code usage and downstream reconstruction. You can turn this on by setting `use_cosine_sim = True`
96+
97+
```python
98+
import torch
99+
from vector_quantize_pytorch import VectorQuantize
100+
101+
vq = VectorQuantize(
102+
dim = 256,
103+
codebook_size = 256,
104+
use_cosine_sim = True # set this to True
105+
)
106+
107+
x = torch.randn(1, 1024, 256)
108+
quantized, indices, commit_loss = vq(x)
109+
```
110+
93111
## Citations
94112

95113
```bibtex

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '0.3.1',
6+
version = '0.3.2',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
author = 'Phil Wang',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 154 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def ema_inplace(moving_avg, new, decay):
1515
def laplace_smoothing(x, n_categories, eps = 1e-5):
1616
return (x + eps) / (x.sum() + n_categories * eps)
1717

18-
def kmeans(x, num_clusters, num_iters = 10):
18+
def kmeans(x, num_clusters, num_iters = 10, use_cosine_sim = False):
1919
samples = rearrange(x, '... d -> (...) d')
2020
num_samples, dim, dtype, device = *samples.shape, x.dtype, x.device
2121

@@ -27,9 +27,13 @@ def kmeans(x, num_clusters, num_iters = 10):
2727
means = samples[indices]
2828

2929
for _ in range(num_iters):
30-
diffs = rearrange(samples, 'n d -> n () d') - rearrange(means, 'c d -> () c d')
31-
dists = (diffs ** 2).sum(dim = -1)
32-
buckets = dists.argmin(dim = -1)
30+
if use_cosine_sim:
31+
dists = samples @ means.t()
32+
buckets = dists.max(dim = -1).indices
33+
else:
34+
diffs = rearrange(samples, 'n d -> n () d') - rearrange(means, 'c d -> () c d')
35+
dists = (diffs ** 2).sum(dim = -1)
36+
buckets = dists.argmin(dim = -1)
3337

3438
bins = torch.bincount(buckets, minlength = num_clusters)
3539
zero_mask = bins == 0
@@ -38,86 +42,186 @@ def kmeans(x, num_clusters, num_iters = 10):
3842
new_means = buckets.new_zeros(num_clusters, dim, dtype = dtype)
3943
new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d = dim), samples)
4044
new_means = new_means / bins[..., None]
45+
46+
if use_cosine_sim:
47+
new_means = F.normalize(new_means, dim = -1)
48+
4149
means = torch.where(zero_mask[..., None], means, new_means)
4250

43-
return rearrange(means, 'n d -> d n')
51+
return means
4452

45-
class VectorQuantize(nn.Module):
53+
# distance types
54+
55+
class EuclideanCodebook(nn.Module):
4656
def __init__(
4757
self,
4858
dim,
4959
codebook_size,
50-
decay = 0.8,
51-
commitment = 1.,
52-
eps = 1e-5,
53-
n_embed = None,
5460
kmeans_init = False,
5561
kmeans_iters = 10,
56-
codebook_dim = None
62+
decay = 0.8,
63+
eps = 1e-5
5764
):
5865
super().__init__()
59-
n_embed = default(n_embed, codebook_size)
60-
self.n_embed = n_embed
61-
62-
codebook_dim = default(codebook_dim, dim)
63-
requires_projection = codebook_dim != dim
64-
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
65-
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
66-
6766
self.decay = decay
68-
self.eps = eps
69-
self.commitment = commitment
70-
7167
init_fn = torch.randn if not kmeans_init else torch.zeros
72-
embed = init_fn(codebook_dim, n_embed)
68+
embed = init_fn(codebook_size, dim)
7369

70+
self.codebook_size = codebook_size
7471
self.kmeans_iters = kmeans_iters
72+
self.eps = eps
73+
7574
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
76-
self.register_buffer('cluster_size', torch.zeros(n_embed))
75+
self.register_buffer('cluster_size', torch.zeros(codebook_size))
7776
self.register_buffer('embed', embed)
7877
self.register_buffer('embed_avg', embed.clone())
7978

80-
@property
81-
def codebook(self):
82-
return self.embed.transpose(0, 1)
83-
8479
def init_embed_(self, data):
85-
embed = kmeans(data, self.n_embed, self.kmeans_iters)
80+
embed = kmeans(data, self.codebook_size, self.kmeans_iters)
8681
self.embed.data.copy_(embed)
8782
self.embed_avg.data.copy_(embed.clone())
8883
self.initted.data.copy_(torch.Tensor([True]))
8984

90-
def forward(self, input):
91-
input = self.project_in(input)
92-
85+
def forward(self, x):
9386
if not self.initted:
94-
self.init_embed_(input)
87+
self.init_embed_(x)
88+
89+
shape, dtype = x.shape, x.dtype
90+
flatten = rearrange(x, '... d -> (...) d')
91+
embed = self.embed.t()
9592

96-
dtype = input.dtype
97-
flatten = rearrange(input, '... d -> (...) d')
98-
dist = (
93+
dist = -(
9994
flatten.pow(2).sum(1, keepdim=True)
100-
- 2 * flatten @ self.embed
101-
+ self.embed.pow(2).sum(0, keepdim=True)
95+
- 2 * flatten @ embed
96+
+ embed.pow(2).sum(0, keepdim=True)
10297
)
10398

104-
_, embed_ind = (-dist).max(1)
105-
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
106-
embed_ind = embed_ind.view(*input.shape[:-1])
107-
108-
commit_loss = 0.
109-
quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
99+
embed_ind = dist.max(dim = -1).indices
100+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(x.dtype)
101+
embed_ind = embed_ind.view(*shape[:-1])
102+
quantize = F.embedding(embed_ind, self.embed)
110103

111104
if self.training:
112105
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
113-
embed_sum = flatten.transpose(0, 1) @ embed_onehot
114-
ema_inplace(self.embed_avg, embed_sum, self.decay)
115-
cluster_size = laplace_smoothing(self.cluster_size, self.n_embed, self.eps) * self.cluster_size.sum()
116-
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
106+
embed_sum = flatten.t() @ embed_onehot
107+
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
108+
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum()
109+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
117110
self.embed.data.copy_(embed_normalized)
118111

119-
commit_loss = F.mse_loss(quantize.detach(), input) * self.commitment
120-
quantize = input + (quantize - input).detach()
112+
return quantize, embed_ind
113+
114+
class CosineSimCodebook(nn.Module):
115+
def __init__(
116+
self,
117+
dim,
118+
codebook_size,
119+
kmeans_init = False,
120+
kmeans_iters = 10,
121+
decay = 0.8,
122+
eps = 1e-5
123+
):
124+
super().__init__()
125+
self.decay = decay
126+
127+
if not kmeans_init:
128+
embed = F.normalize(torch.randn(codebook_size, dim), dim = -1)
129+
else:
130+
embed = torch.zeros(codebook_size, dim)
131+
132+
self.codebook_size = codebook_size
133+
self.kmeans_iters = kmeans_iters
134+
self.eps = eps
135+
136+
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
137+
self.register_buffer('embed', embed)
138+
139+
def init_embed_(self, data):
140+
embed = kmeans(data, self.codebook_size, self.kmeans_iters, use_cosine_sim = True)
141+
self.embed.data.copy_(embed)
142+
self.initted.data.copy_(torch.Tensor([True]))
143+
144+
def forward(self, x):
145+
shape, dtype = x.shape, x.dtype
146+
flatten = rearrange(x, '... d -> (...) d')
147+
flatten = F.normalize(flatten, dim = -1)
148+
embed = F.normalize(self.embed, dim = - 1)
149+
150+
if not self.initted:
151+
self.init_embed_(flatten)
152+
153+
dist = flatten @ embed.t()
154+
embed_ind = dist.max(dim = -1).indices
155+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
156+
embed_ind = embed_ind.view(*shape[:-1])
157+
158+
quantize = F.embedding(embed_ind, self.embed)
159+
160+
if self.training:
161+
bins = embed_onehot.sum(0)
162+
zero_mask = (bins == 0)
163+
bins = bins.masked_fill(zero_mask, 1.)
164+
165+
embed_sum = flatten.t() @ embed_onehot
166+
embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
167+
embed_normalized = F.normalize(embed_normalized, dim = -1)
168+
embed_normalized = torch.where(zero_mask[..., None], embed, embed_normalized)
169+
ema_inplace(self.embed, embed_normalized, self.decay)
170+
171+
return quantize, embed_ind
172+
173+
# main class
174+
175+
class VectorQuantize(nn.Module):
176+
def __init__(
177+
self,
178+
dim,
179+
codebook_size,
180+
n_embed = None,
181+
codebook_dim = None,
182+
decay = 0.8,
183+
commitment = 1.,
184+
eps = 1e-5,
185+
kmeans_init = False,
186+
kmeans_iters = 10,
187+
use_cosine_sim = False
188+
):
189+
super().__init__()
190+
n_embed = default(n_embed, codebook_size)
191+
192+
codebook_dim = default(codebook_dim, dim)
193+
requires_projection = codebook_dim != dim
194+
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
195+
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
196+
197+
self.eps = eps
198+
self.commitment = commitment
199+
200+
klass = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook
201+
202+
self._codebook = klass(
203+
dim = codebook_dim,
204+
codebook_size = n_embed,
205+
kmeans_init = kmeans_init,
206+
kmeans_iters = kmeans_iters,
207+
decay = decay,
208+
eps = eps
209+
)
210+
211+
@property
212+
def codebook(self):
213+
return self._codebook.codebook
214+
215+
def forward(self, x):
216+
dtype = x.dtype
217+
x = self.project_in(x)
218+
219+
quantize, embed_ind = self._codebook(x)
220+
221+
commit_loss = 0.
222+
if self.training:
223+
commit_loss = F.mse_loss(quantize.detach(), x) * self.commitment
224+
quantize = x + (quantize - x).detach()
121225

122226
quantize = self.project_out(quantize)
123227
return quantize, embed_ind, commit_loss

0 commit comments

Comments
 (0)