@@ -193,6 +193,7 @@ def __init__(
193
193
decay = 0.8 ,
194
194
eps = 1e-5 ,
195
195
threshold_ema_dead_code = 2 ,
196
+ reset_cluster_size = None ,
196
197
use_ddp = False ,
197
198
learnable_codebook = False ,
198
199
sample_codebook_temp = 0
@@ -208,6 +209,7 @@ def __init__(
208
209
self .kmeans_iters = kmeans_iters
209
210
self .eps = eps
210
211
self .threshold_ema_dead_code = threshold_ema_dead_code
212
+ self .reset_cluster_size = default (reset_cluster_size , threshold_ema_dead_code )
211
213
self .sample_codebook_temp = sample_codebook_temp
212
214
213
215
assert not (use_ddp and num_codebooks > 1 and kmeans_init ), 'kmeans init is not compatible with multiple codebooks in distributed environment for now'
@@ -250,7 +252,12 @@ def replace(self, batch_samples, batch_mask):
250
252
continue
251
253
252
254
sampled = self .sample_fn (rearrange (samples , '... -> 1 ...' ), mask .sum ().item ())
253
- self .embed .data [ind ][mask ] = rearrange (sampled , '1 ... -> ...' )
255
+ sampled = rearrange (sampled , '1 ... -> ...' )
256
+
257
+ self .embed .data [ind ][mask ] = sampled
258
+
259
+ self .cluster_size .data [ind ][mask ] = self .reset_cluster_size
260
+ self .embed_avg .data [ind ][mask ] = sampled * self .reset_cluster_size
254
261
255
262
def expire_codes_ (self , batch_samples ):
256
263
if self .threshold_ema_dead_code == 0 :
@@ -323,6 +330,7 @@ def __init__(
323
330
decay = 0.8 ,
324
331
eps = 1e-5 ,
325
332
threshold_ema_dead_code = 2 ,
333
+ reset_cluster_size = None ,
326
334
use_ddp = False ,
327
335
learnable_codebook = False ,
328
336
sample_codebook_temp = 0.
@@ -341,6 +349,7 @@ def __init__(
341
349
self .kmeans_iters = kmeans_iters
342
350
self .eps = eps
343
351
self .threshold_ema_dead_code = threshold_ema_dead_code
352
+ self .reset_cluster_size = default (reset_cluster_size , threshold_ema_dead_code )
344
353
self .sample_codebook_temp = sample_codebook_temp
345
354
346
355
self .sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
@@ -382,7 +391,10 @@ def replace(self, batch_samples, batch_mask):
382
391
continue
383
392
384
393
sampled = self .sample_fn (rearrange (samples , '... -> 1 ...' ), mask .sum ().item ())
385
- self .embed .data [ind ][mask ] = rearrange (sampled , '1 ... -> ...' )
394
+ sampled = rearrange (sampled , '1 ... -> ...' )
395
+
396
+ self .embed .data [ind ][mask ] = sampled
397
+ self .cluster_size .data [ind ][mask ] = self .reset_cluster_size
386
398
387
399
def expire_codes_ (self , batch_samples ):
388
400
if self .threshold_ema_dead_code == 0 :
0 commit comments