Skip to content

Commit 7f26a57

Browse files
committed
make library jittable
1 parent e562e0f commit 7f26a57

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '0.3.10',
6+
version = '0.3.11',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
author = 'Phil Wang',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,11 @@ def __init__(
8585
self.register_buffer('embed', embed)
8686
self.register_buffer('embed_avg', embed.clone())
8787

88+
@torch.jit.ignore
8889
def init_embed_(self, data):
90+
if self.initted:
91+
return
92+
8993
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
9094
self.embed.data.copy_(embed)
9195
self.embed_avg.data.copy_(embed.clone())
@@ -115,8 +119,7 @@ def forward(self, x):
115119
flatten = rearrange(x, '... d -> (...) d')
116120
embed = self.embed.t()
117121

118-
if not self.initted:
119-
self.init_embed_(flatten)
122+
self.init_embed_(flatten)
120123

121124
dist = -(
122125
flatten.pow(2).sum(1, keepdim=True)
@@ -168,7 +171,11 @@ def __init__(
168171
self.register_buffer('cluster_size', torch.zeros(codebook_size))
169172
self.register_buffer('embed', embed)
170173

174+
@torch.jit.ignore
171175
def init_embed_(self, data):
176+
if self.initted:
177+
return
178+
172179
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters,
173180
use_cosine_sim = True)
174181
self.embed.data.copy_(embed)
@@ -199,8 +206,7 @@ def forward(self, x):
199206
flatten = rearrange(x, '... d -> (...) d')
200207
flatten = l2norm(flatten)
201208

202-
if not self.initted:
203-
self.init_embed_(flatten)
209+
self.init_embed_(flatten)
204210

205211
embed = l2norm(self.embed)
206212
dist = flatten @ embed.t()

0 commit comments

Comments
 (0)