Skip to content

Commit 284a671

Browse files
committed
add residual vector quantization, from soundstream paper from google
1 parent 2c064ce commit 284a671

File tree

5 files changed

+104
-12
lines changed

5 files changed

+104
-12
lines changed

README.md

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
## Vector Quantization, in Pytorch
1+
## Vector Quantization - Pytorch
22

33
A vector quantization library originally transcribed from Deepmind's tensorflow implementation, made conveniently into a package. It uses exponential moving averages to update the dictionary.
44

@@ -26,3 +26,48 @@ vq = VectorQuantize(
2626
x = torch.randn(1, 1024, 256)
2727
quantized, indices, commit_loss = vq(x) # (1, 1024, 256), (1, 1024), (1)
2828
```
29+
30+
## Variants
31+
32+
This <a href="https://arxiv.org/abs/2107.03312">paper</a> proposes to use multiple vector quantizers to recursively quantize the residuals of the waveform. You can use this with the `ResidualVQ` class and one extra initialization parameter.
33+
34+
```python
35+
import torch
36+
from vector_quantize_pytorch import ResidualVQ
37+
38+
residual_vq = ResidualVQ(
39+
dim = 256,
40+
num_quantizers = 8, # specify number of quantizers
41+
n_embed = 1024, # codebook size
42+
)
43+
44+
x = torch.randn(1, 1024, 256)
45+
quantized, indices, commit_loss = residual_vq(x)
46+
47+
# (1, 1024, 256), (8, 1, 1024), (8, 1)
48+
# (batch, seq, dim), (quantizer, batch, seq), (quantizer, batch)
49+
```
50+
51+
## Citations
52+
53+
```bibtex
54+
@misc{oord2018neural,
55+
title = {Neural Discrete Representation Learning},
56+
author = {Aaron van den Oord and Oriol Vinyals and Koray Kavukcuoglu},
57+
year = {2018},
58+
eprint = {1711.00937},
59+
archivePrefix = {arXiv},
60+
primaryClass = {cs.LG}
61+
}
62+
```
63+
64+
```bibtex
65+
@misc{zeghidour2021soundstream,
66+
title = {SoundStream: An End-to-End Neural Audio Codec},
67+
author = {Neil Zeghidour and Alejandro Luebs and Ahmed Omran and Jan Skoglund and Marco Tagliasacchi},
68+
year = {2021},
69+
eprint = {2107.03312},
70+
archivePrefix = {arXiv},
71+
primaryClass = {cs.SD}
72+
}
73+
```

setup.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,26 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '0.1.0',
6+
version = '0.2.0',
77
license='MIT',
8-
description = 'Simple Vector Quantization, in Pytorch',
8+
description = 'Vector Quantization - Pytorch',
99
author = 'Phil Wang',
1010
author_email = 'lucidrains@gmail.com',
1111
url = 'https://github.com/lucidrains/vector-quantizer-pytorch',
12-
keywords = ['artificial intelligence', 'deep learning', 'pytorch'],
12+
keywords = [
13+
'artificial intelligence',
14+
'deep learning',
15+
'pytorch',
16+
'quantization'
17+
],
1318
install_requires=[
14-
'torch'
19+
'torch'
1520
],
1621
classifiers=[
17-
'Development Status :: 4 - Beta',
18-
'Intended Audience :: Developers',
19-
'Topic :: Scientific/Engineering :: Artificial Intelligence',
20-
'License :: OSI Approved :: MIT License',
21-
'Programming Language :: Python :: 3.6',
22+
'Development Status :: 4 - Beta',
23+
'Intended Audience :: Developers',
24+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
25+
'License :: OSI Approved :: MIT License',
26+
'Programming Language :: Python :: 3.6',
2227
],
23-
)
28+
)

vector_quantize_pytorch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
2+
from vector_quantize_pytorch.residual_vq import ResidualVQ
3+
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
from torch import nn
3+
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
4+
5+
class ResidualVQ(nn.Module):
6+
""" Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
7+
def __init__(
8+
self,
9+
*,
10+
num_quantizers,
11+
n_embed,
12+
**kwargs
13+
):
14+
super().__init__()
15+
self.layers = nn.ModuleList([VectorQuantize(n_embed = n_embed, **kwargs) for _ in range(num_quantizers)])
16+
17+
def forward(self, x):
18+
quantized_out = 0.
19+
residual = x
20+
21+
all_losses = []
22+
all_indices = []
23+
24+
for layer in self.layers:
25+
quantized, indices, loss = layer(residual)
26+
residual = residual - quantized
27+
quantized_out = quantized_out + quantized
28+
29+
all_indices.append(indices)
30+
all_losses.append(loss)
31+
32+
all_losses, all_indices = map(torch.stack, (all_losses, all_indices))
33+
return quantized_out, all_indices, all_losses

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,14 @@ def laplace_smoothing(x, n_categories, eps=1e-5):
99
return (x + eps) / (x.sum() + n_categories * eps)
1010

1111
class VectorQuantize(nn.Module):
12-
def __init__(self, dim, n_embed, decay=0.8, commitment=1., eps=1e-5):
12+
def __init__(
13+
self,
14+
dim,
15+
n_embed,
16+
decay = 0.8,
17+
commitment = 1.,
18+
eps = 1e-5
19+
):
1320
super().__init__()
1421

1522
self.dim = dim

0 commit comments

Comments
 (0)