Skip to content

Commit d28d851

Browse files
committed
add kmeans init of codebook, as proposed in soundstream paper, also make sure commitment loss is not calculated on eval
1 parent 43595a8 commit d28d851

File tree

3 files changed

+76
-8
lines changed

3 files changed

+76
-8
lines changed

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,26 @@ quantized, indices, commit_loss = residual_vq(x)
4848
# (batch, seq, dim), (quantizer, batch, seq), (quantizer, batch)
4949
```
5050

51+
## Initialization
52+
53+
The SoundStream paper proposes that the codebook should be initialized by the kmeans centroids of the first batch. You can easily turn on this feature with one flag `kmeans_init = True`, for either `VectorQuantize` or `ResidualVQ` class
54+
55+
```python
56+
import torch
57+
from vector_quantize_pytorch import ResidualVQ
58+
59+
residual_vq = ResidualVQ(
60+
dim = 256,
61+
codebook_size = 256,
62+
num_quantizers = 4,
63+
kmeans_init = True, # set to True
64+
kmeans_iters = 10 # number of kmeans iterations to calculate the centroids for the codebook on init
65+
)
66+
67+
x = torch.randn(1, 1024, 256)
68+
quantized, indices, commit_loss = residual_vq(x)
69+
```
70+
5171
## Citations
5272

5373
```bibtex

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '0.2.2',
6+
version = '0.3.0',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
author = 'Phil Wang',
@@ -16,6 +16,7 @@
1616
'quantization'
1717
],
1818
install_requires=[
19+
'einops',
1920
'torch'
2021
],
2122
classifiers=[

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
2-
from torch import nn
2+
from torch import nn, einsum
33
import torch.nn.functional as F
4+
from einops import rearrange, repeat
45

56
def exists(val):
67
return val is not None
@@ -11,9 +12,36 @@ def default(val, d):
1112
def ema_inplace(moving_avg, new, decay):
1213
moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay))
1314

14-
def laplace_smoothing(x, n_categories, eps=1e-5):
15+
def laplace_smoothing(x, n_categories, eps = 1e-5):
1516
return (x + eps) / (x.sum() + n_categories * eps)
1617

18+
def kmeans(x, num_clusters, num_iters = 10):
19+
samples = rearrange(x, '... d -> (...) d')
20+
num_samples, dim, dtype, device = *samples.shape, x.dtype, x.device
21+
22+
if num_samples >= num_clusters:
23+
indices = torch.randperm(num_samples, device=device)[:num_clusters]
24+
else:
25+
indices = torch.randint(0, num_samples, (num_clusters,), device=device)
26+
27+
means = samples[indices]
28+
29+
for _ in range(num_iters):
30+
diffs = rearrange(samples, 'n d -> n () d') - rearrange(means, 'c d -> () c d')
31+
dists = (diffs ** 2).sum(dim = -1)
32+
buckets = dists.argmin(dim = -1)
33+
34+
bins = torch.bincount(buckets, minlength = num_clusters)
35+
zero_mask = bins == 0
36+
bins = bins.masked_fill(zero_mask, 1)
37+
38+
new_means = buckets.new_zeros(num_clusters, dim, dtype = dtype)
39+
new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d = dim), samples)
40+
new_means = new_means / bins[..., None]
41+
means = torch.where(zero_mask[..., None], means, new_means)
42+
43+
return rearrange(means, 'n d -> d n')
44+
1745
class VectorQuantize(nn.Module):
1846
def __init__(
1947
self,
@@ -23,6 +51,8 @@ def __init__(
2351
commitment = 1.,
2452
eps = 1e-5,
2553
n_embed = None,
54+
kmeans_init = False,
55+
kmeans_iters = 10
2656
):
2757
super().__init__()
2858
n_embed = default(n_embed, codebook_size)
@@ -33,26 +63,42 @@ def __init__(
3363
self.eps = eps
3464
self.commitment = commitment
3565

36-
embed = torch.randn(dim, n_embed)
37-
self.register_buffer('embed', embed)
66+
init_fn = torch.randn if not kmeans_init else torch.zeros
67+
embed = init_fn(dim, n_embed)
68+
69+
self.kmeans_iters = kmeans_iters
70+
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
3871
self.register_buffer('cluster_size', torch.zeros(n_embed))
72+
self.register_buffer('embed', embed)
3973
self.register_buffer('embed_avg', embed.clone())
4074

4175
@property
4276
def codebook(self):
4377
return self.embed.transpose(0, 1)
4478

79+
def init_embed_(self, data):
80+
embed = kmeans(data, self.n_embed, self.kmeans_iters)
81+
self.embed.data.copy_(embed)
82+
self.embed_avg.data.copy_(embed.clone())
83+
self.initted.data.copy_(torch.Tensor([True]))
84+
4585
def forward(self, input):
86+
if not self.initted:
87+
self.init_embed_(input)
88+
4689
dtype = input.dtype
4790
flatten = input.reshape(-1, self.dim)
4891
dist = (
4992
flatten.pow(2).sum(1, keepdim=True)
5093
- 2 * flatten @ self.embed
5194
+ self.embed.pow(2).sum(0, keepdim=True)
5295
)
96+
5397
_, embed_ind = (-dist).max(1)
5498
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
5599
embed_ind = embed_ind.view(*input.shape[:-1])
100+
101+
commit_loss = 0.
56102
quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
57103

58104
if self.training:
@@ -63,6 +109,7 @@ def forward(self, input):
63109
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
64110
self.embed.data.copy_(embed_normalized)
65111

66-
loss = F.mse_loss(quantize.detach(), input) * self.commitment
67-
quantize = input + (quantize - input).detach()
68-
return quantize, embed_ind, loss
112+
commit_loss = F.mse_loss(quantize.detach(), input) * self.commitment
113+
quantize = input + (quantize - input).detach()
114+
115+
return quantize, embed_ind, commit_loss

0 commit comments

Comments
 (0)