Skip to content

Commit 4931ee0

Browse files
committed
make a guess to what noise dropout should be for FSQ #207
1 parent d31c050 commit 4931ee0

File tree

2 files changed

+9
-29
lines changed

2 files changed

+9
-29
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.22.6"
3+
version = "1.22.9"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/finite_scalar_quantization.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -138,38 +138,18 @@ def symmetry_preserving_bound(self, z):
138138
def quantize(self, z):
139139
""" Quantizes z, returns quantized zhat, same shape as z. """
140140

141-
preserve_symmetry = self.preserve_symmetry
142-
half_width = self._levels // 2
143-
144-
if preserve_symmetry:
145-
quantized = round_ste(self.symmetry_preserving_bound(z)) / half_width
146-
else:
147-
quantized = round_ste(self.bound(z)) / half_width
148-
149-
if not self.training:
150-
return quantized
151-
152-
batch, device, noise_dropout = z.shape[0], z.device, self.noise_dropout
153-
unquantized = z
154-
155-
# determine where to quantize elementwise
156-
157-
quantize_mask = torch.bernoulli(
158-
torch.full((batch,), noise_dropout, device = device)
159-
).bool()
160-
161-
quantized = einx.where('b, b ..., b ...', quantize_mask, unquantized, quantized)
141+
shape, device, noise_dropout, preserve_symmetry, half_width = z.shape[0], z.device, self.noise_dropout, self.preserve_symmetry, (self._levels // 2)
142+
bound_fn = self.symmetry_preserving_bound if preserve_symmetry else self.bound
162143

163144
# determine where to add a random offset elementwise
145+
# if using noise dropout
164146

165-
offset_mask = torch.bernoulli(
166-
torch.full((batch,), noise_dropout, device = device)
167-
).bool()
168-
169-
offset = (torch.rand_like(z) - 0.5) / half_width
170-
quantized = einx.where('b, b ..., b ...', offset_mask, unquantized + offset, quantized)
147+
if self.training and noise_dropout > 0.:
148+
offset_mask = torch.bernoulli(torch.full_like(z, noise_dropout)).bool()
149+
offset = torch.rand_like(z) - 0.5
150+
z = torch.where(offset_mask, z + offset, z)
171151

172-
return quantized
152+
return round_ste(bound_fn(z)) / half_width
173153

174154
def _scale_and_shift(self, zhat_normalized):
175155
half_width = self._levels // 2

0 commit comments

Comments
 (0)