Skip to content

Commit b1ea624

Browse files
authored
make cdist output safe to differentiate through
thanks @Boltzmachine
1 parent 976335f commit b1ea624

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ def Sequential(*modules):
4949

5050
return nn.Sequential(*modules)
5151

52-
def cdist(x, y):
52+
def cdist(x, y, eps = 1e-8):
5353
x2 = reduce(x ** 2, 'b n d -> b n', 'sum')
5454
y2 = reduce(y ** 2, 'b n d -> b n', 'sum')
5555
xy = einsum('b i d, b j d -> b i j', x, y) * -2
56-
return (rearrange(x2, 'b i -> b i 1') + rearrange(y2, 'b j -> b 1 j') + xy).clamp(min = 0).sqrt()
56+
return (rearrange(x2, 'b i -> b i 1') + rearrange(y2, 'b j -> b 1 j') + xy).clamp(min = eps).sqrt()
5757

5858
def log(t, eps = 1e-20):
5959
return torch.log(t.clamp(min = eps))

0 commit comments

Comments
 (0)