Skip to content

Commit c724e8c

Browse files
committed
cleanup
1 parent 71b259e commit c724e8c

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.0.6',
6+
version = '1.0.7',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/residual_vq.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,20 +110,24 @@ def forward(
110110

111111
should_quantize_dropout = self.training and self.quantize_dropout
112112

113+
# sample a layer index at which to dropout further residual quantization
114+
# also prepare null indices and loss
115+
113116
if should_quantize_dropout:
114117
rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant)
115118

116119
if quant_dropout_multiple_of != 1:
117120
rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1
118121

119-
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
122+
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
123+
null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long)
124+
null_loss = torch.full((1,), 0., device = device, dtype = x.dtype)
125+
126+
# go through the layers
120127

121128
for quantizer_index, layer in enumerate(self.layers):
122129

123130
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
124-
null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long)
125-
null_loss = torch.full((1,), 0., device = device, dtype = x.dtype)
126-
127131
all_indices.append(null_indices)
128132
all_losses.append(null_loss)
129133
continue

0 commit comments

Comments
 (0)