@@ -280,7 +280,6 @@ def __init__(
280
280
threshold_ema_dead_code = 2 ,
281
281
reset_cluster_size = None ,
282
282
use_ddp = False ,
283
- distributed_replace_codes = True ,
284
283
learnable_codebook = False ,
285
284
gumbel_sample = gumbel_sample ,
286
285
sample_codebook_temp = 1. ,
@@ -315,8 +314,7 @@ def __init__(
315
314
316
315
self .sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
317
316
318
- self .distributed_replace_codes = distributed_replace_codes
319
- self .replace_sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans and distributed_replace_codes else batched_sample_vectors
317
+ self .replace_sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
320
318
321
319
self .kmeans_all_reduce_fn = distributed .all_reduce if use_ddp and sync_kmeans else noop
322
320
self .all_reduce_fn = distributed .all_reduce if use_ddp else noop
@@ -448,9 +446,6 @@ def replace(self, batch_samples, batch_mask):
448
446
sampled = self .replace_sample_fn (rearrange (samples , '... -> 1 ...' ), mask .sum ().item ())
449
447
sampled = rearrange (sampled , '1 ... -> ...' )
450
448
451
- if not self .distributed_replace_codes :
452
- sampled = maybe_distributed_mean (sampled )
453
-
454
449
self .embed .data [ind ][mask ] = sampled
455
450
self .cluster_size .data [ind ][mask ] = self .reset_cluster_size
456
451
self .embed_avg .data [ind ][mask ] = sampled * self .reset_cluster_size
@@ -559,7 +554,6 @@ def __init__(
559
554
threshold_ema_dead_code = 2 ,
560
555
reset_cluster_size = None ,
561
556
use_ddp = False ,
562
- distributed_replace_codes = True ,
563
557
learnable_codebook = False ,
564
558
gumbel_sample = gumbel_sample ,
565
559
sample_codebook_temp = 1. ,
@@ -590,8 +584,7 @@ def __init__(
590
584
591
585
self .sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
592
586
593
- self .distributed_replace_codes = distributed_replace_codes
594
- self .replace_sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans and distributed_replace_codes else batched_sample_vectors
587
+ self .replace_sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
595
588
596
589
self .kmeans_all_reduce_fn = distributed .all_reduce if use_ddp and sync_kmeans else noop
597
590
self .all_reduce_fn = distributed .all_reduce if use_ddp else noop
@@ -638,9 +631,6 @@ def replace(self, batch_samples, batch_mask):
638
631
sampled = self .replace_sample_fn (rearrange (samples , '... -> 1 ...' ), mask .sum ().item ())
639
632
sampled = rearrange (sampled , '1 ... -> ...' )
640
633
641
- if not self .distributed_replace_codes :
642
- sampled = maybe_distributed_mean (sampled )
643
-
644
634
self .embed .data [ind ][mask ] = sampled
645
635
self .embed_avg .data [ind ][mask ] = sampled * self .reset_cluster_size
646
636
self .cluster_size .data [ind ][mask ] = self .reset_cluster_size
@@ -762,7 +752,6 @@ def __init__(
762
752
stochastic_sample_codes = False ,
763
753
sample_codebook_temp = 1. ,
764
754
straight_through = False ,
765
- distributed_replace_codes = True ,
766
755
reinmax = False , # using reinmax for improved straight-through, assuming straight through helps at all
767
756
sync_codebook = None ,
768
757
sync_affine_param = False ,
@@ -845,8 +834,7 @@ def __init__(
845
834
learnable_codebook = has_codebook_orthogonal_loss or learnable_codebook ,
846
835
sample_codebook_temp = sample_codebook_temp ,
847
836
gumbel_sample = gumbel_sample_fn ,
848
- ema_update = ema_update ,
849
- distributed_replace_codes = distributed_replace_codes
837
+ ema_update = ema_update
850
838
)
851
839
852
840
if affine_param :
0 commit comments