Skip to content

Commit 6c3f9a7

Browse files
committed
remove the hack, as it does not work
1 parent c14fa4d commit 6c3f9a7

File tree

2 files changed

+4
-16
lines changed

2 files changed

+4
-16
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.14.43"
3+
version = "1.14.44"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,6 @@ def __init__(
280280
threshold_ema_dead_code = 2,
281281
reset_cluster_size = None,
282282
use_ddp = False,
283-
distributed_replace_codes = True,
284283
learnable_codebook = False,
285284
gumbel_sample = gumbel_sample,
286285
sample_codebook_temp = 1.,
@@ -315,8 +314,7 @@ def __init__(
315314

316315
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
317316

318-
self.distributed_replace_codes = distributed_replace_codes
319-
self.replace_sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans and distributed_replace_codes else batched_sample_vectors
317+
self.replace_sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
320318

321319
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
322320
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
@@ -448,9 +446,6 @@ def replace(self, batch_samples, batch_mask):
448446
sampled = self.replace_sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
449447
sampled = rearrange(sampled, '1 ... -> ...')
450448

451-
if not self.distributed_replace_codes:
452-
sampled = maybe_distributed_mean(sampled)
453-
454449
self.embed.data[ind][mask] = sampled
455450
self.cluster_size.data[ind][mask] = self.reset_cluster_size
456451
self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size
@@ -559,7 +554,6 @@ def __init__(
559554
threshold_ema_dead_code = 2,
560555
reset_cluster_size = None,
561556
use_ddp = False,
562-
distributed_replace_codes = True,
563557
learnable_codebook = False,
564558
gumbel_sample = gumbel_sample,
565559
sample_codebook_temp = 1.,
@@ -590,8 +584,7 @@ def __init__(
590584

591585
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
592586

593-
self.distributed_replace_codes = distributed_replace_codes
594-
self.replace_sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans and distributed_replace_codes else batched_sample_vectors
587+
self.replace_sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
595588

596589
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
597590
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
@@ -638,9 +631,6 @@ def replace(self, batch_samples, batch_mask):
638631
sampled = self.replace_sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
639632
sampled = rearrange(sampled, '1 ... -> ...')
640633

641-
if not self.distributed_replace_codes:
642-
sampled = maybe_distributed_mean(sampled)
643-
644634
self.embed.data[ind][mask] = sampled
645635
self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size
646636
self.cluster_size.data[ind][mask] = self.reset_cluster_size
@@ -762,7 +752,6 @@ def __init__(
762752
stochastic_sample_codes = False,
763753
sample_codebook_temp = 1.,
764754
straight_through = False,
765-
distributed_replace_codes = True,
766755
reinmax = False, # using reinmax for improved straight-through, assuming straight through helps at all
767756
sync_codebook = None,
768757
sync_affine_param = False,
@@ -845,8 +834,7 @@ def __init__(
845834
learnable_codebook = has_codebook_orthogonal_loss or learnable_codebook,
846835
sample_codebook_temp = sample_codebook_temp,
847836
gumbel_sample = gumbel_sample_fn,
848-
ema_update = ema_update,
849-
distributed_replace_codes = distributed_replace_codes
837+
ema_update = ema_update
850838
)
851839

852840
if affine_param:

0 commit comments

Comments
 (0)