Skip to content

Commit 1de96c0

Browse files
committed
fix inplace ema
1 parent 8a5a159 commit 1de96c0

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.6.14',
6+
version = '1.6.16',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,13 @@ def l2norm(t):
2727
def log(t, eps = 1e-20):
2828
return torch.log(t.clamp(min = eps))
2929

30-
def lerp(old, new, decay):
31-
is_mps = getattr(old, 'is_mps', False)
30+
def ema_inplace(old, new, decay):
31+
is_mps = str(old.device) == 'mps'
3232

3333
if not is_mps:
3434
old.lerp_(new, 1 - decay)
35-
return old
36-
37-
return old * decay + new * (1 - decay)
35+
else:
36+
old.mul_(decay).add_(new * (1 - decay))
3837

3938
def pack_one(t, pattern):
4039
return pack([t], pattern)
@@ -450,11 +449,11 @@ def forward(
450449
cluster_size = embed_onehot.sum(dim = 1)
451450

452451
self.all_reduce_fn(cluster_size)
453-
self.cluster_size = lerp(self.cluster_size, cluster_size, self.decay)
452+
ema_inplace(self.cluster_size, cluster_size, self.decay)
454453

455454
embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot)
456455
self.all_reduce_fn(embed_sum.contiguous())
457-
self.embed_avg = lerp(self.embed_avg, embed_sum, self.decay)
456+
ema_inplace(self.embed_avg, embed_sum, self.decay)
458457

459458
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True)
460459

@@ -607,11 +606,11 @@ def forward(
607606
bins = embed_onehot.sum(dim = 1)
608607
self.all_reduce_fn(bins)
609608

610-
self.cluster_size = lerp(self.cluster_size, bins, self.decay)
609+
ema_inplace(self.cluster_size, bins, self.decay)
611610

612611
embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot)
613612
self.all_reduce_fn(embed_sum.contiguous())
614-
self.embed_avg = lerp(self.embed_avg, embed_sum, self.decay)
613+
ema_inplace(self.embed_avg, embed_sum, self.decay)
615614

616615
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True)
617616

0 commit comments

Comments
 (0)