1
+ from math import ceil
1
2
from functools import partial
2
3
from random import randrange
3
4
8
9
9
10
from einops import rearrange , repeat , pack , unpack
10
11
12
+ def round_up_multiple (num , mult ):
13
+ return ceil (num / mult ) * mult
14
+
15
+ # main class
16
+
11
17
class ResidualVQ (nn .Module ):
12
18
""" Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
13
19
def __init__ (
@@ -18,6 +24,7 @@ def __init__(
18
24
heads = 1 ,
19
25
quantize_dropout = False ,
20
26
quantize_dropout_cutoff_index = 0 ,
27
+ quantize_dropout_multiple_of = 1 ,
21
28
accept_image_fmap = False ,
22
29
** kwargs
23
30
):
@@ -32,7 +39,9 @@ def __init__(
32
39
self .quantize_dropout = quantize_dropout
33
40
34
41
assert quantize_dropout_cutoff_index >= 0
42
+
35
43
self .quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
44
+ self .quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
36
45
37
46
if not shared_codebook :
38
47
return
@@ -92,7 +101,7 @@ def forward(
92
101
x ,
93
102
return_all_codes = False
94
103
):
95
- num_quant , device = self .num_quantizers , x .device
104
+ num_quant , quant_dropout_multiple_of , device = self .num_quantizers , self . quantize_dropout_multiple_of , x .device
96
105
quantized_out = 0.
97
106
residual = x
98
107
@@ -104,6 +113,9 @@ def forward(
104
113
if should_quantize_dropout :
105
114
rand_quantize_dropout_index = randrange (self .quantize_dropout_cutoff_index , num_quant )
106
115
116
+ if quant_dropout_multiple_of != 1 :
117
+ rand_quantize_dropout_index = round_up_multiple (rand_quantize_dropout_index + 1 , quant_dropout_multiple_of ) - 1
118
+
107
119
null_indices_shape = (x .shape [0 ], * x .shape [- 2 :]) if self .accept_image_fmap else tuple (x .shape [:2 ])
108
120
109
121
for quantizer_index , layer in enumerate (self .layers ):
0 commit comments