Skip to content

Commit 41d3b1d

Browse files
authored
Merge pull request #121 from lucidrains/fix-lfq-distance
calculate full euclidean distance in lfq
2 parents 665f4b6 + f87bf54 commit 41d3b1d

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.14.9"
3+
version = "1.14.10"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@ def log(t, eps = 1e-5):
4848
def entropy(prob):
4949
return (-prob * log(prob)).sum(dim=-1)
5050

51+
# distance
52+
53+
def euclidean_distance_squared(x, y):
54+
x2 = reduce(x ** 2, '... n d -> ... n 1', 'sum')
55+
y2 = reduce(y ** 2, 'n d -> n', 'sum')
56+
xy = einsum('... i d, j d -> ... i j', x, y) * -2
57+
return x2 + xy + y2
58+
5159
# class
5260

5361
class LFQ(Module):
@@ -218,8 +226,7 @@ def forward(
218226
# entropy aux loss
219227

220228
if self.training:
221-
# the same as euclidean distance up to a constant
222-
distance = -2 * einsum('... i d, j d -> ... i j', original_input, self.codebook)
229+
distance = euclidean_distance_squared(original_input, self.codebook)
223230

224231
prob = (-distance * inv_temperature).softmax(dim = -1)
225232

0 commit comments

Comments
 (0)