@@ -15,7 +15,7 @@ def ema_inplace(moving_avg, new, decay):
15
15
def laplace_smoothing (x , n_categories , eps = 1e-5 ):
16
16
return (x + eps ) / (x .sum () + n_categories * eps )
17
17
18
- def kmeans (x , num_clusters , num_iters = 10 ):
18
+ def kmeans (x , num_clusters , num_iters = 10 , use_cosine_sim = False ):
19
19
samples = rearrange (x , '... d -> (...) d' )
20
20
num_samples , dim , dtype , device = * samples .shape , x .dtype , x .device
21
21
@@ -27,9 +27,13 @@ def kmeans(x, num_clusters, num_iters = 10):
27
27
means = samples [indices ]
28
28
29
29
for _ in range (num_iters ):
30
- diffs = rearrange (samples , 'n d -> n () d' ) - rearrange (means , 'c d -> () c d' )
31
- dists = (diffs ** 2 ).sum (dim = - 1 )
32
- buckets = dists .argmin (dim = - 1 )
30
+ if use_cosine_sim :
31
+ dists = samples @ means .t ()
32
+ buckets = dists .max (dim = - 1 ).indices
33
+ else :
34
+ diffs = rearrange (samples , 'n d -> n () d' ) - rearrange (means , 'c d -> () c d' )
35
+ dists = (diffs ** 2 ).sum (dim = - 1 )
36
+ buckets = dists .argmin (dim = - 1 )
33
37
34
38
bins = torch .bincount (buckets , minlength = num_clusters )
35
39
zero_mask = bins == 0
@@ -38,86 +42,186 @@ def kmeans(x, num_clusters, num_iters = 10):
38
42
new_means = buckets .new_zeros (num_clusters , dim , dtype = dtype )
39
43
new_means .scatter_add_ (0 , repeat (buckets , 'n -> n d' , d = dim ), samples )
40
44
new_means = new_means / bins [..., None ]
45
+
46
+ if use_cosine_sim :
47
+ new_means = F .normalize (new_means , dim = - 1 )
48
+
41
49
means = torch .where (zero_mask [..., None ], means , new_means )
42
50
43
- return rearrange ( means , 'n d -> d n' )
51
+ return means
44
52
45
- class VectorQuantize (nn .Module ):
53
+ # distance types
54
+
55
+ class EuclideanCodebook (nn .Module ):
46
56
def __init__ (
47
57
self ,
48
58
dim ,
49
59
codebook_size ,
50
- decay = 0.8 ,
51
- commitment = 1. ,
52
- eps = 1e-5 ,
53
- n_embed = None ,
54
60
kmeans_init = False ,
55
61
kmeans_iters = 10 ,
56
- codebook_dim = None
62
+ decay = 0.8 ,
63
+ eps = 1e-5
57
64
):
58
65
super ().__init__ ()
59
- n_embed = default (n_embed , codebook_size )
60
- self .n_embed = n_embed
61
-
62
- codebook_dim = default (codebook_dim , dim )
63
- requires_projection = codebook_dim != dim
64
- self .project_in = nn .Linear (dim , codebook_dim ) if requires_projection else nn .Identity ()
65
- self .project_out = nn .Linear (codebook_dim , dim ) if requires_projection else nn .Identity ()
66
-
67
66
self .decay = decay
68
- self .eps = eps
69
- self .commitment = commitment
70
-
71
67
init_fn = torch .randn if not kmeans_init else torch .zeros
72
- embed = init_fn (codebook_dim , n_embed )
68
+ embed = init_fn (codebook_size , dim )
73
69
70
+ self .codebook_size = codebook_size
74
71
self .kmeans_iters = kmeans_iters
72
+ self .eps = eps
73
+
75
74
self .register_buffer ('initted' , torch .Tensor ([not kmeans_init ]))
76
- self .register_buffer ('cluster_size' , torch .zeros (n_embed ))
75
+ self .register_buffer ('cluster_size' , torch .zeros (codebook_size ))
77
76
self .register_buffer ('embed' , embed )
78
77
self .register_buffer ('embed_avg' , embed .clone ())
79
78
80
- @property
81
- def codebook (self ):
82
- return self .embed .transpose (0 , 1 )
83
-
84
79
def init_embed_ (self , data ):
85
- embed = kmeans (data , self .n_embed , self .kmeans_iters )
80
+ embed = kmeans (data , self .codebook_size , self .kmeans_iters )
86
81
self .embed .data .copy_ (embed )
87
82
self .embed_avg .data .copy_ (embed .clone ())
88
83
self .initted .data .copy_ (torch .Tensor ([True ]))
89
84
90
- def forward (self , input ):
91
- input = self .project_in (input )
92
-
85
+ def forward (self , x ):
93
86
if not self .initted :
94
- self .init_embed_ (input )
87
+ self .init_embed_ (x )
88
+
89
+ shape , dtype = x .shape , x .dtype
90
+ flatten = rearrange (x , '... d -> (...) d' )
91
+ embed = self .embed .t ()
95
92
96
- dtype = input .dtype
97
- flatten = rearrange (input , '... d -> (...) d' )
98
- dist = (
93
+ dist = - (
99
94
flatten .pow (2 ).sum (1 , keepdim = True )
100
- - 2 * flatten @ self . embed
101
- + self . embed .pow (2 ).sum (0 , keepdim = True )
95
+ - 2 * flatten @ embed
96
+ + embed .pow (2 ).sum (0 , keepdim = True )
102
97
)
103
98
104
- _ , embed_ind = (- dist ).max (1 )
105
- embed_onehot = F .one_hot (embed_ind , self .n_embed ).type (dtype )
106
- embed_ind = embed_ind .view (* input .shape [:- 1 ])
107
-
108
- commit_loss = 0.
109
- quantize = F .embedding (embed_ind , self .embed .transpose (0 , 1 ))
99
+ embed_ind = dist .max (dim = - 1 ).indices
100
+ embed_onehot = F .one_hot (embed_ind , self .codebook_size ).type (x .dtype )
101
+ embed_ind = embed_ind .view (* shape [:- 1 ])
102
+ quantize = F .embedding (embed_ind , self .embed )
110
103
111
104
if self .training :
112
105
ema_inplace (self .cluster_size , embed_onehot .sum (0 ), self .decay )
113
- embed_sum = flatten .transpose ( 0 , 1 ) @ embed_onehot
114
- ema_inplace (self .embed_avg , embed_sum , self .decay )
115
- cluster_size = laplace_smoothing (self .cluster_size , self .n_embed , self .eps ) * self .cluster_size .sum ()
116
- embed_normalized = self .embed_avg / cluster_size .unsqueeze (0 )
106
+ embed_sum = flatten .t ( ) @ embed_onehot
107
+ ema_inplace (self .embed_avg , embed_sum . t () , self .decay )
108
+ cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum ()
109
+ embed_normalized = self .embed_avg / cluster_size .unsqueeze (1 )
117
110
self .embed .data .copy_ (embed_normalized )
118
111
119
- commit_loss = F .mse_loss (quantize .detach (), input ) * self .commitment
120
- quantize = input + (quantize - input ).detach ()
112
+ return quantize , embed_ind
113
+
114
+ class CosineSimCodebook (nn .Module ):
115
+ def __init__ (
116
+ self ,
117
+ dim ,
118
+ codebook_size ,
119
+ kmeans_init = False ,
120
+ kmeans_iters = 10 ,
121
+ decay = 0.8 ,
122
+ eps = 1e-5
123
+ ):
124
+ super ().__init__ ()
125
+ self .decay = decay
126
+
127
+ if not kmeans_init :
128
+ embed = F .normalize (torch .randn (codebook_size , dim ), dim = - 1 )
129
+ else :
130
+ embed = torch .zeros (codebook_size , dim )
131
+
132
+ self .codebook_size = codebook_size
133
+ self .kmeans_iters = kmeans_iters
134
+ self .eps = eps
135
+
136
+ self .register_buffer ('initted' , torch .Tensor ([not kmeans_init ]))
137
+ self .register_buffer ('embed' , embed )
138
+
139
+ def init_embed_ (self , data ):
140
+ embed = kmeans (data , self .codebook_size , self .kmeans_iters , use_cosine_sim = True )
141
+ self .embed .data .copy_ (embed )
142
+ self .initted .data .copy_ (torch .Tensor ([True ]))
143
+
144
+ def forward (self , x ):
145
+ shape , dtype = x .shape , x .dtype
146
+ flatten = rearrange (x , '... d -> (...) d' )
147
+ flatten = F .normalize (flatten , dim = - 1 )
148
+ embed = F .normalize (self .embed , dim = - 1 )
149
+
150
+ if not self .initted :
151
+ self .init_embed_ (flatten )
152
+
153
+ dist = flatten @ embed .t ()
154
+ embed_ind = dist .max (dim = - 1 ).indices
155
+ embed_onehot = F .one_hot (embed_ind , self .codebook_size ).type (dtype )
156
+ embed_ind = embed_ind .view (* shape [:- 1 ])
157
+
158
+ quantize = F .embedding (embed_ind , self .embed )
159
+
160
+ if self .training :
161
+ bins = embed_onehot .sum (0 )
162
+ zero_mask = (bins == 0 )
163
+ bins = bins .masked_fill (zero_mask , 1. )
164
+
165
+ embed_sum = flatten .t () @ embed_onehot
166
+ embed_normalized = (embed_sum / bins .unsqueeze (0 )).t ()
167
+ embed_normalized = F .normalize (embed_normalized , dim = - 1 )
168
+ embed_normalized = torch .where (zero_mask [..., None ], embed , embed_normalized )
169
+ ema_inplace (self .embed , embed_normalized , self .decay )
170
+
171
+ return quantize , embed_ind
172
+
173
+ # main class
174
+
175
+ class VectorQuantize (nn .Module ):
176
+ def __init__ (
177
+ self ,
178
+ dim ,
179
+ codebook_size ,
180
+ n_embed = None ,
181
+ codebook_dim = None ,
182
+ decay = 0.8 ,
183
+ commitment = 1. ,
184
+ eps = 1e-5 ,
185
+ kmeans_init = False ,
186
+ kmeans_iters = 10 ,
187
+ use_cosine_sim = False
188
+ ):
189
+ super ().__init__ ()
190
+ n_embed = default (n_embed , codebook_size )
191
+
192
+ codebook_dim = default (codebook_dim , dim )
193
+ requires_projection = codebook_dim != dim
194
+ self .project_in = nn .Linear (dim , codebook_dim ) if requires_projection else nn .Identity ()
195
+ self .project_out = nn .Linear (codebook_dim , dim ) if requires_projection else nn .Identity ()
196
+
197
+ self .eps = eps
198
+ self .commitment = commitment
199
+
200
+ klass = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook
201
+
202
+ self ._codebook = klass (
203
+ dim = codebook_dim ,
204
+ codebook_size = n_embed ,
205
+ kmeans_init = kmeans_init ,
206
+ kmeans_iters = kmeans_iters ,
207
+ decay = decay ,
208
+ eps = eps
209
+ )
210
+
211
+ @property
212
+ def codebook (self ):
213
+ return self ._codebook .codebook
214
+
215
+ def forward (self , x ):
216
+ dtype = x .dtype
217
+ x = self .project_in (x )
218
+
219
+ quantize , embed_ind = self ._codebook (x )
220
+
221
+ commit_loss = 0.
222
+ if self .training :
223
+ commit_loss = F .mse_loss (quantize .detach (), x ) * self .commitment
224
+ quantize = x + (quantize - x ).detach ()
121
225
122
226
quantize = self .project_out (quantize )
123
227
return quantize , embed_ind , commit_loss
0 commit comments