Skip to content

Commit ea13758

Browse files
committed
attempt to address lucidrains/audiolm-pytorch#279 again
1 parent f97a37b commit ea13758

File tree

4 files changed

+42
-38
lines changed

4 files changed

+42
-38
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.18.6"
3+
version = "1.18.7"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/residual_fsq.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import random
22
from math import log2
3-
from functools import partial, cache
3+
from functools import partial
44

55
from typing import List
66

@@ -33,10 +33,17 @@ def round_up_multiple(num, mult):
3333

3434
# distributed helpers
3535

36-
@cache
3736
def is_distributed():
3837
return dist.is_initialized() and dist.get_world_size() > 1
3938

39+
def get_maybe_sync_seed(max_size = 10_000):
40+
rand_int = torch.randint(0, max_size, ())
41+
42+
if is_distributed():
43+
dist.all_reduce(rand_int)
44+
45+
return rand_int.item()
46+
4047
# main class
4148

4249
class ResidualFSQ(Module):
@@ -175,18 +182,12 @@ def forward(
175182

176183
if should_quantize_dropout:
177184

178-
if exists(rand_quantize_dropout_fixed_seed):
179-
# seed is manually passed in
180-
rand = random.Random(rand_quantize_dropout_fixed_seed)
185+
# check if seed is manually passed in
181186

182-
elif is_distributed():
183-
# in distributed environment, synchronize a random seed value if not given
184-
t = torch.tensor(random.randrange(10_000), device = device)
185-
dropout_seed = dist.all_reduce(t).item()
186-
rand = random.Random(dropout_seed)
187+
if not exists(rand_quantize_dropout_fixed_seed):
188+
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed()
187189

188-
else:
189-
rand = random
190+
rand = random.Random(rand_quantize_dropout_fixed_seed)
190191

191192
rand_quantize_dropout_index = rand.randrange(self.quantize_dropout_cutoff_index, num_quant)
192193

@@ -304,7 +305,7 @@ def forward(
304305

305306
forward_kwargs = dict(
306307
return_all_codes = return_all_codes,
307-
rand_quantize_dropout_fixed_seed = random.randint(0, int(1e7))
308+
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed()
308309
)
309310

310311
# invoke residual vq on each group

vector_quantize_pytorch/residual_lfq.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,17 @@ def round_up_multiple(num, mult):
2828

2929
# distributed helpers
3030

31-
@cache
3231
def is_distributed():
3332
return dist.is_initialized() and dist.get_world_size() > 1
3433

34+
def get_maybe_sync_seed(max_size = 10_000):
35+
rand_int = torch.randint(0, max_size, ())
36+
37+
if is_distributed():
38+
dist.all_reduce(rand_int)
39+
40+
return rand_int.item()
41+
3542
# main class
3643

3744
class ResidualLFQ(Module):
@@ -152,18 +159,12 @@ def forward(
152159

153160
if should_quantize_dropout:
154161

155-
if exists(rand_quantize_dropout_fixed_seed):
156-
# seed is manually passed in
157-
rand = random.Random(rand_quantize_dropout_fixed_seed)
162+
# check if seed is manually passed in
158163

159-
elif is_distributed():
160-
# in distributed environment, synchronize a random seed value if not given
161-
t = torch.tensor(random.randrange(10_000), device = device)
162-
dropout_seed = dist.all_reduce(t).item()
163-
rand = random.Random(dropout_seed)
164+
if not exists(rand_quantize_dropout_fixed_seed):
165+
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed()
164166

165-
else:
166-
rand = random
167+
rand = random.Random(rand_quantize_dropout_fixed_seed)
167168

168169
rand_quantize_dropout_index = rand.randrange(self.quantize_dropout_cutoff_index, num_quant)
169170

@@ -271,7 +272,7 @@ def forward(
271272
forward_kwargs = dict(
272273
mask = mask,
273274
return_all_codes = return_all_codes,
274-
rand_quantize_dropout_fixed_seed = random.randint(0, int(1e7))
275+
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed()
275276
)
276277

277278
# invoke residual vq on each group

vector_quantize_pytorch/residual_vq.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,17 @@ def round_up_multiple(num, mult):
3333

3434
# distributed helpers
3535

36-
@cache
3736
def is_distributed():
3837
return dist.is_initialized() and dist.get_world_size() > 1
3938

39+
def get_maybe_sync_seed(max_size = 10_000):
40+
rand_int = torch.randint(0, max_size, ())
41+
42+
if is_distributed():
43+
dist.all_reduce(rand_int)
44+
45+
return rand_int.item()
46+
4047
# the mlp for generating the neural implicit codebook
4148
# from Huijben et al. https://arxiv.org/abs/2401.14732
4249

@@ -286,18 +293,13 @@ def forward(
286293

287294
if should_quantize_dropout:
288295

289-
if exists(rand_quantize_dropout_fixed_seed):
290-
# seed is manually passed in
291-
rand = random.Random(rand_quantize_dropout_fixed_seed)
296+
# check if seed is manually passed in
297+
298+
if not exists(rand_quantize_dropout_fixed_seed):
299+
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed()
292300

293-
elif is_distributed():
294-
# in distributed environment, synchronize a random seed value if not given
295-
t = torch.tensor(random.randrange(10_000), device = device)
296-
dropout_seed = dist.all_reduce(t).item()
297-
rand = random.Random(dropout_seed)
301+
rand = random.Random(rand_quantize_dropout_fixed_seed)
298302

299-
else:
300-
rand = random
301303

302304
rand_quantize_dropout_index = rand.randrange(self.quantize_dropout_cutoff_index, num_quant)
303305

@@ -466,7 +468,7 @@ def forward(
466468
sample_codebook_temp = sample_codebook_temp,
467469
mask = mask,
468470
freeze_codebook = freeze_codebook,
469-
rand_quantize_dropout_fixed_seed = random.randint(0, int(1e7))
471+
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed()
470472
)
471473

472474
# invoke residual vq on each group

0 commit comments

Comments
 (0)