Skip to content

Commit 708f3c9

Browse files
committed
Use element-wise selection for noise dropout.
1 parent fa2211d commit 708f3c9

File tree

1 file changed

+32
-16
lines changed

1 file changed

+32
-16
lines changed

vector_quantize_pytorch/finite_scalar_quantization.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
return_indices = True,
6767
force_quantization_f32 = True,
6868
preserve_symmetry: bool = False,
69-
noise_approx_prob = 0.0,
69+
noise_dropout = 0.0,
7070
):
7171
super().__init__()
7272

@@ -79,7 +79,7 @@ def __init__(
7979
self.scale = scale
8080

8181
self.preserve_symmetry = preserve_symmetry
82-
self.noise_approx_prob = noise_approx_prob
82+
self.noise_dropout = noise_dropout
8383

8484
codebook_dim = len(levels)
8585
self.codebook_dim = codebook_dim
@@ -129,24 +129,40 @@ def symmetry_preserving_bound(self, z):
129129
bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.0) + 0.5
130130
return scale * bracket - 1.0
131131

132-
def noise_approx_bound(self, z):
133-
"""
134-
simulates quantization using noise -> Q_L(x) ~= tanh(x) + U{-1,1} / (L-1)
135-
"""
136-
noise = torch.empty_like(z).uniform_(-1, 1)
137-
return torch.tanh(z) + noise / (self._levels - 1)
138-
139132
def quantize(self, z, preserve_symmetry = False):
140133
""" Quantizes z, returns quantized zhat, same shape as z. """
141-
if self.training and random.random() < self.noise_approx_prob:
142-
bounded = self.noise_approx_bound(z)
134+
135+
half_width = self._levels // 2
136+
137+
if self.training:
138+
unquantized = z
139+
140+
# determine where to quantize elementwise
141+
142+
quantize_mask = torch.bernoulli(
143+
torch.full([z.shape[0], 1, 1, 1], self.noise_dropout, device = z.device)
144+
).bool().expand_as(z)
145+
146+
if preserve_symmetry:
147+
quantized = round_ste(self.symmetry_preserving_bound(z)) / half_width
148+
else:
149+
quantized = round_ste(self.bound(z)) / half_width
150+
quantized = torch.where(quantize_mask, unquantized, quantized)
151+
152+
# determine where to add a random offset elementwise
153+
154+
offset_mask = torch.bernoulli(
155+
torch.full([z.shape[0], 1, 1, 1], self.noise_dropout, device = z.device)
156+
).bool().expand_as(z)
157+
158+
offset = (torch.rand_like(z) - 0.5) / half_width
159+
quantized = torch.where(offset_mask, unquantized + offset, quantized)
143160
elif preserve_symmetry:
144-
bounded = self.symmetry_preserving_bound(z)
161+
quantized = round_ste(self.symmetry_preserving_bound(z)) / half_width
145162
else:
146-
bounded = self.bound(z)
147-
quantized = round_ste(bounded)
148-
half_width = self._levels // 2 # Renormalize to [-1, 1].
149-
return quantized / half_width
163+
quantized = round_ste(self.bound(z)) / half_width
164+
165+
return quantized
150166

151167
def _scale_and_shift(self, zhat_normalized):
152168
half_width = self._levels // 2

0 commit comments

Comments
 (0)