@@ -280,6 +280,7 @@ def __init__(
280
280
gumbel_sample = gumbel_sample ,
281
281
sample_codebook_temp = 1. ,
282
282
ema_update = True ,
283
+ manual_ema_update = False ,
283
284
affine_param = False ,
284
285
sync_affine_param = False ,
285
286
affine_param_batch_decay = 0.99 ,
@@ -290,6 +291,7 @@ def __init__(
290
291
291
292
self .decay = decay
292
293
self .ema_update = ema_update
294
+ self .manual_ema_update = manual_ema_update
293
295
294
296
init_fn = uniform_init if not kmeans_init else torch .zeros
295
297
embed = init_fn (num_codebooks , codebook_size , dim )
@@ -458,6 +460,12 @@ def expire_codes_(self, batch_samples):
458
460
batch_samples = rearrange (batch_samples , 'h ... d -> h (...) d' )
459
461
self .replace (batch_samples , batch_mask = expired_codes )
460
462
463
+ def update_ema (self ):
464
+ cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum (dim = - 1 , keepdim = True )
465
+
466
+ embed_normalized = self .embed_avg / rearrange (cluster_size , '... -> ... 1' )
467
+ self .embed .data .copy_ (embed_normalized )
468
+
461
469
@autocast ('cuda' , enabled = False )
462
470
def forward (
463
471
self ,
@@ -551,11 +559,9 @@ def forward(
551
559
552
560
ema_inplace (self .embed_avg .data , embed_sum , self .decay )
553
561
554
- cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum (dim = - 1 , keepdim = True )
555
-
556
- embed_normalized = self .embed_avg / rearrange (cluster_size , '... -> ... 1' )
557
- self .embed .data .copy_ (embed_normalized )
558
- self .expire_codes_ (x )
562
+ if not self .manual_ema_update :
563
+ self .update_ema ()
564
+ self .expire_codes_ (x )
559
565
560
566
if needs_codebook_dim :
561
567
quantize , embed_ind = map (lambda t : rearrange (t , '1 ... -> ...' ), (quantize , embed_ind ))
@@ -582,11 +588,14 @@ def __init__(
582
588
gumbel_sample = gumbel_sample ,
583
589
sample_codebook_temp = 1. ,
584
590
ema_update = True ,
591
+ manual_ema_update = False
585
592
):
586
593
super ().__init__ ()
587
594
self .transform_input = l2norm
588
595
589
596
self .ema_update = ema_update
597
+ self .manual_ema_update = manual_ema_update
598
+
590
599
self .decay = decay
591
600
592
601
if not kmeans_init :
@@ -671,6 +680,14 @@ def expire_codes_(self, batch_samples):
671
680
batch_samples = rearrange (batch_samples , 'h ... d -> h (...) d' )
672
681
self .replace (batch_samples , batch_mask = expired_codes )
673
682
683
+ def update_ema (self ):
684
+ cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum (dim = - 1 , keepdim = True )
685
+
686
+ embed_normalized = self .embed_avg / rearrange (cluster_size , '... -> ... 1' )
687
+ embed_normalized = l2norm (embed_normalized )
688
+
689
+ self .embed .data .copy_ (embed_normalized )
690
+
674
691
@autocast ('cuda' , enabled = False )
675
692
def forward (
676
693
self ,
@@ -746,13 +763,9 @@ def forward(
746
763
747
764
ema_inplace (self .embed_avg .data , embed_sum , self .decay )
748
765
749
- cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum (dim = - 1 , keepdim = True )
750
-
751
- embed_normalized = self .embed_avg / rearrange (cluster_size , '... -> ... 1' )
752
- embed_normalized = l2norm (embed_normalized )
753
-
754
- self .embed .data .copy_ (embed_normalized )
755
- self .expire_codes_ (x )
766
+ if not self .manual_ema_update :
767
+ self .update_ema ()
768
+ self .expire_codes_ (x )
756
769
757
770
if needs_codebook_dim :
758
771
quantize , embed_ind = map (lambda t : rearrange (t , '1 ... -> ...' ), (quantize , embed_ind ))
@@ -802,6 +815,7 @@ def __init__(
802
815
sync_codebook = None ,
803
816
sync_affine_param = False ,
804
817
ema_update = True ,
818
+ manual_ema_update = False ,
805
819
learnable_codebook = False ,
806
820
in_place_codebook_optimizer : Callable [..., Optimizer ] = None , # Optimizer used to update the codebook embedding if using learnable_codebook
807
821
affine_param = False ,
@@ -881,7 +895,8 @@ def __init__(
881
895
learnable_codebook = has_codebook_orthogonal_loss or learnable_codebook ,
882
896
sample_codebook_temp = sample_codebook_temp ,
883
897
gumbel_sample = gumbel_sample_fn ,
884
- ema_update = ema_update
898
+ ema_update = ema_update ,
899
+ manual_ema_update = manual_ema_update
885
900
)
886
901
887
902
if affine_param :
0 commit comments