@@ -104,9 +104,10 @@ def expire_codes_(self, batch_samples):
104
104
return
105
105
106
106
expired_codes = self .cluster_size < self .threshold_ema_dead_code
107
- if torch .any (expired_codes ):
108
- batch_samples = rearrange (batch_samples , '... d -> (...) d' )
109
- self .replace (batch_samples , mask = expired_codes )
107
+ if not torch .any (expired_codes ):
108
+ return
109
+ batch_samples = rearrange (batch_samples , '... d -> (...) d' )
110
+ self .replace (batch_samples , mask = expired_codes )
110
111
111
112
def forward (self , x ):
112
113
shape , dtype = x .shape , x .dtype
@@ -163,6 +164,7 @@ def __init__(
163
164
self .threshold_ema_dead_code = threshold_ema_dead_code
164
165
165
166
self .register_buffer ('initted' , torch .Tensor ([not kmeans_init ]))
167
+ self .register_buffer ('cluster_size' , torch .zeros (codebook_size ))
166
168
self .register_buffer ('embed' , embed )
167
169
168
170
def init_embed_ (self , data ):
@@ -185,9 +187,10 @@ def expire_codes_(self, batch_samples):
185
187
return
186
188
187
189
expired_codes = self .cluster_size < self .threshold_ema_dead_code
188
- if torch .any (expired_codes ):
189
- batch_samples = rearrange (batch_samples , '... d -> (...) d' )
190
- self .replace (batch_samples , mask = expired_codes )
190
+ if not torch .any (expired_codes ):
191
+ return
192
+ batch_samples = rearrange (batch_samples , '... d -> (...) d' )
193
+ self .replace (batch_samples , mask = expired_codes )
191
194
192
195
def forward (self , x ):
193
196
shape , dtype = x .shape , x .dtype
@@ -207,6 +210,8 @@ def forward(self, x):
207
210
208
211
if self .training :
209
212
bins = embed_onehot .sum (0 )
213
+ ema_inplace (self .cluster_size , bins , self .decay )
214
+
210
215
zero_mask = (bins == 0 )
211
216
bins = bins .masked_fill (zero_mask , 1. )
212
217
0 commit comments