diff --git a/vector_quantize_pytorch/finite_scalar_quantization.py b/vector_quantize_pytorch/finite_scalar_quantization.py index 005908e..971012b 100644 --- a/vector_quantize_pytorch/finite_scalar_quantization.py +++ b/vector_quantize_pytorch/finite_scalar_quantization.py @@ -66,7 +66,7 @@ def __init__( return_indices = True, force_quantization_f32 = True, preserve_symmetry: bool = False, - noise_approx_prob = 0.0, + noise_dropout = 0.0, ): super().__init__() @@ -79,7 +79,7 @@ def __init__( self.scale = scale self.preserve_symmetry = preserve_symmetry - self.noise_approx_prob = noise_approx_prob + self.noise_dropout = noise_dropout codebook_dim = len(levels) self.codebook_dim = codebook_dim @@ -129,24 +129,40 @@ def symmetry_preserving_bound(self, z): bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.0) + 0.5 return scale * bracket - 1.0 - def noise_approx_bound(self, z): - """ - simulates quantization using noise -> Q_L(x) ~= tanh(x) + U{-1,1} / (L-1) - """ - noise = torch.empty_like(z).uniform_(-1, 1) - return torch.tanh(z) + noise / (self._levels - 1) - def quantize(self, z, preserve_symmetry = False): """ Quantizes z, returns quantized zhat, same shape as z. """ - if self.training and random.random() < self.noise_approx_prob: - bounded = self.noise_approx_bound(z) + + half_width = self._levels // 2 + + if self.training: + unquantized = z + + # determine where to quantize elementwise + + quantize_mask = torch.bernoulli( + torch.full([z.shape[0], 1, 1, 1], self.noise_dropout, device = z.device) + ).bool().expand_as(z) + + if preserve_symmetry: + quantized = round_ste(self.symmetry_preserving_bound(z)) / half_width + else: + quantized = round_ste(self.bound(z)) / half_width + quantized = torch.where(quantize_mask, unquantized, quantized) + + # determine where to add a random offset elementwise + + offset_mask = torch.bernoulli( + torch.full([z.shape[0], 1, 1, 1], self.noise_dropout, device = z.device) + ).bool().expand_as(z) + + offset = (torch.rand_like(z) - 0.5) / half_width + quantized = torch.where(offset_mask, unquantized + offset, quantized) elif preserve_symmetry: - bounded = self.symmetry_preserving_bound(z) + quantized = round_ste(self.symmetry_preserving_bound(z)) / half_width else: - bounded = self.bound(z) - quantized = round_ste(bounded) - half_width = self._levels // 2 # Renormalize to [-1, 1]. - return quantized / half_width + quantized = round_ste(self.bound(z)) / half_width + + return quantized def _scale_and_shift(self, zhat_normalized): half_width = self._levels // 2