@@ -37,7 +37,8 @@ def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False):
37
37
if use_cosine_sim :
38
38
dists = samples @ means .t ()
39
39
else :
40
- diffs = rearrange (samples , 'n d -> n () d' ) - rearrange (means , 'c d -> () c d' )
40
+ diffs = rearrange (samples , 'n d -> n () d' ) \
41
+ - rearrange (means , 'c d -> () c d' )
41
42
dists = - (diffs ** 2 ).sum (dim = - 1 )
42
43
43
44
buckets = dists .max (dim = - 1 ).indices
@@ -66,7 +67,8 @@ def __init__(
66
67
kmeans_init = False ,
67
68
kmeans_iters = 10 ,
68
69
decay = 0.8 ,
69
- eps = 1e-5
70
+ eps = 1e-5 ,
71
+ threshold_ema_dead_code = 2
70
72
):
71
73
super ().__init__ ()
72
74
self .decay = decay
@@ -76,6 +78,7 @@ def __init__(
76
78
self .codebook_size = codebook_size
77
79
self .kmeans_iters = kmeans_iters
78
80
self .eps = eps
81
+ self .threshold_ema_dead_code = threshold_ema_dead_code
79
82
80
83
self .register_buffer ('initted' , torch .Tensor ([not kmeans_init ]))
81
84
self .register_buffer ('cluster_size' , torch .zeros (codebook_size ))
@@ -89,9 +92,22 @@ def init_embed_(self, data):
89
92
self .initted .data .copy_ (torch .Tensor ([True ]))
90
93
91
94
def replace (self , samples , mask ):
92
- modified_codebook = torch .where (mask [..., None ], sample_vectors (samples , self .codebook_size ), self .embed )
95
+ modified_codebook = torch .where (
96
+ mask [..., None ],
97
+ sample_vectors (samples , self .codebook_size ),
98
+ self .embed
99
+ )
93
100
self .embed .data .copy_ (modified_codebook )
94
101
102
+ def expire_codes_ (self , batch_samples ):
103
+ if self .threshold_ema_dead_code == 0 :
104
+ return
105
+
106
+ expired_codes = self .cluster_size < self .threshold_ema_dead_code
107
+ if torch .any (expired_codes ):
108
+ batch_samples = rearrange (batch_samples , '... d -> (...) d' )
109
+ self .replace (batch_samples , mask = expired_codes )
110
+
95
111
def forward (self , x ):
96
112
shape , dtype = x .shape , x .dtype
97
113
flatten = rearrange (x , '... d -> (...) d' )
@@ -107,7 +123,7 @@ def forward(self, x):
107
123
)
108
124
109
125
embed_ind = dist .max (dim = - 1 ).indices
110
- embed_onehot = F .one_hot (embed_ind , self .codebook_size ).type (x . dtype )
126
+ embed_onehot = F .one_hot (embed_ind , self .codebook_size ).type (dtype )
111
127
embed_ind = embed_ind .view (* shape [:- 1 ])
112
128
quantize = F .embedding (embed_ind , self .embed )
113
129
@@ -118,6 +134,7 @@ def forward(self, x):
118
134
cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum ()
119
135
embed_normalized = self .embed_avg / cluster_size .unsqueeze (1 )
120
136
self .embed .data .copy_ (embed_normalized )
137
+ self .expire_codes_ (x )
121
138
122
139
return quantize , embed_ind
123
140
@@ -129,7 +146,8 @@ def __init__(
129
146
kmeans_init = False ,
130
147
kmeans_iters = 10 ,
131
148
decay = 0.8 ,
132
- eps = 1e-5
149
+ eps = 1e-5 ,
150
+ threshold_ema_dead_code = 2
133
151
):
134
152
super ().__init__ ()
135
153
self .decay = decay
@@ -142,20 +160,35 @@ def __init__(
142
160
self .codebook_size = codebook_size
143
161
self .kmeans_iters = kmeans_iters
144
162
self .eps = eps
163
+ self .threshold_ema_dead_code = threshold_ema_dead_code
145
164
146
165
self .register_buffer ('initted' , torch .Tensor ([not kmeans_init ]))
147
166
self .register_buffer ('embed' , embed )
148
167
149
168
def init_embed_ (self , data ):
150
- embed = kmeans (data , self .codebook_size , self .kmeans_iters , use_cosine_sim = True )
169
+ embed = kmeans (data , self .codebook_size , self .kmeans_iters ,
170
+ use_cosine_sim = True )
151
171
self .embed .data .copy_ (embed )
152
172
self .initted .data .copy_ (torch .Tensor ([True ]))
153
173
154
174
def replace (self , samples , mask ):
155
175
samples = l2norm (samples )
156
- modified_codebook = torch .where (mask [..., None ], sample_vectors (samples , self .codebook_size ), self .embed )
176
+ modified_codebook = torch .where (
177
+ mask [..., None ],
178
+ sample_vectors (samples , self .codebook_size ),
179
+ self .embed
180
+ )
157
181
self .embed .data .copy_ (modified_codebook )
158
182
183
+ def expire_codes_ (self , batch_samples ):
184
+ if self .threshold_ema_dead_code == 0 :
185
+ return
186
+
187
+ expired_codes = self .cluster_size < self .threshold_ema_dead_code
188
+ if torch .any (expired_codes ):
189
+ batch_samples = rearrange (batch_samples , '... d -> (...) d' )
190
+ self .replace (batch_samples , mask = expired_codes )
191
+
159
192
def forward (self , x ):
160
193
shape , dtype = x .shape , x .dtype
161
194
flatten = rearrange (x , '... d -> (...) d' )
@@ -180,8 +213,10 @@ def forward(self, x):
180
213
embed_sum = flatten .t () @ embed_onehot
181
214
embed_normalized = (embed_sum / bins .unsqueeze (0 )).t ()
182
215
embed_normalized = l2norm (embed_normalized )
183
- embed_normalized = torch .where (zero_mask [..., None ], embed , embed_normalized )
216
+ embed_normalized = torch .where (zero_mask [..., None ], embed ,
217
+ embed_normalized )
184
218
ema_inplace (self .embed , embed_normalized , self .decay )
219
+ self .expire_codes_ (x )
185
220
186
221
return quantize , embed_ind
187
222
@@ -200,59 +235,41 @@ def __init__(
200
235
kmeans_init = False ,
201
236
kmeans_iters = 10 ,
202
237
use_cosine_sim = False ,
203
- max_codebook_misses_before_expiry = 0
238
+ threshold_ema_dead_code = 0
204
239
):
205
240
super ().__init__ ()
206
241
n_embed = default (n_embed , codebook_size )
207
242
208
243
codebook_dim = default (codebook_dim , dim )
209
244
requires_projection = codebook_dim != dim
210
- self .project_in = nn .Linear (dim , codebook_dim ) if requires_projection else nn .Identity ()
211
- self .project_out = nn .Linear (codebook_dim , dim ) if requires_projection else nn .Identity ()
245
+ self .project_in = nn .Linear (dim , codebook_dim ) if requires_projection \
246
+ else nn .Identity ()
247
+ self .project_out = nn .Linear (codebook_dim , dim ) if requires_projection \
248
+ else nn .Identity ()
212
249
213
250
self .eps = eps
214
251
self .commitment = commitment
215
252
216
- klass = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook
253
+ codebook_class = EuclideanCodebook if not use_cosine_sim \
254
+ else CosineSimCodebook
217
255
218
- self ._codebook = klass (
256
+ self ._codebook = codebook_class (
219
257
dim = codebook_dim ,
220
258
codebook_size = n_embed ,
221
259
kmeans_init = kmeans_init ,
222
260
kmeans_iters = kmeans_iters ,
223
261
decay = decay ,
224
- eps = eps
262
+ eps = eps ,
263
+ threshold_ema_dead_code = threshold_ema_dead_code
225
264
)
226
265
227
266
self .codebook_size = codebook_size
228
- self .max_codebook_misses_before_expiry = max_codebook_misses_before_expiry
229
-
230
- if max_codebook_misses_before_expiry > 0 :
231
- codebook_misses = torch .zeros (codebook_size )
232
- self .register_buffer ('codebook_misses' , codebook_misses )
233
267
234
268
@property
235
269
def codebook (self ):
236
270
return self ._codebook .codebook
237
271
238
- def expire_codes_ (self , embed_ind , batch_samples ):
239
- if self .max_codebook_misses_before_expiry == 0 :
240
- return
241
-
242
- embed_ind = rearrange (embed_ind , '... -> (...)' )
243
- misses = torch .bincount (embed_ind , minlength = self .codebook_size ) == 0
244
- self .codebook_misses += misses
245
-
246
- expired_codes = self .codebook_misses >= self .max_codebook_misses_before_expiry
247
- if not torch .any (expired_codes ):
248
- return
249
-
250
- self .codebook_misses .masked_fill_ (expired_codes , 0 )
251
- batch_samples = rearrange (batch_samples , '... d -> (...) d' )
252
- self ._codebook .replace (batch_samples , mask = expired_codes )
253
-
254
272
def forward (self , x ):
255
- dtype = x .dtype
256
273
x = self .project_in (x )
257
274
258
275
quantize , embed_ind = self ._codebook (x )
@@ -262,7 +279,6 @@ def forward(self, x):
262
279
if self .training :
263
280
commit_loss = F .mse_loss (quantize .detach (), x ) * self .commitment
264
281
quantize = x + (quantize - x ).detach ()
265
- self .expire_codes_ (embed_ind , x )
266
282
267
283
quantize = self .project_out (quantize )
268
284
return quantize , embed_ind , commit_loss
0 commit comments