@@ -27,14 +27,13 @@ def l2norm(t):
27
27
def log (t , eps = 1e-20 ):
28
28
return torch .log (t .clamp (min = eps ))
29
29
30
- def lerp (old , new , decay ):
30
+ def ema_inplace (old , new , decay ):
31
31
is_mps = getattr (old , 'is_mps' , False )
32
32
33
33
if not is_mps :
34
34
old .lerp_ (new , 1 - decay )
35
- return old
36
-
37
- return old * decay + new * (1 - decay )
35
+ else :
36
+ old .mul_ (decay ).add_ (new * (1 - decay ))
38
37
39
38
def pack_one (t , pattern ):
40
39
return pack ([t ], pattern )
@@ -450,11 +449,11 @@ def forward(
450
449
cluster_size = embed_onehot .sum (dim = 1 )
451
450
452
451
self .all_reduce_fn (cluster_size )
453
- self . cluster_size = lerp (self .cluster_size , cluster_size , self .decay )
452
+ ema_inplace (self .cluster_size , cluster_size , self .decay )
454
453
455
454
embed_sum = einsum ('h n d, h n c -> h c d' , flatten , embed_onehot )
456
455
self .all_reduce_fn (embed_sum .contiguous ())
457
- self . embed_avg = lerp (self .embed_avg , embed_sum , self .decay )
456
+ ema_inplace (self .embed_avg , embed_sum , self .decay )
458
457
459
458
cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum (dim = - 1 , keepdim = True )
460
459
@@ -607,11 +606,11 @@ def forward(
607
606
bins = embed_onehot .sum (dim = 1 )
608
607
self .all_reduce_fn (bins )
609
608
610
- self . cluster_size = lerp (self .cluster_size , bins , self .decay )
609
+ ema_inplace (self .cluster_size , bins , self .decay )
611
610
612
611
embed_sum = einsum ('h n d, h n c -> h c d' , flatten , embed_onehot )
613
612
self .all_reduce_fn (embed_sum .contiguous ())
614
- self . embed_avg = lerp (self .embed_avg , embed_sum , self .decay )
613
+ ema_inplace (self .embed_avg , embed_sum , self .decay )
615
614
616
615
cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum (dim = - 1 , keepdim = True )
617
616
0 commit comments