@@ -28,7 +28,9 @@ class ResidualVQ(nn.Module):
28
28
def __init__ (
29
29
self ,
30
30
* ,
31
+ dim ,
31
32
num_quantizers ,
33
+ codebook_dim = None ,
32
34
shared_codebook = False ,
33
35
heads = 1 ,
34
36
quantize_dropout = False ,
@@ -39,11 +41,17 @@ def __init__(
39
41
):
40
42
super ().__init__ ()
41
43
assert heads == 1 , 'residual vq is not compatible with multi-headed codes'
44
+ codebook_dim = default (codebook_dim , dim )
45
+ codebook_input_dim = codebook_dim * heads
46
+
47
+ requires_projection = codebook_input_dim != dim
48
+ self .project_in = nn .Linear (dim , codebook_input_dim ) if requires_projection else nn .Identity ()
49
+ self .project_out = nn .Linear (codebook_input_dim , dim ) if requires_projection else nn .Identity ()
42
50
43
51
self .num_quantizers = num_quantizers
44
52
45
53
self .accept_image_fmap = accept_image_fmap
46
- self .layers = nn .ModuleList ([VectorQuantize (accept_image_fmap = accept_image_fmap , ** kwargs ) for _ in range (num_quantizers )])
54
+ self .layers = nn .ModuleList ([VectorQuantize (dim = codebook_dim , codebook_dim = codebook_dim , accept_image_fmap = accept_image_fmap , ** kwargs ) for _ in range (num_quantizers )])
47
55
48
56
self .quantize_dropout = quantize_dropout and num_quantizers > 1
49
57
@@ -114,6 +122,8 @@ def forward(
114
122
):
115
123
num_quant , quant_dropout_multiple_of , return_loss , device = self .num_quantizers , self .quantize_dropout_multiple_of , exists (indices ), x .device
116
124
125
+ x = self .project_in (x )
126
+
117
127
assert not (self .accept_image_fmap and exists (indices ))
118
128
119
129
quantized_out = 0.
@@ -169,6 +179,10 @@ def forward(
169
179
all_indices .append (embed_indices )
170
180
all_losses .append (loss )
171
181
182
+ # project out, if needed
183
+
184
+ quantized_out = self .project_out (quantized_out )
185
+
172
186
# whether to early return the cross entropy loss
173
187
174
188
if return_loss :
0 commit comments