Skip to content

Commit dcbfc30

Browse files
committed
make sure code expiry feature works with cosine sim
1 parent bec307d commit dcbfc30

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '0.3.4',
6+
version = '0.3.5',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
author = 'Phil Wang',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,10 @@ def expire_codes_(self, batch_samples):
104104
return
105105

106106
expired_codes = self.cluster_size < self.threshold_ema_dead_code
107-
if torch.any(expired_codes):
108-
batch_samples = rearrange(batch_samples, '... d -> (...) d')
109-
self.replace(batch_samples, mask = expired_codes)
107+
if not torch.any(expired_codes):
108+
return
109+
batch_samples = rearrange(batch_samples, '... d -> (...) d')
110+
self.replace(batch_samples, mask = expired_codes)
110111

111112
def forward(self, x):
112113
shape, dtype = x.shape, x.dtype
@@ -163,6 +164,7 @@ def __init__(
163164
self.threshold_ema_dead_code = threshold_ema_dead_code
164165

165166
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
167+
self.register_buffer('cluster_size', torch.zeros(codebook_size))
166168
self.register_buffer('embed', embed)
167169

168170
def init_embed_(self, data):
@@ -185,9 +187,10 @@ def expire_codes_(self, batch_samples):
185187
return
186188

187189
expired_codes = self.cluster_size < self.threshold_ema_dead_code
188-
if torch.any(expired_codes):
189-
batch_samples = rearrange(batch_samples, '... d -> (...) d')
190-
self.replace(batch_samples, mask = expired_codes)
190+
if not torch.any(expired_codes):
191+
return
192+
batch_samples = rearrange(batch_samples, '... d -> (...) d')
193+
self.replace(batch_samples, mask = expired_codes)
191194

192195
def forward(self, x):
193196
shape, dtype = x.shape, x.dtype
@@ -207,6 +210,8 @@ def forward(self, x):
207210

208211
if self.training:
209212
bins = embed_onehot.sum(0)
213+
ema_inplace(self.cluster_size, bins, self.decay)
214+
210215
zero_mask = (bins == 0)
211216
bins = bins.masked_fill(zero_mask, 1.)
212217

0 commit comments

Comments
 (0)