Skip to content

Commit 976335f

Browse files
committed
address #215
1 parent 9d2594b commit 976335f

File tree

2 files changed

+44
-34
lines changed

2 files changed

+44
-34
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.22.15"
3+
version = "1.22.16"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/finite_scalar_quantization.py

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
import torch.nn as nn
1313
from torch.nn import Module
14-
from torch import Tensor, int32
14+
from torch import tensor, Tensor, int32
1515
from torch.amp import autocast
1616

1717
import einx
@@ -47,11 +47,12 @@ def unpack_one(t, ps, pattern):
4747
# tensor helpers
4848

4949
def round_ste(z):
50-
"""Round with straight through gradients."""
50+
""" round with straight through gradients. """
5151
zhat = z.round()
5252
return z + (zhat - z).detach()
5353

5454
def floor_ste(z):
55+
""" floor with straight through gradients. """
5556
zhat = z.floor()
5657
return z + (zhat - z).detach()
5758

@@ -60,26 +61,26 @@ def floor_ste(z):
6061
class FSQ(Module):
6162
def __init__(
6263
self,
63-
levels: List[int],
64+
levels: list[int],
6465
dim: int | None = None,
6566
num_codebooks = 1,
6667
keep_num_codebooks_dim: bool | None = None,
6768
scale: float | None = None,
68-
allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
69-
channel_first: bool = False,
70-
projection_has_bias: bool = True,
69+
allowed_dtypes: tuple[torch.dtype, ...] = (torch.float32, torch.float64),
70+
channel_first = False,
71+
projection_has_bias = True,
7172
return_indices = True,
7273
force_quantization_f32 = True,
73-
preserve_symmetry: bool = False,
74-
noise_dropout = 0.0,
74+
preserve_symmetry = False,
75+
noise_dropout = 0.,
7576
):
7677
super().__init__()
7778

78-
_levels = torch.tensor(levels, dtype=int32)
79-
self.register_buffer("_levels", _levels, persistent = False)
79+
_levels = tensor(levels, dtype = int32)
80+
self.register_buffer('_levels', _levels, persistent = False)
8081

81-
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
82-
self.register_buffer("_basis", _basis, persistent = False)
82+
_basis = torch.cumprod(tensor([1] + levels[:-1]), dim = 0, dtype = int32)
83+
self.register_buffer('_basis', _basis, persistent = False)
8384

8485
self.scale = scale
8586

@@ -108,56 +109,65 @@ def __init__(
108109
self.has_projections = has_projections
109110

110111
self.return_indices = return_indices
112+
111113
if return_indices:
112114
self.codebook_size = self._levels.prod().item()
113115
implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size))
114-
self.register_buffer("implicit_codebook", implicit_codebook, persistent = False)
116+
self.register_buffer('implicit_codebook', implicit_codebook, persistent = False)
115117

116118
self.allowed_dtypes = allowed_dtypes
117119
self.force_quantization_f32 = force_quantization_f32
118120

119-
def bound(self, z, eps: float = 1e-3):
121+
def bound(self, z, eps = 1e-3):
120122
""" Bound `z`, an array of shape (..., d). """
121123
half_l = (self._levels - 1) * (1 + eps) / 2
122124
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
123125
shift = (offset / half_l).atanh()
124-
return (z + shift).tanh() * half_l - offset
126+
bounded_z = (z + shift).tanh() * half_l - offset
127+
half_width = self._levels // 2
128+
return round_ste(bounded_z) / half_width
125129

126130
# symmetry-preserving and noise-approximated quantization, section 3.2 in https://arxiv.org/abs/2411.19842
127131

128132
def symmetry_preserving_bound(self, z):
129-
"""
130-
QL(x) = 2 / (L - 1) * [(L - 1) * (tanh(x) + 1) / 2 + 0.5] - 1
131-
"""
133+
""" QL(x) = 2 / (L - 1) * [(L - 1) * (tanh(x) + 1) / 2 + 0.5] - 1 """
132134
levels_minus_1 = (self._levels - 1)
133-
scale = 2.0 / levels_minus_1
134-
bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.0) + 0.5
135+
scale = 2. / levels_minus_1
136+
bracket = (levels_minus_1 * (z.tanh() + 1) / 2.) + 0.5
135137
bracket = floor_ste(bracket)
136-
return scale * bracket - 1.0
138+
return scale * bracket - 1.
137139

138140
def quantize(self, z):
139141
""" Quantizes z, returns quantized zhat, same shape as z. """
140142

141-
shape, device, noise_dropout, preserve_symmetry, half_width = z.shape[0], z.device, self.noise_dropout, self.preserve_symmetry, (self._levels // 2)
143+
shape, device, noise_dropout, preserve_symmetry = z.shape[0], z.device, self.noise_dropout, self.preserve_symmetry
142144
bound_fn = self.symmetry_preserving_bound if preserve_symmetry else self.bound
143145

144146
bounded_z = bound_fn(z)
145147

146148
# determine where to add a random offset elementwise
147149
# if using noise dropout
148150

149-
if self.training and noise_dropout > 0.:
150-
offset_mask = torch.bernoulli(torch.full_like(bounded_z, noise_dropout)).bool()
151-
offset = torch.rand_like(bounded_z) - 0.5
152-
bounded_z = torch.where(offset_mask, bounded_z + offset, bounded_z)
151+
if not self.training or noise_dropout == 0.:
152+
return bounded_z
153153

154-
return round_ste(bounded_z) / half_width
154+
offset_mask = torch.bernoulli(torch.full_like(bounded_z, noise_dropout)).bool()
155+
offset = torch.rand_like(bounded_z) - 0.5
156+
bounded_z = torch.where(offset_mask, bounded_z + offset, bounded_z)
157+
158+
return bounded_z
155159

156160
def _scale_and_shift(self, zhat_normalized):
161+
if self.preserve_symmetry:
162+
return (zhat_normalized + 1.) / (2. / (self._levels - 1))
163+
157164
half_width = self._levels // 2
158165
return (zhat_normalized * half_width) + half_width
159166

160167
def _scale_and_shift_inverse(self, zhat):
168+
if self.preserve_symmetry:
169+
return zhat * (2. / (self._levels - 1)) - 1.
170+
161171
half_width = self._levels // 2
162172
return (zhat - half_width) / half_width
163173

@@ -166,18 +176,18 @@ def _indices_to_codes(self, indices):
166176
codes = self._scale_and_shift_inverse(level_indices)
167177
return codes
168178

169-
def codes_to_indices(self, zhat):
170-
""" Converts a `code` to an index in the codebook. """
171-
assert zhat.shape[-1] == self.codebook_dim
172-
zhat = self._scale_and_shift(zhat)
173-
return (zhat * self._basis).sum(dim=-1).to(int32)
174-
175179
def indices_to_level_indices(self, indices):
176180
""" Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """
177181
indices = rearrange(indices, '... -> ... 1')
178182
codes_non_centered = (indices // self._basis) % self._levels
179183
return codes_non_centered
180184

185+
def codes_to_indices(self, zhat):
186+
""" Converts a `code` to an index in the codebook. """
187+
assert zhat.shape[-1] == self.codebook_dim
188+
zhat = self._scale_and_shift(zhat)
189+
return (zhat * self._basis).sum(dim = -1).round().to(int32)
190+
181191
def indices_to_codes(self, indices):
182192
""" Inverse of `codes_to_indices`. """
183193
assert exists(indices)

0 commit comments

Comments
 (0)