Skip to content

Commit 09a778f

Browse files
committed
address issue #45
1 parent 5838df7 commit 09a778f

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.2.2',
6+
version = '1.2.3',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def __init__(
193193
decay = 0.8,
194194
eps = 1e-5,
195195
threshold_ema_dead_code = 2,
196+
reset_cluster_size = None,
196197
use_ddp = False,
197198
learnable_codebook = False,
198199
sample_codebook_temp = 0
@@ -208,6 +209,7 @@ def __init__(
208209
self.kmeans_iters = kmeans_iters
209210
self.eps = eps
210211
self.threshold_ema_dead_code = threshold_ema_dead_code
212+
self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code)
211213
self.sample_codebook_temp = sample_codebook_temp
212214

213215
assert not (use_ddp and num_codebooks > 1 and kmeans_init), 'kmeans init is not compatible with multiple codebooks in distributed environment for now'
@@ -250,7 +252,12 @@ def replace(self, batch_samples, batch_mask):
250252
continue
251253

252254
sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
253-
self.embed.data[ind][mask] = rearrange(sampled, '1 ... -> ...')
255+
sampled = rearrange(sampled, '1 ... -> ...')
256+
257+
self.embed.data[ind][mask] = sampled
258+
259+
self.cluster_size.data[ind][mask] = self.reset_cluster_size
260+
self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size
254261

255262
def expire_codes_(self, batch_samples):
256263
if self.threshold_ema_dead_code == 0:
@@ -323,6 +330,7 @@ def __init__(
323330
decay = 0.8,
324331
eps = 1e-5,
325332
threshold_ema_dead_code = 2,
333+
reset_cluster_size = None,
326334
use_ddp = False,
327335
learnable_codebook = False,
328336
sample_codebook_temp = 0.
@@ -341,6 +349,7 @@ def __init__(
341349
self.kmeans_iters = kmeans_iters
342350
self.eps = eps
343351
self.threshold_ema_dead_code = threshold_ema_dead_code
352+
self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code)
344353
self.sample_codebook_temp = sample_codebook_temp
345354

346355
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
@@ -382,7 +391,10 @@ def replace(self, batch_samples, batch_mask):
382391
continue
383392

384393
sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
385-
self.embed.data[ind][mask] = rearrange(sampled, '1 ... -> ...')
394+
sampled = rearrange(sampled, '1 ... -> ...')
395+
396+
self.embed.data[ind][mask] = sampled
397+
self.cluster_size.data[ind][mask] = self.reset_cluster_size
386398

387399
def expire_codes_(self, batch_samples):
388400
if self.threshold_ema_dead_code == 0:

0 commit comments

Comments
 (0)