@@ -110,20 +110,24 @@ def forward(
110
110
111
111
should_quantize_dropout = self .training and self .quantize_dropout
112
112
113
+ # sample a layer index at which to dropout further residual quantization
114
+ # also prepare null indices and loss
115
+
113
116
if should_quantize_dropout :
114
117
rand_quantize_dropout_index = randrange (self .quantize_dropout_cutoff_index , num_quant )
115
118
116
119
if quant_dropout_multiple_of != 1 :
117
120
rand_quantize_dropout_index = round_up_multiple (rand_quantize_dropout_index + 1 , quant_dropout_multiple_of ) - 1
118
121
119
- null_indices_shape = (x .shape [0 ], * x .shape [- 2 :]) if self .accept_image_fmap else tuple (x .shape [:2 ])
122
+ null_indices_shape = (x .shape [0 ], * x .shape [- 2 :]) if self .accept_image_fmap else tuple (x .shape [:2 ])
123
+ null_indices = torch .full (null_indices_shape , - 1. , device = device , dtype = torch .long )
124
+ null_loss = torch .full ((1 ,), 0. , device = device , dtype = x .dtype )
125
+
126
+ # go through the layers
120
127
121
128
for quantizer_index , layer in enumerate (self .layers ):
122
129
123
130
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index :
124
- null_indices = torch .full (null_indices_shape , - 1. , device = device , dtype = torch .long )
125
- null_loss = torch .full ((1 ,), 0. , device = device , dtype = x .dtype )
126
-
127
131
all_indices .append (null_indices )
128
132
all_losses .append (null_loss )
129
133
continue
0 commit comments