@@ -337,29 +337,29 @@ def forward(
337
337
338
338
# whether to only use a fraction of probs, for reducing memory
339
339
340
+ if exists (mask ):
341
+ input_for_entropy = original_input [mask ]
342
+
343
+ input_for_entropy = rearrange (input_for_entropy , 'b n ... -> (b n) ...' )
344
+
340
345
if self .frac_per_sample_entropy < 1. :
341
346
# account for mask
342
- if exists (mask ):
343
- original_input = original_input [mask ]
344
- original_input = rearrange (original_input , 'b n ... -> (b n) ...' )
345
347
346
- num_tokens = original_input .size (0 )
348
+ num_tokens = input_for_entropy .size (0 )
347
349
num_sampled_tokens = int (num_tokens * self .frac_per_sample_entropy )
348
350
rand_mask = torch .randn (num_tokens ).argsort (dim = - 1 ) < num_sampled_tokens
349
351
350
- sampled_input = original_input [rand_mask ]
352
+ sampled_input = input_for_entropy [rand_mask ]
351
353
352
354
sampled_distance = - 2 * einsum ('... i d, j d -> ... i j' , sampled_input , codebook )
353
355
354
356
sampled_prob = (- sampled_distance * inv_temperature ).softmax (dim = - 1 )
355
357
356
358
per_sample_probs = sampled_prob
357
359
else :
358
- if exists (mask ):
359
- original_input = original_input [mask ]
360
- original_input = rearrange (original_input , 'b n ... -> (b n) ...' )
360
+
361
361
# the same as euclidean distance up to a constant
362
- distance = - 2 * einsum ('... i d, j d -> ... i j' , original_input , codebook )
362
+ distance = - 2 * einsum ('... i d, j d -> ... i j' , input_for_entropy , codebook )
363
363
364
364
prob = (- distance * inv_temperature ).softmax (dim = - 1 )
365
365
0 commit comments