Skip to content

FSQ: Use element-wise selection for noise dropout #186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 10, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 32 additions & 16 deletions vector_quantize_pytorch/finite_scalar_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
return_indices = True,
force_quantization_f32 = True,
preserve_symmetry: bool = False,
noise_approx_prob = 0.0,
noise_dropout = 0.0,
):
super().__init__()

Expand All @@ -79,7 +79,7 @@ def __init__(
self.scale = scale

self.preserve_symmetry = preserve_symmetry
self.noise_approx_prob = noise_approx_prob
self.noise_dropout = noise_dropout

codebook_dim = len(levels)
self.codebook_dim = codebook_dim
Expand Down Expand Up @@ -129,24 +129,40 @@ def symmetry_preserving_bound(self, z):
bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.0) + 0.5
return scale * bracket - 1.0

def noise_approx_bound(self, z):
"""
simulates quantization using noise -> Q_L(x) ~= tanh(x) + U{-1,1} / (L-1)
"""
noise = torch.empty_like(z).uniform_(-1, 1)
return torch.tanh(z) + noise / (self._levels - 1)

def quantize(self, z, preserve_symmetry = False):
""" Quantizes z, returns quantized zhat, same shape as z. """
if self.training and random.random() < self.noise_approx_prob:
bounded = self.noise_approx_bound(z)

half_width = self._levels // 2

if self.training:
unquantized = z

# determine where to quantize elementwise

quantize_mask = torch.bernoulli(
torch.full([z.shape[0], 1, 1, 1], self.noise_dropout, device = z.device)
).bool().expand_as(z)

if preserve_symmetry:
quantized = round_ste(self.symmetry_preserving_bound(z)) / half_width
else:
quantized = round_ste(self.bound(z)) / half_width
quantized = torch.where(quantize_mask, unquantized, quantized)

# determine where to add a random offset elementwise

offset_mask = torch.bernoulli(
torch.full([z.shape[0], 1, 1, 1], self.noise_dropout, device = z.device)
).bool().expand_as(z)

offset = (torch.rand_like(z) - 0.5) / half_width
quantized = torch.where(offset_mask, unquantized + offset, quantized)
elif preserve_symmetry:
bounded = self.symmetry_preserving_bound(z)
quantized = round_ste(self.symmetry_preserving_bound(z)) / half_width
else:
bounded = self.bound(z)
quantized = round_ste(bounded)
half_width = self._levels // 2 # Renormalize to [-1, 1].
return quantized / half_width
quantized = round_ste(self.bound(z)) / half_width

return quantized

def _scale_and_shift(self, zhat_normalized):
half_width = self._levels // 2
Expand Down
Loading