Skip to content

Commit 46816f3

Browse files
committed
automatically sync codebook, if distributed is initialized and world size greater than 1
1 parent 0c6cea2 commit 46816f3

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.7.0',
6+
version = '1.7.1',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,7 @@ def __init__(
710710
sample_codebook_temp = 1.,
711711
straight_through = False,
712712
reinmax = False, # using reinmax for improved straight-through, assuming straight through helps at all
713-
sync_codebook = False,
713+
sync_codebook = None,
714714
sync_affine_param = False,
715715
ema_update = True,
716716
learnable_codebook = False,
@@ -760,6 +760,9 @@ def __init__(
760760
straight_through = straight_through
761761
)
762762

763+
if not exists(sync_codebook):
764+
sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1
765+
763766
codebook_kwargs = dict(
764767
dim = codebook_dim,
765768
num_codebooks = heads if separate_codebook_per_head else 1,

0 commit comments

Comments
 (0)