@@ -66,7 +66,7 @@ def __init__(
66
66
return_indices = True ,
67
67
force_quantization_f32 = True ,
68
68
preserve_symmetry : bool = False ,
69
- noise_approx_prob = 0.0 ,
69
+ noise_dropout = 0.0 ,
70
70
):
71
71
super ().__init__ ()
72
72
@@ -79,7 +79,7 @@ def __init__(
79
79
self .scale = scale
80
80
81
81
self .preserve_symmetry = preserve_symmetry
82
- self .noise_approx_prob = noise_approx_prob
82
+ self .noise_dropout = noise_dropout
83
83
84
84
codebook_dim = len (levels )
85
85
self .codebook_dim = codebook_dim
@@ -129,24 +129,40 @@ def symmetry_preserving_bound(self, z):
129
129
bracket = (levels_minus_1 * (torch .tanh (z ) + 1 ) / 2.0 ) + 0.5
130
130
return scale * bracket - 1.0
131
131
132
- def noise_approx_bound (self , z ):
133
- """
134
- simulates quantization using noise -> Q_L(x) ~= tanh(x) + U{-1,1} / (L-1)
135
- """
136
- noise = torch .empty_like (z ).uniform_ (- 1 , 1 )
137
- return torch .tanh (z ) + noise / (self ._levels - 1 )
138
-
139
132
def quantize (self , z , preserve_symmetry = False ):
140
133
""" Quantizes z, returns quantized zhat, same shape as z. """
141
- if self .training and random .random () < self .noise_approx_prob :
142
- bounded = self .noise_approx_bound (z )
134
+
135
+ half_width = self ._levels // 2
136
+
137
+ if self .training :
138
+ unquantized = z
139
+
140
+ # determine where to quantize elementwise
141
+
142
+ quantize_mask = torch .bernoulli (
143
+ torch .full ([z .shape [0 ], 1 , 1 , 1 ], self .noise_dropout , device = z .device )
144
+ ).bool ().expand_as (z )
145
+
146
+ if preserve_symmetry :
147
+ quantized = round_ste (self .symmetry_preserving_bound (z )) / half_width
148
+ else :
149
+ quantized = round_ste (self .bound (z )) / half_width
150
+ quantized = torch .where (quantize_mask , unquantized , quantized )
151
+
152
+ # determine where to add a random offset elementwise
153
+
154
+ offset_mask = torch .bernoulli (
155
+ torch .full ([z .shape [0 ], 1 , 1 , 1 ], self .noise_dropout , device = z .device )
156
+ ).bool ().expand_as (z )
157
+
158
+ offset = (torch .rand_like (z ) - 0.5 ) / half_width
159
+ quantized = torch .where (offset_mask , unquantized + offset , quantized )
143
160
elif preserve_symmetry :
144
- bounded = self .symmetry_preserving_bound (z )
161
+ quantized = round_ste ( self .symmetry_preserving_bound (z )) / half_width
145
162
else :
146
- bounded = self .bound (z )
147
- quantized = round_ste (bounded )
148
- half_width = self ._levels // 2 # Renormalize to [-1, 1].
149
- return quantized / half_width
163
+ quantized = round_ste (self .bound (z )) / half_width
164
+
165
+ return quantized
150
166
151
167
def _scale_and_shift (self , zhat_normalized ):
152
168
half_width = self ._levels // 2
0 commit comments