Skip to content

Commit a72d217

Browse files
committed
another change to noise dropout fsq
1 parent 4931ee0 commit a72d217

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.22.9"
3+
version = "1.22.10"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/finite_scalar_quantization.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,15 +141,17 @@ def quantize(self, z):
141141
shape, device, noise_dropout, preserve_symmetry, half_width = z.shape[0], z.device, self.noise_dropout, self.preserve_symmetry, (self._levels // 2)
142142
bound_fn = self.symmetry_preserving_bound if preserve_symmetry else self.bound
143143

144+
bounded_z = bound_fn(z)
145+
144146
# determine where to add a random offset elementwise
145147
# if using noise dropout
146148

147149
if self.training and noise_dropout > 0.:
148-
offset_mask = torch.bernoulli(torch.full_like(z, noise_dropout)).bool()
149-
offset = torch.rand_like(z) - 0.5
150-
z = torch.where(offset_mask, z + offset, z)
150+
offset_mask = torch.bernoulli(torch.full_like(bounded_z, noise_dropout)).bool()
151+
offset = torch.rand_like(bounded_z) - 0.5
152+
bounded_z = torch.where(offset_mask, bounded_z + offset, bounded_z)
151153

152-
return round_ste(bound_fn(z)) / half_width
154+
return round_ste(bounded_z) / half_width
153155

154156
def _scale_and_shift(self, zhat_normalized):
155157
half_width = self._levels // 2

0 commit comments

Comments
 (0)