Skip to content

Commit 8975648

Browse files
committed
make a guess to what noise dropout should be for FSQ #207
1 parent d31c050 commit 8975648

File tree

2 files changed

+9
-26
lines changed

2 files changed

+9
-26
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.22.6"
3+
version = "1.22.7"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/finite_scalar_quantization.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -137,38 +137,21 @@ def symmetry_preserving_bound(self, z):
137137

138138
def quantize(self, z):
139139
""" Quantizes z, returns quantized zhat, same shape as z. """
140+
shape, device, noise_dropout, preserve_symmetry, half_width = z.shape[0], z.device, self.noise_dropout, self.preserve_symmetry, (self._levels // 2)
140141

141-
preserve_symmetry = self.preserve_symmetry
142-
half_width = self._levels // 2
142+
# determine where to add a random offset elementwise
143+
# if using noise dropout
144+
145+
if self.training and noise_dropout > 0.:
146+
offset_mask = torch.bernoulli(torch.full_like(z, noise_dropout)).bool()
147+
offset = (torch.rand_like(z) - 0.5) / half_width
148+
z = torch.where(offset_mask, z + offset, z)
143149

144150
if preserve_symmetry:
145151
quantized = round_ste(self.symmetry_preserving_bound(z)) / half_width
146152
else:
147153
quantized = round_ste(self.bound(z)) / half_width
148154

149-
if not self.training:
150-
return quantized
151-
152-
batch, device, noise_dropout = z.shape[0], z.device, self.noise_dropout
153-
unquantized = z
154-
155-
# determine where to quantize elementwise
156-
157-
quantize_mask = torch.bernoulli(
158-
torch.full((batch,), noise_dropout, device = device)
159-
).bool()
160-
161-
quantized = einx.where('b, b ..., b ...', quantize_mask, unquantized, quantized)
162-
163-
# determine where to add a random offset elementwise
164-
165-
offset_mask = torch.bernoulli(
166-
torch.full((batch,), noise_dropout, device = device)
167-
).bool()
168-
169-
offset = (torch.rand_like(z) - 0.5) / half_width
170-
quantized = einx.where('b, b ..., b ...', offset_mask, unquantized + offset, quantized)
171-
172155
return quantized
173156

174157
def _scale_and_shift(self, zhat_normalized):

0 commit comments

Comments
 (0)