Skip to content

Commit 4a643eb

Browse files
committed
revert uneeded fix #120
1 parent 6dbb3ed commit 4a643eb

File tree

2 files changed

+3
-10
lines changed

2 files changed

+3
-10
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.14.11"
3+
version = "1.14.12"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,6 @@ def log(t, eps = 1e-5):
4848
def entropy(prob):
4949
return (-prob * log(prob)).sum(dim=-1)
5050

51-
# distance
52-
53-
def euclidean_distance_squared(x, y):
54-
x2 = reduce(x ** 2, '... n d -> ... n 1', 'sum')
55-
y2 = reduce(y ** 2, 'n d -> n', 'sum')
56-
xy = einsum('... i d, j d -> ... i j', x, y) * -2
57-
return x2 + xy + y2
58-
5951
# class
6052

6153
class LFQ(Module):
@@ -226,7 +218,8 @@ def forward(
226218
# entropy aux loss
227219

228220
if self.training:
229-
distance = euclidean_distance_squared(original_input, self.codebook)
221+
# the same as euclidean distance up to a constant
222+
distance = -2 * einsum('... i d, j d -> ... i j', original_input, self.codebook)
230223

231224
prob = (-distance * inv_temperature).softmax(dim = -1)
232225

0 commit comments

Comments
 (0)