Skip to content

Commit 35a8a41

Browse files
committed
try to fix sync seed again
1 parent ea13758 commit 35a8a41

File tree

4 files changed

+14
-14
lines changed

4 files changed

+14
-14
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.18.7"
3+
version = "1.18.8"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/residual_fsq.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def forward(
185185
# check if seed is manually passed in
186186

187187
if not exists(rand_quantize_dropout_fixed_seed):
188-
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed()
188+
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
189189

190190
rand = random.Random(rand_quantize_dropout_fixed_seed)
191191

@@ -296,7 +296,7 @@ def forward(
296296
x,
297297
return_all_codes = False
298298
):
299-
shape, split_dim = x.shape, self.split_dim
299+
shape, split_dim, device = x.shape, self.split_dim, x.device
300300
assert shape[split_dim] == self.dim
301301

302302
# split the feature dimension into groups
@@ -305,7 +305,7 @@ def forward(
305305

306306
forward_kwargs = dict(
307307
return_all_codes = return_all_codes,
308-
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed()
308+
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
309309
)
310310

311311
# invoke residual vq on each group

vector_quantize_pytorch/residual_lfq.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def round_up_multiple(num, mult):
3131
def is_distributed():
3232
return dist.is_initialized() and dist.get_world_size() > 1
3333

34-
def get_maybe_sync_seed(max_size = 10_000):
35-
rand_int = torch.randint(0, max_size, ())
34+
def get_maybe_sync_seed(device, max_size = 10_000):
35+
rand_int = torch.randint(0, max_size, (), device = device)
3636

3737
if is_distributed():
3838
dist.all_reduce(rand_int)
@@ -162,7 +162,7 @@ def forward(
162162
# check if seed is manually passed in
163163

164164
if not exists(rand_quantize_dropout_fixed_seed):
165-
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed()
165+
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
166166

167167
rand = random.Random(rand_quantize_dropout_fixed_seed)
168168

@@ -262,7 +262,7 @@ def forward(
262262
mask = None,
263263
return_all_codes = False
264264
):
265-
shape, split_dim = x.shape, self.split_dim
265+
shape, split_dim, device = x.shape, self.split_dim, x.device
266266
assert shape[split_dim] == self.dim
267267

268268
# split the feature dimension into groups
@@ -272,7 +272,7 @@ def forward(
272272
forward_kwargs = dict(
273273
mask = mask,
274274
return_all_codes = return_all_codes,
275-
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed()
275+
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
276276
)
277277

278278
# invoke residual vq on each group

vector_quantize_pytorch/residual_vq.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def round_up_multiple(num, mult):
3636
def is_distributed():
3737
return dist.is_initialized() and dist.get_world_size() > 1
3838

39-
def get_maybe_sync_seed(max_size = 10_000):
40-
rand_int = torch.randint(0, max_size, ())
39+
def get_maybe_sync_seed(device, max_size = 10_000):
40+
rand_int = torch.randint(0, max_size, (), device = device)
4141

4242
if is_distributed():
4343
dist.all_reduce(rand_int)
@@ -296,7 +296,7 @@ def forward(
296296
# check if seed is manually passed in
297297

298298
if not exists(rand_quantize_dropout_fixed_seed):
299-
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed()
299+
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
300300

301301
rand = random.Random(rand_quantize_dropout_fixed_seed)
302302

@@ -452,7 +452,7 @@ def forward(
452452
freeze_codebook = False,
453453
mask = None,
454454
):
455-
shape, split_dim = x.shape, self.split_dim
455+
shape, split_dim, device = x.shape, self.split_dim, x.device
456456
assert shape[split_dim] == self.dim
457457

458458
# split the feature dimension into groups
@@ -468,7 +468,7 @@ def forward(
468468
sample_codebook_temp = sample_codebook_temp,
469469
mask = mask,
470470
freeze_codebook = freeze_codebook,
471-
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed()
471+
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
472472
)
473473

474474
# invoke residual vq on each group

0 commit comments

Comments
 (0)