@@ -38,22 +38,25 @@ class ResidualFSQ(Module):
38
38
def __init__ (
39
39
self ,
40
40
* ,
41
- dim ,
42
41
levels : List [int ],
43
42
num_quantizers ,
43
+ dim = None ,
44
+ is_channel_first = False ,
44
45
quantize_dropout = False ,
45
46
quantize_dropout_cutoff_index = 0 ,
46
47
quantize_dropout_multiple_of = 1 ,
47
48
** kwargs
48
49
):
49
50
super ().__init__ ()
50
51
codebook_dim = len (levels )
52
+ dim = default (dim , codebook_dim )
51
53
52
54
requires_projection = codebook_dim != dim
53
55
self .project_in = nn .Linear (dim , codebook_dim ) if requires_projection else nn .Identity ()
54
56
self .project_out = nn .Linear (codebook_dim , dim ) if requires_projection else nn .Identity ()
55
57
self .has_projections = requires_projection
56
58
59
+ self .is_channel_first = is_channel_first
57
60
self .num_quantizers = num_quantizers
58
61
59
62
self .levels = levels
@@ -143,10 +146,18 @@ def forward(
143
146
):
144
147
num_quant , quant_dropout_multiple_of , device = self .num_quantizers , self .quantize_dropout_multiple_of , x .device
145
148
149
+ # handle channel first
150
+
151
+ if self .is_channel_first :
152
+ x = rearrange (x , 'b d ... -> b ... d' )
153
+ x , ps = pack ([x ], 'b * d' )
154
+
155
+ # maybe project in
156
+
146
157
x = self .project_in (x )
147
158
148
159
quantized_out = 0.
149
- residual = first ( self . layers ). bound ( x )
160
+ residual = x
150
161
151
162
all_indices = []
152
163
@@ -175,6 +186,7 @@ def forward(
175
186
continue
176
187
177
188
quantized , indices = layer (residual / scale )
189
+
178
190
quantized = quantized * scale
179
191
180
192
residual = residual - quantized .detach ()
@@ -190,6 +202,17 @@ def forward(
190
202
191
203
all_indices = torch .stack (all_indices , dim = - 1 )
192
204
205
+ # channel first out
206
+
207
+ if self .is_channel_first :
208
+ quantized_out , = unpack (quantized_out , ps , 'b * d' )
209
+ all_indices , = unpack (all_indices , ps , 'b * d' )
210
+
211
+ quantized_out = rearrange (quantized_out , 'b ... d -> b d ...' )
212
+ all_indices = rearrange (all_indices , 'b ... d -> b d ...' )
213
+
214
+ # return
215
+
193
216
ret = (quantized_out , all_indices )
194
217
195
218
if not return_all_codes :
0 commit comments