Skip to content

Commit b1f5d8e

Browse files
committed
add code expiry / replacement strategy from soundstream paper
1 parent 9ad29ef commit b1f5d8e

File tree

3 files changed

+73
-17
lines changed

3 files changed

+73
-17
lines changed

README.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ x = torch.randn(1, 1024, 256)
9090
quantized, indices, commit_loss = vq(x)
9191
```
9292

93-
### Cosine Similarity
93+
### Cosine similarity
9494

9595
The <a href="https://openreview.net/forum?id=pfNyExj7z2">Improved VQGAN paper</a> also proposes to l2 normalize the codes and the encoded vectors, which boils down to using cosine similarity for the distance. They claim enforcing the vectors on a sphere leads to improvements in code usage and downstream reconstruction. You can turn this on by setting `use_cosine_sim = True`
9696

@@ -108,6 +108,24 @@ x = torch.randn(1, 1024, 256)
108108
quantized, indices, commit_loss = vq(x)
109109
```
110110

111+
### Expiring stale codes
112+
113+
Finally, the SoundStream paper has a scheme where they replace codes that have not been used in a certain number of consecutive batches with a randomly selected vector from the current batch. You can set this threshold for consecutive misses before replacement with `max_codebook_misses_before_expiry` keyword. (I know it is a bit long, but I couldn't think of a better name)
114+
115+
```python
116+
import torch
117+
from vector_quantize_pytorch import VectorQuantize
118+
119+
vq = VectorQuantize(
120+
dim = 256,
121+
codebook_size = 512,
122+
max_codebook_misses_before_expiry = 5 # should actively replace any codes that were missed 5 times in a row during training
123+
)
124+
125+
x = torch.randn(1, 1024, 256)
126+
quantized, indices, commit_loss = vq(x)
127+
```
128+
111129
## Citations
112130

113131
```bibtex

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '0.3.2',
6+
version = '0.3.3',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
author = 'Phil Wang',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,29 @@ def ema_inplace(moving_avg, new, decay):
1818
def laplace_smoothing(x, n_categories, eps = 1e-5):
1919
return (x + eps) / (x.sum() + n_categories * eps)
2020

21-
def kmeans(x, num_clusters, num_iters = 10, use_cosine_sim = False):
22-
samples = rearrange(x, '... d -> (...) d')
23-
num_samples, dim, dtype, device = *samples.shape, x.dtype, x.device
21+
def sample_vectors(samples, num):
22+
num_samples, device = samples.shape[0], samples.device
2423

25-
if num_samples >= num_clusters:
26-
indices = torch.randperm(num_samples, device=device)[:num_clusters]
24+
if num_samples >= num:
25+
indices = torch.randperm(num_samples, device = device)[:num]
2726
else:
28-
indices = torch.randint(0, num_samples, (num_clusters,), device=device)
27+
indices = torch.randint(0, num_samples, (num,), device = device)
2928

30-
means = samples[indices]
29+
return samples[indices]
30+
31+
def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False):
32+
dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
33+
34+
means = sample_vectors(samples, num_clusters)
3135

3236
for _ in range(num_iters):
3337
if use_cosine_sim:
3438
dists = samples @ means.t()
35-
buckets = dists.max(dim = -1).indices
3639
else:
3740
diffs = rearrange(samples, 'n d -> n () d') - rearrange(means, 'c d -> () c d')
38-
dists = (diffs ** 2).sum(dim = -1)
39-
buckets = dists.argmin(dim = -1)
41+
dists = -(diffs ** 2).sum(dim = -1)
4042

43+
buckets = dists.max(dim = -1).indices
4144
bins = torch.bincount(buckets, minlength = num_clusters)
4245
zero_mask = bins == 0
4346
bins = bins.masked_fill(zero_mask, 1)
@@ -85,14 +88,18 @@ def init_embed_(self, data):
8588
self.embed_avg.data.copy_(embed.clone())
8689
self.initted.data.copy_(torch.Tensor([True]))
8790

88-
def forward(self, x):
89-
if not self.initted:
90-
self.init_embed_(x)
91+
def replace(self, samples, mask):
92+
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
93+
self.embed.data.copy_(modified_codebook)
9194

95+
def forward(self, x):
9296
shape, dtype = x.shape, x.dtype
9397
flatten = rearrange(x, '... d -> (...) d')
9498
embed = self.embed.t()
9599

100+
if not self.initted:
101+
self.init_embed_(flatten)
102+
96103
dist = -(
97104
flatten.pow(2).sum(1, keepdim=True)
98105
- 2 * flatten @ embed
@@ -144,15 +151,20 @@ def init_embed_(self, data):
144151
self.embed.data.copy_(embed)
145152
self.initted.data.copy_(torch.Tensor([True]))
146153

154+
def replace(self, samples, mask):
155+
samples = l2norm(samples)
156+
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
157+
self.embed.data.copy_(modified_codebook)
158+
147159
def forward(self, x):
148160
shape, dtype = x.shape, x.dtype
149161
flatten = rearrange(x, '... d -> (...) d')
150162
flatten = l2norm(flatten)
151-
embed = l2norm(self.embed)
152163

153164
if not self.initted:
154165
self.init_embed_(flatten)
155166

167+
embed = l2norm(self.embed)
156168
dist = flatten @ embed.t()
157169
embed_ind = dist.max(dim = -1).indices
158170
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
@@ -187,7 +199,8 @@ def __init__(
187199
eps = 1e-5,
188200
kmeans_init = False,
189201
kmeans_iters = 10,
190-
use_cosine_sim = False
202+
use_cosine_sim = False,
203+
max_codebook_misses_before_expiry = 0
191204
):
192205
super().__init__()
193206
n_embed = default(n_embed, codebook_size)
@@ -211,20 +224,45 @@ def __init__(
211224
eps = eps
212225
)
213226

227+
self.codebook_size = codebook_size
228+
self.max_codebook_misses_before_expiry = max_codebook_misses_before_expiry
229+
230+
if max_codebook_misses_before_expiry > 0:
231+
codebook_misses = torch.zeros(codebook_size)
232+
self.register_buffer('codebook_misses', codebook_misses)
233+
214234
@property
215235
def codebook(self):
216236
return self._codebook.codebook
217237

238+
def expire_codes_(self, embed_ind, batch_samples):
239+
if self.max_codebook_misses_before_expiry == 0:
240+
return
241+
242+
embed_ind = rearrange(embed_ind, '... -> (...)')
243+
misses = torch.bincount(embed_ind, minlength = self.codebook_size) == 0
244+
self.codebook_misses += misses
245+
246+
expired_codes = self.codebook_misses >= self.max_codebook_misses_before_expiry
247+
if not torch.any(expired_codes):
248+
return
249+
250+
self.codebook_misses.masked_fill_(expired_codes, 0)
251+
batch_samples = rearrange(batch_samples, '... d -> (...) d')
252+
self._codebook.replace(batch_samples, mask = expired_codes)
253+
218254
def forward(self, x):
219255
dtype = x.dtype
220256
x = self.project_in(x)
221257

222258
quantize, embed_ind = self._codebook(x)
223259

224260
commit_loss = 0.
261+
225262
if self.training:
226263
commit_loss = F.mse_loss(quantize.detach(), x) * self.commitment
227264
quantize = x + (quantize - x).detach()
265+
self.expire_codes_(embed_ind, x)
228266

229267
quantize = self.project_out(quantize)
230268
return quantize, embed_ind, commit_loss

0 commit comments

Comments
 (0)