5
5
6
6
import torch
7
7
from torch .nn import Module
8
- from torch import nn , einsum , Tensor
8
+ from torch import nn , einsum , is_tensor , Tensor
9
9
import torch .nn .functional as F
10
10
import torch .distributed as distributed
11
11
from torch .optim import Optimizer
@@ -34,6 +34,12 @@ def l2norm(t, dim = -1, eps = 1e-6):
34
34
def safe_div (num , den , eps = 1e-6 ):
35
35
return num / den .clamp (min = eps )
36
36
37
+ def append_dims_to (t , ndims ):
38
+ assert t .ndim <= ndims
39
+ append_ndims = ndims - t .ndim
40
+ shape = t .shape
41
+ return t .reshape (* shape , * ((1 ,) * append_ndims ))
42
+
37
43
def Sequential (* modules ):
38
44
modules = [* filter (exists , modules )]
39
45
if len (modules ) == 0 :
@@ -55,13 +61,17 @@ def log(t, eps = 1e-20):
55
61
def entropy (prob , eps = 1e-5 ):
56
62
return (- prob * log (prob , eps = eps )).sum (dim = - 1 )
57
63
58
- def ema_inplace (old , new , decay ):
59
- is_mps = str ( old . device ). startswith ( 'mps:' )
64
+ def ema_inplace (old , new , decay , weight = None ):
65
+ weight = default ( weight , 1. )
60
66
61
- if not is_mps :
62
- old .lerp_ (new , 1 - decay )
63
- else :
64
- old .mul_ (decay ).add_ (new * (1 - decay ))
67
+ if is_tensor (weight ):
68
+ if weight .ndim == 1 :
69
+ weight = rearrange (weight , 'c -> 1 c' )
70
+
71
+ assert weight .ndim == 2 and weight .shape == old .shape [:2 ]
72
+ weight = append_dims_to (weight , old .ndim )
73
+
74
+ old .lerp_ (new , (1. - decay ) * weight )
65
75
66
76
def pack_one (t , pattern ):
67
77
packed , ps = pack ([t ], pattern )
@@ -392,9 +402,9 @@ def init_embed_(self, data, mask = None):
392
402
393
403
embed_sum = embed * rearrange (cluster_size , '... -> ... 1' )
394
404
395
- self .embed .data .copy_ (embed )
396
405
self .embed_avg .data .copy_ (embed_sum )
397
406
self .cluster_size .data .copy_ (cluster_size )
407
+ self .update_ema ()
398
408
self .initted .data .copy_ (torch .Tensor ([True ]))
399
409
400
410
@torch .jit .ignore
@@ -500,7 +510,8 @@ def forward(
500
510
sample_codebook_temp = None ,
501
511
mask = None ,
502
512
freeze_codebook = False ,
503
- codebook_transform_fn : Callable | None = None
513
+ codebook_transform_fn : Callable | None = None ,
514
+ ema_update_weight : Tensor | Callable | None = None
504
515
):
505
516
needs_codebook_dim = x .ndim < 4
506
517
sample_codebook_temp = default (sample_codebook_temp , self .sample_codebook_temp )
@@ -585,15 +596,17 @@ def forward(
585
596
embed_onehot [~ mask ] = 0.
586
597
587
598
cluster_size = embed_onehot .sum (dim = 1 )
588
-
589
599
self .all_reduce_fn (cluster_size )
590
- ema_inplace (self .cluster_size .data , cluster_size , self .decay )
591
600
592
601
embed_sum = einsum ('h n d, h n c -> h c d' , flatten , embed_onehot )
593
602
embed_sum = embed_sum .contiguous ()
594
603
self .all_reduce_fn (embed_sum )
595
604
596
- ema_inplace (self .embed_avg .data , embed_sum , self .decay )
605
+ if callable (ema_update_weight ):
606
+ ema_update_weight = ema_update_weight (embed_sum , cluster_size )
607
+
608
+ ema_inplace (self .cluster_size .data , cluster_size , self .decay , ema_update_weight )
609
+ ema_inplace (self .embed_avg .data , embed_sum , self .decay , ema_update_weight )
597
610
598
611
if not self .manual_ema_update :
599
612
self .update_ema ()
@@ -688,9 +701,9 @@ def init_embed_(self, data, mask = None):
688
701
689
702
embed_sum = embed * rearrange (cluster_size , '... -> ... 1' )
690
703
691
- self .embed .data .copy_ (embed )
692
704
self .embed_avg .data .copy_ (embed_sum )
693
705
self .cluster_size .data .copy_ (cluster_size )
706
+ self .update_ema ()
694
707
self .initted .data .copy_ (torch .Tensor ([True ]))
695
708
696
709
def replace (self , batch_samples , batch_mask ):
@@ -731,7 +744,8 @@ def forward(
731
744
sample_codebook_temp = None ,
732
745
mask = None ,
733
746
freeze_codebook = False ,
734
- codebook_transform_fn : Callable | None = None
747
+ codebook_transform_fn : Callable | None = None ,
748
+ ema_update_weight : Tensor | None = None
735
749
):
736
750
needs_codebook_dim = x .ndim < 4
737
751
sample_codebook_temp = default (sample_codebook_temp , self .sample_codebook_temp )
@@ -800,13 +814,15 @@ def forward(
800
814
bins = embed_onehot .sum (dim = 1 )
801
815
self .all_reduce_fn (bins )
802
816
803
- ema_inplace (self .cluster_size .data , bins , self .decay )
804
-
805
817
embed_sum = einsum ('h n d, h n c -> h c d' , flatten , embed_onehot )
806
818
embed_sum = embed_sum .contiguous ()
807
819
self .all_reduce_fn (embed_sum )
808
820
809
- ema_inplace (self .embed_avg .data , embed_sum , self .decay )
821
+ if callable (ema_update_weight ):
822
+ ema_update_weight = ema_update_weight (embed_sum , bins )
823
+
824
+ ema_inplace (self .cluster_size .data , bins , self .decay , ema_update_weight )
825
+ ema_inplace (self .embed_avg .data , embed_sum , self .decay , ema_update_weight )
810
826
811
827
if not self .manual_ema_update :
812
828
self .update_ema ()
@@ -1047,7 +1063,8 @@ def forward(
1047
1063
sample_codebook_temp = None ,
1048
1064
freeze_codebook = None ,
1049
1065
return_loss_breakdown = False ,
1050
- codebook_transform_fn : Callable | None = None
1066
+ codebook_transform_fn : Callable | None = None ,
1067
+ ema_update_weight : Tensor | None = None
1051
1068
):
1052
1069
orig_input , input_requires_grad = x , x .requires_grad
1053
1070
@@ -1103,7 +1120,8 @@ def forward(
1103
1120
sample_codebook_temp = sample_codebook_temp ,
1104
1121
mask = mask ,
1105
1122
freeze_codebook = freeze_codebook ,
1106
- codebook_transform_fn = codebook_transform_fn
1123
+ codebook_transform_fn = codebook_transform_fn ,
1124
+ ema_update_weight = ema_update_weight
1107
1125
)
1108
1126
1109
1127
# quantize
0 commit comments