@@ -85,7 +85,11 @@ def __init__(
85
85
self .register_buffer ('embed' , embed )
86
86
self .register_buffer ('embed_avg' , embed .clone ())
87
87
88
+ @torch .jit .ignore
88
89
def init_embed_ (self , data ):
90
+ if self .initted :
91
+ return
92
+
89
93
embed , cluster_size = kmeans (data , self .codebook_size , self .kmeans_iters )
90
94
self .embed .data .copy_ (embed )
91
95
self .embed_avg .data .copy_ (embed .clone ())
@@ -115,8 +119,7 @@ def forward(self, x):
115
119
flatten = rearrange (x , '... d -> (...) d' )
116
120
embed = self .embed .t ()
117
121
118
- if not self .initted :
119
- self .init_embed_ (flatten )
122
+ self .init_embed_ (flatten )
120
123
121
124
dist = - (
122
125
flatten .pow (2 ).sum (1 , keepdim = True )
@@ -168,7 +171,11 @@ def __init__(
168
171
self .register_buffer ('cluster_size' , torch .zeros (codebook_size ))
169
172
self .register_buffer ('embed' , embed )
170
173
174
+ @torch .jit .ignore
171
175
def init_embed_ (self , data ):
176
+ if self .initted :
177
+ return
178
+
172
179
embed , cluster_size = kmeans (data , self .codebook_size , self .kmeans_iters ,
173
180
use_cosine_sim = True )
174
181
self .embed .data .copy_ (embed )
@@ -199,8 +206,7 @@ def forward(self, x):
199
206
flatten = rearrange (x , '... d -> (...) d' )
200
207
flatten = l2norm (flatten )
201
208
202
- if not self .initted :
203
- self .init_embed_ (flatten )
209
+ self .init_embed_ (flatten )
204
210
205
211
embed = l2norm (self .embed )
206
212
dist = flatten @ embed .t ()
0 commit comments