1
1
import torch
2
- from torch import nn
2
+ from torch import nn , einsum
3
3
import torch .nn .functional as F
4
+ from einops import rearrange , repeat
4
5
5
6
def exists (val ):
6
7
return val is not None
@@ -11,9 +12,36 @@ def default(val, d):
11
12
def ema_inplace (moving_avg , new , decay ):
12
13
moving_avg .data .mul_ (decay ).add_ (new , alpha = (1 - decay ))
13
14
14
- def laplace_smoothing (x , n_categories , eps = 1e-5 ):
15
+ def laplace_smoothing (x , n_categories , eps = 1e-5 ):
15
16
return (x + eps ) / (x .sum () + n_categories * eps )
16
17
18
+ def kmeans (x , num_clusters , num_iters = 10 ):
19
+ samples = rearrange (x , '... d -> (...) d' )
20
+ num_samples , dim , dtype , device = * samples .shape , x .dtype , x .device
21
+
22
+ if num_samples >= num_clusters :
23
+ indices = torch .randperm (num_samples , device = device )[:num_clusters ]
24
+ else :
25
+ indices = torch .randint (0 , num_samples , (num_clusters ,), device = device )
26
+
27
+ means = samples [indices ]
28
+
29
+ for _ in range (num_iters ):
30
+ diffs = rearrange (samples , 'n d -> n () d' ) - rearrange (means , 'c d -> () c d' )
31
+ dists = (diffs ** 2 ).sum (dim = - 1 )
32
+ buckets = dists .argmin (dim = - 1 )
33
+
34
+ bins = torch .bincount (buckets , minlength = num_clusters )
35
+ zero_mask = bins == 0
36
+ bins = bins .masked_fill (zero_mask , 1 )
37
+
38
+ new_means = buckets .new_zeros (num_clusters , dim , dtype = dtype )
39
+ new_means .scatter_add_ (0 , repeat (buckets , 'n -> n d' , d = dim ), samples )
40
+ new_means = new_means / bins [..., None ]
41
+ means = torch .where (zero_mask [..., None ], means , new_means )
42
+
43
+ return rearrange (means , 'n d -> d n' )
44
+
17
45
class VectorQuantize (nn .Module ):
18
46
def __init__ (
19
47
self ,
@@ -23,6 +51,8 @@ def __init__(
23
51
commitment = 1. ,
24
52
eps = 1e-5 ,
25
53
n_embed = None ,
54
+ kmeans_init = False ,
55
+ kmeans_iters = 10
26
56
):
27
57
super ().__init__ ()
28
58
n_embed = default (n_embed , codebook_size )
@@ -33,26 +63,42 @@ def __init__(
33
63
self .eps = eps
34
64
self .commitment = commitment
35
65
36
- embed = torch .randn (dim , n_embed )
37
- self .register_buffer ('embed' , embed )
66
+ init_fn = torch .randn if not kmeans_init else torch .zeros
67
+ embed = init_fn (dim , n_embed )
68
+
69
+ self .kmeans_iters = kmeans_iters
70
+ self .register_buffer ('initted' , torch .Tensor ([not kmeans_init ]))
38
71
self .register_buffer ('cluster_size' , torch .zeros (n_embed ))
72
+ self .register_buffer ('embed' , embed )
39
73
self .register_buffer ('embed_avg' , embed .clone ())
40
74
41
75
@property
42
76
def codebook (self ):
43
77
return self .embed .transpose (0 , 1 )
44
78
79
+ def init_embed_ (self , data ):
80
+ embed = kmeans (data , self .n_embed , self .kmeans_iters )
81
+ self .embed .data .copy_ (embed )
82
+ self .embed_avg .data .copy_ (embed .clone ())
83
+ self .initted .data .copy_ (torch .Tensor ([True ]))
84
+
45
85
def forward (self , input ):
86
+ if not self .initted :
87
+ self .init_embed_ (input )
88
+
46
89
dtype = input .dtype
47
90
flatten = input .reshape (- 1 , self .dim )
48
91
dist = (
49
92
flatten .pow (2 ).sum (1 , keepdim = True )
50
93
- 2 * flatten @ self .embed
51
94
+ self .embed .pow (2 ).sum (0 , keepdim = True )
52
95
)
96
+
53
97
_ , embed_ind = (- dist ).max (1 )
54
98
embed_onehot = F .one_hot (embed_ind , self .n_embed ).type (dtype )
55
99
embed_ind = embed_ind .view (* input .shape [:- 1 ])
100
+
101
+ commit_loss = 0.
56
102
quantize = F .embedding (embed_ind , self .embed .transpose (0 , 1 ))
57
103
58
104
if self .training :
@@ -63,6 +109,7 @@ def forward(self, input):
63
109
embed_normalized = self .embed_avg / cluster_size .unsqueeze (0 )
64
110
self .embed .data .copy_ (embed_normalized )
65
111
66
- loss = F .mse_loss (quantize .detach (), input ) * self .commitment
67
- quantize = input + (quantize - input ).detach ()
68
- return quantize , embed_ind , loss
112
+ commit_loss = F .mse_loss (quantize .detach (), input ) * self .commitment
113
+ quantize = input + (quantize - input ).detach ()
114
+
115
+ return quantize , embed_ind , commit_loss
0 commit comments