11
11
import torch
12
12
import torch .nn as nn
13
13
from torch .nn import Module
14
- from torch import Tensor , int32
14
+ from torch import tensor , Tensor , int32
15
15
from torch .amp import autocast
16
16
17
17
import einx
@@ -47,11 +47,12 @@ def unpack_one(t, ps, pattern):
47
47
# tensor helpers
48
48
49
49
def round_ste (z ):
50
- """Round with straight through gradients."""
50
+ """ round with straight through gradients. """
51
51
zhat = z .round ()
52
52
return z + (zhat - z ).detach ()
53
53
54
54
def floor_ste (z ):
55
+ """ floor with straight through gradients. """
55
56
zhat = z .floor ()
56
57
return z + (zhat - z ).detach ()
57
58
@@ -60,26 +61,26 @@ def floor_ste(z):
60
61
class FSQ (Module ):
61
62
def __init__ (
62
63
self ,
63
- levels : List [int ],
64
+ levels : list [int ],
64
65
dim : int | None = None ,
65
66
num_codebooks = 1 ,
66
67
keep_num_codebooks_dim : bool | None = None ,
67
68
scale : float | None = None ,
68
- allowed_dtypes : Tuple [torch .dtype , ...] = (torch .float32 , torch .float64 ),
69
- channel_first : bool = False ,
70
- projection_has_bias : bool = True ,
69
+ allowed_dtypes : tuple [torch .dtype , ...] = (torch .float32 , torch .float64 ),
70
+ channel_first = False ,
71
+ projection_has_bias = True ,
71
72
return_indices = True ,
72
73
force_quantization_f32 = True ,
73
- preserve_symmetry : bool = False ,
74
- noise_dropout = 0.0 ,
74
+ preserve_symmetry = False ,
75
+ noise_dropout = 0. ,
75
76
):
76
77
super ().__init__ ()
77
78
78
- _levels = torch . tensor (levels , dtype = int32 )
79
- self .register_buffer (" _levels" , _levels , persistent = False )
79
+ _levels = tensor (levels , dtype = int32 )
80
+ self .register_buffer (' _levels' , _levels , persistent = False )
80
81
81
- _basis = torch .cumprod (torch . tensor ([1 ] + levels [:- 1 ]), dim = 0 , dtype = int32 )
82
- self .register_buffer (" _basis" , _basis , persistent = False )
82
+ _basis = torch .cumprod (tensor ([1 ] + levels [:- 1 ]), dim = 0 , dtype = int32 )
83
+ self .register_buffer (' _basis' , _basis , persistent = False )
83
84
84
85
self .scale = scale
85
86
@@ -108,56 +109,65 @@ def __init__(
108
109
self .has_projections = has_projections
109
110
110
111
self .return_indices = return_indices
112
+
111
113
if return_indices :
112
114
self .codebook_size = self ._levels .prod ().item ()
113
115
implicit_codebook = self ._indices_to_codes (torch .arange (self .codebook_size ))
114
- self .register_buffer (" implicit_codebook" , implicit_codebook , persistent = False )
116
+ self .register_buffer (' implicit_codebook' , implicit_codebook , persistent = False )
115
117
116
118
self .allowed_dtypes = allowed_dtypes
117
119
self .force_quantization_f32 = force_quantization_f32
118
120
119
- def bound (self , z , eps : float = 1e-3 ):
121
+ def bound (self , z , eps = 1e-3 ):
120
122
""" Bound `z`, an array of shape (..., d). """
121
123
half_l = (self ._levels - 1 ) * (1 + eps ) / 2
122
124
offset = torch .where (self ._levels % 2 == 0 , 0.5 , 0.0 )
123
125
shift = (offset / half_l ).atanh ()
124
- return (z + shift ).tanh () * half_l - offset
126
+ bounded_z = (z + shift ).tanh () * half_l - offset
127
+ half_width = self ._levels // 2
128
+ return round_ste (bounded_z ) / half_width
125
129
126
130
# symmetry-preserving and noise-approximated quantization, section 3.2 in https://arxiv.org/abs/2411.19842
127
131
128
132
def symmetry_preserving_bound (self , z ):
129
- """
130
- QL(x) = 2 / (L - 1) * [(L - 1) * (tanh(x) + 1) / 2 + 0.5] - 1
131
- """
133
+ """ QL(x) = 2 / (L - 1) * [(L - 1) * (tanh(x) + 1) / 2 + 0.5] - 1 """
132
134
levels_minus_1 = (self ._levels - 1 )
133
- scale = 2.0 / levels_minus_1
134
- bracket = (levels_minus_1 * (torch .tanh (z ) + 1 ) / 2.0 ) + 0.5
135
+ scale = 2. / levels_minus_1
136
+ bracket = (levels_minus_1 * (z .tanh () + 1 ) / 2. ) + 0.5
135
137
bracket = floor_ste (bracket )
136
- return scale * bracket - 1.0
138
+ return scale * bracket - 1.
137
139
138
140
def quantize (self , z ):
139
141
""" Quantizes z, returns quantized zhat, same shape as z. """
140
142
141
- shape , device , noise_dropout , preserve_symmetry , half_width = z .shape [0 ], z .device , self .noise_dropout , self .preserve_symmetry , ( self . _levels // 2 )
143
+ shape , device , noise_dropout , preserve_symmetry = z .shape [0 ], z .device , self .noise_dropout , self .preserve_symmetry
142
144
bound_fn = self .symmetry_preserving_bound if preserve_symmetry else self .bound
143
145
144
146
bounded_z = bound_fn (z )
145
147
146
148
# determine where to add a random offset elementwise
147
149
# if using noise dropout
148
150
149
- if self .training and noise_dropout > 0. :
150
- offset_mask = torch .bernoulli (torch .full_like (bounded_z , noise_dropout )).bool ()
151
- offset = torch .rand_like (bounded_z ) - 0.5
152
- bounded_z = torch .where (offset_mask , bounded_z + offset , bounded_z )
151
+ if not self .training or noise_dropout == 0. :
152
+ return bounded_z
153
153
154
- return round_ste (bounded_z ) / half_width
154
+ offset_mask = torch .bernoulli (torch .full_like (bounded_z , noise_dropout )).bool ()
155
+ offset = torch .rand_like (bounded_z ) - 0.5
156
+ bounded_z = torch .where (offset_mask , bounded_z + offset , bounded_z )
157
+
158
+ return bounded_z
155
159
156
160
def _scale_and_shift (self , zhat_normalized ):
161
+ if self .preserve_symmetry :
162
+ return (zhat_normalized + 1. ) / (2. / (self ._levels - 1 ))
163
+
157
164
half_width = self ._levels // 2
158
165
return (zhat_normalized * half_width ) + half_width
159
166
160
167
def _scale_and_shift_inverse (self , zhat ):
168
+ if self .preserve_symmetry :
169
+ return zhat * (2. / (self ._levels - 1 )) - 1.
170
+
161
171
half_width = self ._levels // 2
162
172
return (zhat - half_width ) / half_width
163
173
@@ -166,18 +176,18 @@ def _indices_to_codes(self, indices):
166
176
codes = self ._scale_and_shift_inverse (level_indices )
167
177
return codes
168
178
169
- def codes_to_indices (self , zhat ):
170
- """ Converts a `code` to an index in the codebook. """
171
- assert zhat .shape [- 1 ] == self .codebook_dim
172
- zhat = self ._scale_and_shift (zhat )
173
- return (zhat * self ._basis ).sum (dim = - 1 ).to (int32 )
174
-
175
179
def indices_to_level_indices (self , indices ):
176
180
""" Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """
177
181
indices = rearrange (indices , '... -> ... 1' )
178
182
codes_non_centered = (indices // self ._basis ) % self ._levels
179
183
return codes_non_centered
180
184
185
+ def codes_to_indices (self , zhat ):
186
+ """ Converts a `code` to an index in the codebook. """
187
+ assert zhat .shape [- 1 ] == self .codebook_dim
188
+ zhat = self ._scale_and_shift (zhat )
189
+ return (zhat * self ._basis ).sum (dim = - 1 ).round ().to (int32 )
190
+
181
191
def indices_to_codes (self , indices ):
182
192
""" Inverse of `codes_to_indices`. """
183
193
assert exists (indices )
0 commit comments