@@ -61,7 +61,22 @@ def log(t, eps = 1e-20):
61
61
def entropy (prob , eps = 1e-5 ):
62
62
return (- prob * log (prob , eps = eps )).sum (dim = - 1 )
63
63
64
+ def accum_grad_ (t , grad ):
65
+ if exists (t .grad ):
66
+ t .grad .add_ (grad )
67
+ else :
68
+ t .grad = grad .clone ().detach ()
69
+
64
70
def ema_inplace (old , new , decay , weight = None ):
71
+
72
+ # if old.grad is populated, add it to new and set it to None
73
+
74
+ if exists (old .grad ):
75
+ new .add_ (old .grad )
76
+ old .grad = None
77
+
78
+ # take care of custom weighting
79
+
65
80
weight = default (weight , 1. )
66
81
67
82
if is_tensor (weight ):
@@ -71,7 +86,7 @@ def ema_inplace(old, new, decay, weight = None):
71
86
assert weight .ndim == 2 and weight .shape == old .shape [:2 ]
72
87
weight = append_dims_to (weight , old .ndim )
73
88
74
- old .lerp_ (new , (1. - decay ) * weight )
89
+ old .data . lerp_ (new , (1. - decay ) * weight )
75
90
76
91
def pack_one (t , pattern ):
77
92
packed , ps = pack ([t ], pattern )
@@ -511,7 +526,8 @@ def forward(
511
526
mask = None ,
512
527
freeze_codebook = False ,
513
528
codebook_transform_fn : Callable | None = None ,
514
- ema_update_weight : Tensor | Callable | None = None
529
+ ema_update_weight : Tensor | Callable | None = None ,
530
+ accum_ema_update = False
515
531
):
516
532
needs_codebook_dim = x .ndim < 4
517
533
sample_codebook_temp = default (sample_codebook_temp , self .sample_codebook_temp )
@@ -603,12 +619,16 @@ def forward(
603
619
if callable (ema_update_weight ):
604
620
ema_update_weight = ema_update_weight (embed_sum , cluster_size )
605
621
606
- ema_inplace (self .cluster_size .data , cluster_size , self .decay , ema_update_weight )
607
- ema_inplace (self .embed_avg .data , embed_sum , self .decay , ema_update_weight )
622
+ if accum_ema_update :
623
+ accum_grad_ (self .cluster_size , cluster_size )
624
+ accum_grad_ (self .embed_avg , embed_sum )
625
+ else :
626
+ ema_inplace (self .cluster_size , cluster_size , self .decay , ema_update_weight )
627
+ ema_inplace (self .embed_avg , embed_sum , self .decay , ema_update_weight )
608
628
609
- if not self .manual_ema_update :
610
- self .update_ema ()
611
- self .expire_codes_ (x )
629
+ if not self .manual_ema_update :
630
+ self .update_ema ()
631
+ self .expire_codes_ (x )
612
632
613
633
if needs_codebook_dim :
614
634
quantize , embed_ind = map (lambda t : rearrange (t , '1 ... -> ...' ), (quantize , embed_ind ))
@@ -743,7 +763,8 @@ def forward(
743
763
mask = None ,
744
764
freeze_codebook = False ,
745
765
codebook_transform_fn : Callable | None = None ,
746
- ema_update_weight : Tensor | None = None
766
+ ema_update_weight : Tensor | None = None ,
767
+ accum_ema_update = False
747
768
):
748
769
needs_codebook_dim = x .ndim < 4
749
770
sample_codebook_temp = default (sample_codebook_temp , self .sample_codebook_temp )
@@ -819,12 +840,17 @@ def forward(
819
840
if callable (ema_update_weight ):
820
841
ema_update_weight = ema_update_weight (embed_sum , bins )
821
842
822
- ema_inplace (self .cluster_size .data , bins , self .decay , ema_update_weight )
823
- ema_inplace (self .embed_avg .data , embed_sum , self .decay , ema_update_weight )
843
+ if accum_ema_update :
844
+ accum_grad_ (self .cluster_size , bins )
845
+ accum_grad_ (self .embed_avg , embed_sum )
846
+ else :
847
+
848
+ ema_inplace (self .cluster_size , bins , self .decay , ema_update_weight )
849
+ ema_inplace (self .embed_avg , embed_sum , self .decay , ema_update_weight )
824
850
825
- if not self .manual_ema_update :
826
- self .update_ema ()
827
- self .expire_codes_ (x )
851
+ if not self .manual_ema_update :
852
+ self .update_ema ()
853
+ self .expire_codes_ (x )
828
854
829
855
if needs_codebook_dim :
830
856
quantize , embed_ind = map (lambda t : rearrange (t , '1 ... -> ...' ), (quantize , embed_ind ))
@@ -1062,7 +1088,8 @@ def forward(
1062
1088
freeze_codebook = None ,
1063
1089
return_loss_breakdown = False ,
1064
1090
codebook_transform_fn : Callable | None = None ,
1065
- ema_update_weight : Tensor | None = None
1091
+ ema_update_weight : Tensor | None = None ,
1092
+ accum_ema_update = False
1066
1093
):
1067
1094
orig_input , input_requires_grad = x , x .requires_grad
1068
1095
@@ -1119,7 +1146,8 @@ def forward(
1119
1146
mask = mask ,
1120
1147
freeze_codebook = freeze_codebook ,
1121
1148
codebook_transform_fn = codebook_transform_fn ,
1122
- ema_update_weight = ema_update_weight
1149
+ ema_update_weight = ema_update_weight ,
1150
+ accum_ema_update = accum_ema_update
1123
1151
)
1124
1152
1125
1153
# quantize
0 commit comments