@@ -54,7 +54,9 @@ def __init__(
54
54
codebook_size = None ,
55
55
entropy_loss_weight = 0.1 ,
56
56
diversity_gamma = 2.5 ,
57
- straight_through_activation = nn .Tanh ()
57
+ straight_through_activation = nn .Tanh (),
58
+ num_codebooks = 1 ,
59
+ keep_num_codebooks_dim = None
58
60
):
59
61
super ().__init__ ()
60
62
@@ -66,13 +68,19 @@ def __init__(
66
68
codebook_size = default (codebook_size , lambda : 2 ** dim )
67
69
codebook_dim = int (log2 (codebook_size ))
68
70
69
- dim = default (dim , codebook_dim )
71
+ codebook_dims = codebook_dim * num_codebooks
72
+ dim = default (dim , codebook_dims )
70
73
71
- self .project_in = nn .Linear (dim , codebook_dim ) if dim != codebook_dim else nn .Identity ()
72
- self .project_out = nn .Linear (codebook_dim , dim ) if dim != codebook_dim else nn .Identity ()
74
+ self .project_in = nn .Linear (dim , codebook_dims ) if dim != codebook_dims else nn .Identity ()
75
+ self .project_out = nn .Linear (codebook_dims , dim ) if dim != codebook_dims else nn .Identity ()
73
76
74
77
self .dim = dim
75
78
self .codebook_dim = codebook_dim
79
+ self .num_codebooks = num_codebooks
80
+
81
+ keep_num_codebooks_dim = default (keep_num_codebooks_dim , num_codebooks > 1 )
82
+ assert not (num_codebooks > 1 and not keep_num_codebooks_dim )
83
+ self .keep_num_codebooks_dim = keep_num_codebooks_dim
76
84
77
85
# straight through activation
78
86
@@ -95,11 +103,16 @@ def indices_to_codes(
95
103
):
96
104
is_img_or_video = indices .ndim >= 3
97
105
106
+ if not self .keep_num_codebooks_dim :
107
+ indices = rearrange (indices , '... -> ... 1' )
108
+
98
109
# indices to codes, which are bits of either -1 or 1
99
110
100
111
bits = ((indices [..., None ].int () & self .mask ) != 0 ).float ()
101
112
codes = bits * 2 - 1
102
113
114
+ codes = rearrange (codes , '... c d -> ... (c d)' )
115
+
103
116
# whether to project codes out to original dimensions
104
117
# if the input feature dimensions were not log2(codebook size)
105
118
@@ -123,6 +136,7 @@ def forward(
123
136
b - batch
124
137
n - sequence (or flattened spatial dimensions)
125
138
d - feature dimension, which is also log2(codebook size)
139
+ c - number of codebook dim
126
140
"""
127
141
128
142
is_img_or_video = x .ndim >= 4
@@ -133,10 +147,14 @@ def forward(
133
147
x = rearrange (x , 'b d ... -> b ... d' )
134
148
x , ps = pack_one (x , 'b * d' )
135
149
136
- assert x .shape [- 1 ] == self .dim
150
+ assert x .shape [- 1 ] == self .dim , f'expected dimension of { self . dim } but received { x . shape [ - 1 ] } '
137
151
138
152
x = self .project_in (x )
139
153
154
+ # split out number of codebooks
155
+
156
+ x = rearrange (x , 'b n (c d) -> b n c d' , c = self .num_codebooks )
157
+
140
158
# quantize by eq 3.
141
159
142
160
ones = torch .ones_like (x )
@@ -152,7 +170,7 @@ def forward(
152
170
153
171
# calculate indices
154
172
155
- indices = reduce ((x > 0 ).int () * self .mask .int (), 'b n d -> b n' , 'sum' )
173
+ indices = reduce ((x > 0 ).int () * self .mask .int (), 'b n c d -> b n c ' , 'sum' )
156
174
157
175
# entropy aux loss
158
176
@@ -161,7 +179,7 @@ def forward(
161
179
162
180
bit_entropy = binary_entropy (prob ).mean ()
163
181
164
- avg_prob = reduce (prob , 'b n d -> b d' , 'mean' )
182
+ avg_prob = reduce (prob , 'b n c d -> b c d' , 'mean' )
165
183
codebook_entropy = binary_entropy (avg_prob ).mean ()
166
184
167
185
# 1. entropy will be nudged to be low for each bit, so each scalar commits to one latent binary bit or the other
@@ -174,6 +192,10 @@ def forward(
174
192
175
193
entropy_aux_loss = entropy_aux_loss * self .entropy_loss_weight
176
194
195
+ # merge back codebook dim
196
+
197
+ x = rearrange (x , 'b n c d -> b n (c d)' )
198
+
177
199
# project out to feature dimension if needed
178
200
179
201
x = self .project_out (x )
@@ -184,6 +206,11 @@ def forward(
184
206
x = unpack_one (x , ps , 'b * d' )
185
207
x = rearrange (x , 'b ... d -> b d ...' )
186
208
187
- indices = unpack_one (indices , ps , 'b *' )
209
+ indices = unpack_one (indices , ps , 'b * c' )
210
+
211
+ # whether to remove single codebook dim
212
+
213
+ if not self .keep_num_codebooks_dim :
214
+ indices = rearrange (indices , '... 1 -> ...' )
188
215
189
216
return Return (x , indices , entropy_aux_loss )
0 commit comments