Skip to content

Commit 71b259e

Browse files
committed
add the structured quantized dropout from encodec paper
1 parent 9ecc12a commit 71b259e

File tree

3 files changed

+24
-2
lines changed

3 files changed

+24
-2
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,13 @@ if __name__ == '__main__':
320320
title = {Autoregressive Image Generation using Residual Quantization}
321321
}
322322
```
323+
324+
```bibtex
325+
@article{Defossez2022HighFN,
326+
title = {High Fidelity Neural Audio Compression},
327+
author = {Alexandre D'efossez and Jade Copet and Gabriel Synnaeve and Yossi Adi},
328+
journal = {ArXiv},
329+
year = {2022},
330+
volume = {abs/2210.13438}
331+
}
332+
```

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.0.4',
6+
version = '1.0.6',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/residual_vq.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from math import ceil
12
from functools import partial
23
from random import randrange
34

@@ -8,6 +9,11 @@
89

910
from einops import rearrange, repeat, pack, unpack
1011

12+
def round_up_multiple(num, mult):
13+
return ceil(num / mult) * mult
14+
15+
# main class
16+
1117
class ResidualVQ(nn.Module):
1218
""" Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
1319
def __init__(
@@ -18,6 +24,7 @@ def __init__(
1824
heads = 1,
1925
quantize_dropout = False,
2026
quantize_dropout_cutoff_index = 0,
27+
quantize_dropout_multiple_of = 1,
2128
accept_image_fmap = False,
2229
**kwargs
2330
):
@@ -32,7 +39,9 @@ def __init__(
3239
self.quantize_dropout = quantize_dropout
3340

3441
assert quantize_dropout_cutoff_index >= 0
42+
3543
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
44+
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
3645

3746
if not shared_codebook:
3847
return
@@ -92,7 +101,7 @@ def forward(
92101
x,
93102
return_all_codes = False
94103
):
95-
num_quant, device = self.num_quantizers, x.device
104+
num_quant, quant_dropout_multiple_of, device = self.num_quantizers, self.quantize_dropout_multiple_of, x.device
96105
quantized_out = 0.
97106
residual = x
98107

@@ -104,6 +113,9 @@ def forward(
104113
if should_quantize_dropout:
105114
rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant)
106115

116+
if quant_dropout_multiple_of != 1:
117+
rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1
118+
107119
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
108120

109121
for quantizer_index, layer in enumerate(self.layers):

0 commit comments

Comments
 (0)