22
22
def exists (val ):
23
23
return val is not None
24
24
25
+ def first (it ):
26
+ return it [0 ]
27
+
25
28
def default (val , d ):
26
29
return val if exists (val ) else d
27
30
@@ -34,6 +37,64 @@ def round_up_multiple(num, mult):
34
37
def is_distributed ():
35
38
return dist .is_initialized () and dist .get_world_size () > 1
36
39
40
+ # the mlp for generating the neural implicit codebook
41
+ # from Huijben et al. https://arxiv.org/abs/2401.14732
42
+
43
+ class MLP (Module ):
44
+ def __init__ (
45
+ self ,
46
+ dim ,
47
+ dim_hidden = None ,
48
+ depth = 4 , # they used 4 layers in the paper
49
+ l2norm_output = False
50
+ ):
51
+ super ().__init__ ()
52
+ dim_hidden = default (dim_hidden , dim )
53
+
54
+ self .proj_in = nn .Linear (2 * dim , dim )
55
+
56
+ layers = ModuleList ([])
57
+
58
+ for _ in range (depth ):
59
+ layers .append (nn .Sequential (
60
+ nn .Linear (dim , dim_hidden ),
61
+ nn .SiLU (),
62
+ nn .Linear (dim_hidden , dim )
63
+ ))
64
+
65
+ self .layers = layers
66
+ self .l2norm_output = l2norm_output
67
+
68
+ def forward (
69
+ self ,
70
+ codes ,
71
+ * ,
72
+ condition
73
+ ):
74
+ one_headed = codes .ndim == 2
75
+
76
+ if one_headed :
77
+ codes = rearrange (codes , 'c d -> 1 c d' )
78
+
79
+ heads , num_codes , batch , seq_len = codes .shape [0 ], codes .shape [- 2 ], condition .shape [0 ], condition .shape [- 2 ]
80
+
81
+ codes = repeat (codes , 'h c d -> h b n c d' , n = seq_len , b = batch )
82
+ condition = repeat (condition , 'b n d -> h b n c d' , c = num_codes , h = heads )
83
+
84
+ x = torch .cat ((condition , codes ), dim = - 1 )
85
+ x = self .proj_in (x )
86
+
87
+ for layer in self .layers :
88
+ x = layer (x ) + x
89
+
90
+ if self .l2norm_output :
91
+ x = F .normalize (x , dim = - 1 )
92
+
93
+ if not one_headed :
94
+ return x
95
+
96
+ return rearrange (x , '1 ... -> ...' )
97
+
37
98
# main class
38
99
39
100
class ResidualVQ (Module ):
@@ -50,7 +111,9 @@ def __init__(
50
111
quantize_dropout_cutoff_index = 0 ,
51
112
quantize_dropout_multiple_of = 1 ,
52
113
accept_image_fmap = False ,
53
- ** kwargs
114
+ implicit_neural_codebook = False , # QINCo from https://arxiv.org/abs/2401.14732
115
+ mlp_kwargs : dict = dict (),
116
+ ** vq_kwargs
54
117
):
55
118
super ().__init__ ()
56
119
assert heads == 1 , 'residual vq is not compatible with multi-headed codes'
@@ -65,7 +128,16 @@ def __init__(
65
128
self .num_quantizers = num_quantizers
66
129
67
130
self .accept_image_fmap = accept_image_fmap
68
- self .layers = ModuleList ([VectorQuantize (dim = codebook_dim , codebook_dim = codebook_dim , accept_image_fmap = accept_image_fmap , ** kwargs ) for _ in range (num_quantizers )])
131
+
132
+ self .implicit_neural_codebook = implicit_neural_codebook
133
+
134
+ if implicit_neural_codebook :
135
+ vq_kwargs .update (
136
+ learnable_codebook = True ,
137
+ ema_update = False
138
+ )
139
+
140
+ self .layers = ModuleList ([VectorQuantize (dim = codebook_dim , codebook_dim = codebook_dim , accept_image_fmap = accept_image_fmap , ** vq_kwargs ) for _ in range (num_quantizers )])
69
141
70
142
assert all ([not vq .has_projections for vq in self .layers ])
71
143
@@ -76,6 +148,12 @@ def __init__(
76
148
self .quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
77
149
self .quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
78
150
151
+ # setting up the MLPs for implicit neural codebooks
152
+
153
+ self .mlps = ModuleList ([MLP (dim = codebook_dim , l2norm_output = first (self .layers ).use_cosine_sim , ** mlp_kwargs ) for _ in range (num_quantizers - 1 )])
154
+
155
+ # sharing codebook logic
156
+
79
157
if not shared_codebook :
80
158
return
81
159
@@ -120,7 +198,31 @@ def get_codes_from_indices(self, indices):
120
198
mask = indices == - 1.
121
199
indices = indices .masked_fill (mask , 0 ) # have it fetch a dummy code to be masked out later
122
200
123
- all_codes = get_at ('q [c] d, b n q -> q b n d' , self .codebooks , indices )
201
+ if not self .implicit_neural_codebook :
202
+ # gather all the codes
203
+
204
+ all_codes = get_at ('q [c] d, b n q -> q b n d' , self .codebooks , indices )
205
+
206
+ else :
207
+ # else if using implicit neural codebook, codes will need to be derived layer by layer
208
+
209
+ code_transform_mlps = (None , * self .mlps )
210
+
211
+ all_codes = []
212
+ quantized_out = 0.
213
+
214
+ for codes , indices , maybe_transform_mlp in zip (self .codebooks , indices .unbind (dim = - 1 ), code_transform_mlps ):
215
+
216
+ if exists (maybe_transform_mlp ):
217
+ codes = maybe_transform_mlp (codes , condition = quantized_out )
218
+ layer_codes = get_at ('b n [c] d, b n -> b n d' , codes , indices )
219
+ else :
220
+ layer_codes = get_at ('[c] d, b n -> b n d' , codes , indices )
221
+
222
+ all_codes .append (layer_codes )
223
+ quantized_out += layer_codes
224
+
225
+ all_codes = torch .stack (all_codes )
124
226
125
227
# mask out any codes that were dropout-ed
126
228
@@ -195,9 +297,16 @@ def forward(
195
297
null_indices = torch .full (null_indices_shape , - 1. , device = device , dtype = torch .long )
196
298
null_loss = torch .full ((1 ,), 0. , device = device , dtype = x .dtype )
197
299
300
+ # setup the mlps for implicit neural codebook
301
+
302
+ maybe_code_transforms = (None ,) * len (self .layers )
303
+
304
+ if self .implicit_neural_codebook :
305
+ maybe_code_transforms = (None , * self .mlps )
306
+
198
307
# go through the layers
199
308
200
- for quantizer_index , layer in enumerate (self .layers ):
309
+ for quantizer_index , ( vq , maybe_mlp ) in enumerate (zip ( self .layers , maybe_code_transforms ) ):
201
310
202
311
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index :
203
312
all_indices .append (null_indices )
@@ -208,12 +317,20 @@ def forward(
208
317
if return_loss :
209
318
layer_indices = indices [..., quantizer_index ]
210
319
211
- quantized , * rest = layer (
320
+ # setup the transform code function to be passed into VectorQuantize forward
321
+
322
+ if exists (maybe_mlp ):
323
+ maybe_mlp = partial (maybe_mlp , condition = quantized_out )
324
+
325
+ # vector quantize forward
326
+
327
+ quantized , * rest = vq (
212
328
residual ,
213
329
mask = mask ,
214
330
indices = layer_indices ,
215
331
sample_codebook_temp = sample_codebook_temp ,
216
- freeze_codebook = freeze_codebook
332
+ freeze_codebook = freeze_codebook ,
333
+ codebook_transform_fn = maybe_mlp
217
334
)
218
335
219
336
residual = residual - quantized .detach ()
0 commit comments