@@ -138,38 +138,18 @@ def symmetry_preserving_bound(self, z):
138
138
def quantize (self , z ):
139
139
""" Quantizes z, returns quantized zhat, same shape as z. """
140
140
141
- preserve_symmetry = self .preserve_symmetry
142
- half_width = self ._levels // 2
143
-
144
- if preserve_symmetry :
145
- quantized = round_ste (self .symmetry_preserving_bound (z )) / half_width
146
- else :
147
- quantized = round_ste (self .bound (z )) / half_width
148
-
149
- if not self .training :
150
- return quantized
151
-
152
- batch , device , noise_dropout = z .shape [0 ], z .device , self .noise_dropout
153
- unquantized = z
154
-
155
- # determine where to quantize elementwise
156
-
157
- quantize_mask = torch .bernoulli (
158
- torch .full ((batch ,), noise_dropout , device = device )
159
- ).bool ()
160
-
161
- quantized = einx .where ('b, b ..., b ...' , quantize_mask , unquantized , quantized )
141
+ shape , device , noise_dropout , preserve_symmetry , half_width = z .shape [0 ], z .device , self .noise_dropout , self .preserve_symmetry , (self ._levels // 2 )
142
+ bound_fn = self .symmetry_preserving_bound if preserve_symmetry else self .bound
162
143
163
144
# determine where to add a random offset elementwise
145
+ # if using noise dropout
164
146
165
- offset_mask = torch .bernoulli (
166
- torch .full ((batch ,), noise_dropout , device = device )
167
- ).bool ()
168
-
169
- offset = (torch .rand_like (z ) - 0.5 ) / half_width
170
- quantized = einx .where ('b, b ..., b ...' , offset_mask , unquantized + offset , quantized )
147
+ if self .training and noise_dropout > 0. :
148
+ offset_mask = torch .bernoulli (torch .full_like (z , noise_dropout )).bool ()
149
+ offset = torch .rand_like (z ) - 0.5
150
+ z = torch .where (offset_mask , z + offset , z )
171
151
172
- return quantized
152
+ return round_ste ( bound_fn ( z )) / half_width
173
153
174
154
def _scale_and_shift (self , zhat_normalized ):
175
155
half_width = self ._levels // 2
0 commit comments