Skip to content

Commit f97a37b

Browse files
committed
1 parent 46bd7af commit f97a37b

File tree

3 files changed

+43
-5
lines changed

3 files changed

+43
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.18.5"
3+
version = "1.18.6"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/residual_fsq.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import random
22
from math import log2
3-
from functools import partial
3+
from functools import partial, cache
44

55
from typing import List
66

@@ -9,6 +9,7 @@
99
from torch.nn import Module, ModuleList
1010
import torch.nn.functional as F
1111
from torch.amp import autocast
12+
import torch.distributed as dist
1213

1314
from vector_quantize_pytorch.finite_scalar_quantization import FSQ
1415

@@ -30,6 +31,12 @@ def default(val, d):
3031
def round_up_multiple(num, mult):
3132
return ceil(num / mult) * mult
3233

34+
# distributed helpers
35+
36+
@cache
37+
def is_distributed():
38+
return dist.is_initialized() and dist.get_world_size() > 1
39+
3340
# main class
3441

3542
class ResidualFSQ(Module):
@@ -167,7 +174,19 @@ def forward(
167174
# also prepare null indices
168175

169176
if should_quantize_dropout:
170-
rand = random.Random(rand_quantize_dropout_fixed_seed) if exists(rand_quantize_dropout_fixed_seed) else random
177+
178+
if exists(rand_quantize_dropout_fixed_seed):
179+
# seed is manually passed in
180+
rand = random.Random(rand_quantize_dropout_fixed_seed)
181+
182+
elif is_distributed():
183+
# in distributed environment, synchronize a random seed value if not given
184+
t = torch.tensor(random.randrange(10_000), device = device)
185+
dropout_seed = dist.all_reduce(t).item()
186+
rand = random.Random(dropout_seed)
187+
188+
else:
189+
rand = random
171190

172191
rand_quantize_dropout_index = rand.randrange(self.quantize_dropout_cutoff_index, num_quant)
173192

vector_quantize_pytorch/residual_lfq.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import random
22
from math import log2
3-
from functools import partial
3+
from functools import partial, cache
44

55
import torch
66
from torch import nn
77
from torch.nn import Module, ModuleList
88
import torch.nn.functional as F
99
from torch.amp import autocast
10+
import torch.distributed as dist
1011

1112
from vector_quantize_pytorch.lookup_free_quantization import LFQ
1213

@@ -25,6 +26,12 @@ def default(val, d):
2526
def round_up_multiple(num, mult):
2627
return ceil(num / mult) * mult
2728

29+
# distributed helpers
30+
31+
@cache
32+
def is_distributed():
33+
return dist.is_initialized() and dist.get_world_size() > 1
34+
2835
# main class
2936

3037
class ResidualLFQ(Module):
@@ -144,7 +151,19 @@ def forward(
144151
# also prepare null indices and loss
145152

146153
if should_quantize_dropout:
147-
rand = random.Random(rand_quantize_dropout_fixed_seed) if exists(rand_quantize_dropout_fixed_seed) else random
154+
155+
if exists(rand_quantize_dropout_fixed_seed):
156+
# seed is manually passed in
157+
rand = random.Random(rand_quantize_dropout_fixed_seed)
158+
159+
elif is_distributed():
160+
# in distributed environment, synchronize a random seed value if not given
161+
t = torch.tensor(random.randrange(10_000), device = device)
162+
dropout_seed = dist.all_reduce(t).item()
163+
rand = random.Random(dropout_seed)
164+
165+
else:
166+
rand = random
148167

149168
rand_quantize_dropout_index = rand.randrange(self.quantize_dropout_cutoff_index, num_quant)
150169

0 commit comments

Comments
 (0)