Skip to content

Commit 4a44819

Browse files
committed
make sure library can work on mac m1/m2 chip - #55
1 parent ae5db01 commit 4a44819

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.6.9',
6+
version = '1.6.12',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@ def l2norm(t):
2727
def log(t, eps = 1e-20):
2828
return torch.log(t.clamp(min = eps))
2929

30+
def lerp(old, new, decay):
31+
is_mps = getattr(old, 'is_mps', False)
32+
33+
if not is_mps:
34+
old.lerp_(new, 1 - decay)
35+
return old
36+
37+
return old * decay + new * (1 - decay)
38+
3039
def pack_one(t, pattern):
3140
return pack([t], pattern)
3241

@@ -441,11 +450,11 @@ def forward(
441450
cluster_size = embed_onehot.sum(dim = 1)
442451

443452
self.all_reduce_fn(cluster_size)
444-
self.cluster_size.data.lerp_(cluster_size, 1 - self.decay)
453+
self.cluster_size = lerp(self.cluster_size, cluster_size, self.decay)
445454

446455
embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot)
447456
self.all_reduce_fn(embed_sum.contiguous())
448-
self.embed_avg.data.lerp_(embed_sum, 1 - self.decay)
457+
self.embed_avg = lerp(self.embed_avg, embed_sum, self.decay)
449458

450459
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True)
451460

@@ -598,11 +607,11 @@ def forward(
598607
bins = embed_onehot.sum(dim = 1)
599608
self.all_reduce_fn(bins)
600609

601-
self.cluster_size.data.lerp_(bins, 1 - self.decay)
610+
self.cluster_size = lerp(self.cluster_size, bins, self.decay)
602611

603612
embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot)
604613
self.all_reduce_fn(embed_sum.contiguous())
605-
self.embed_avg.data.lerp_(embed_sum, 1 - self.decay)
614+
self.embed_avg = lerp(self.embed_avg, embed_sum, self.decay)
606615

607616
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True)
608617

0 commit comments

Comments
 (0)