@@ -27,6 +27,15 @@ def l2norm(t):
27
27
def log (t , eps = 1e-20 ):
28
28
return torch .log (t .clamp (min = eps ))
29
29
30
+ def lerp (old , new , decay ):
31
+ is_mps = getattr (old , 'is_mps' , False )
32
+
33
+ if not is_mps :
34
+ old .lerp_ (new , 1 - decay )
35
+ return old
36
+
37
+ return old * decay + new * (1 - decay )
38
+
30
39
def pack_one (t , pattern ):
31
40
return pack ([t ], pattern )
32
41
@@ -441,11 +450,11 @@ def forward(
441
450
cluster_size = embed_onehot .sum (dim = 1 )
442
451
443
452
self .all_reduce_fn (cluster_size )
444
- self .cluster_size . data . lerp_ ( cluster_size , 1 - self .decay )
453
+ self .cluster_size = lerp ( self . cluster_size , cluster_size , self .decay )
445
454
446
455
embed_sum = einsum ('h n d, h n c -> h c d' , flatten , embed_onehot )
447
456
self .all_reduce_fn (embed_sum .contiguous ())
448
- self .embed_avg . data . lerp_ ( embed_sum , 1 - self .decay )
457
+ self .embed_avg = lerp ( self . embed_avg , embed_sum , self .decay )
449
458
450
459
cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum (dim = - 1 , keepdim = True )
451
460
@@ -598,11 +607,11 @@ def forward(
598
607
bins = embed_onehot .sum (dim = 1 )
599
608
self .all_reduce_fn (bins )
600
609
601
- self .cluster_size . data . lerp_ ( bins , 1 - self .decay )
610
+ self .cluster_size = lerp ( self . cluster_size , bins , self .decay )
602
611
603
612
embed_sum = einsum ('h n d, h n c -> h c d' , flatten , embed_onehot )
604
613
self .all_reduce_fn (embed_sum .contiguous ())
605
- self .embed_avg . data . lerp_ ( embed_sum , 1 - self .decay )
614
+ self .embed_avg = lerp ( self . embed_avg , embed_sum , self .decay )
606
615
607
616
cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum (dim = - 1 , keepdim = True )
608
617
0 commit comments