@@ -137,38 +137,21 @@ def symmetry_preserving_bound(self, z):
137
137
138
138
def quantize (self , z ):
139
139
""" Quantizes z, returns quantized zhat, same shape as z. """
140
+ shape , device , noise_dropout , preserve_symmetry , half_width = z .shape [0 ], z .device , self .noise_dropout , self .preserve_symmetry , (self ._levels // 2 )
140
141
141
- preserve_symmetry = self .preserve_symmetry
142
- half_width = self ._levels // 2
142
+ # determine where to add a random offset elementwise
143
+ # if using noise dropout
144
+
145
+ if self .training and noise_dropout > 0. :
146
+ offset_mask = torch .bernoulli (torch .full_like (z , noise_dropout )).bool ()
147
+ offset = (torch .rand_like (z ) - 0.5 ) / half_width
148
+ z = torch .where (offset_mask , z + offset , z )
143
149
144
150
if preserve_symmetry :
145
151
quantized = round_ste (self .symmetry_preserving_bound (z )) / half_width
146
152
else :
147
153
quantized = round_ste (self .bound (z )) / half_width
148
154
149
- if not self .training :
150
- return quantized
151
-
152
- batch , device , noise_dropout = z .shape [0 ], z .device , self .noise_dropout
153
- unquantized = z
154
-
155
- # determine where to quantize elementwise
156
-
157
- quantize_mask = torch .bernoulli (
158
- torch .full ((batch ,), noise_dropout , device = device )
159
- ).bool ()
160
-
161
- quantized = einx .where ('b, b ..., b ...' , quantize_mask , unquantized , quantized )
162
-
163
- # determine where to add a random offset elementwise
164
-
165
- offset_mask = torch .bernoulli (
166
- torch .full ((batch ,), noise_dropout , device = device )
167
- ).bool ()
168
-
169
- offset = (torch .rand_like (z ) - 0.5 ) / half_width
170
- quantized = einx .where ('b, b ..., b ...' , offset_mask , unquantized + offset , quantized )
171
-
172
155
return quantized
173
156
174
157
def _scale_and_shift (self , zhat_normalized ):
0 commit comments