Skip to content

Commit ec2f4f6

Browse files
committed
address #188
1 parent 59a30b6 commit ec2f4f6

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -337,29 +337,29 @@ def forward(
337337

338338
# whether to only use a fraction of probs, for reducing memory
339339

340+
if exists(mask):
341+
input_for_entropy = original_input[mask]
342+
343+
input_for_entropy = rearrange(input_for_entropy, 'b n ... -> (b n) ...')
344+
340345
if self.frac_per_sample_entropy < 1.:
341346
# account for mask
342-
if exists(mask):
343-
original_input = original_input[mask]
344-
original_input = rearrange(original_input, 'b n ... -> (b n) ...')
345347

346-
num_tokens = original_input.size(0)
348+
num_tokens = input_for_entropy.size(0)
347349
num_sampled_tokens = int(num_tokens * self.frac_per_sample_entropy)
348350
rand_mask = torch.randn(num_tokens).argsort(dim = -1) < num_sampled_tokens
349351

350-
sampled_input = original_input[rand_mask]
352+
sampled_input = input_for_entropy[rand_mask]
351353

352354
sampled_distance = -2 * einsum('... i d, j d -> ... i j', sampled_input, codebook)
353355

354356
sampled_prob = (-sampled_distance * inv_temperature).softmax(dim = -1)
355357

356358
per_sample_probs = sampled_prob
357359
else:
358-
if exists(mask):
359-
original_input = original_input[mask]
360-
original_input = rearrange(original_input, 'b n ... -> (b n) ...')
360+
361361
# the same as euclidean distance up to a constant
362-
distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)
362+
distance = -2 * einsum('... i d, j d -> ... i j', input_for_entropy, codebook)
363363

364364
prob = (-distance * inv_temperature).softmax(dim = -1)
365365

0 commit comments

Comments
 (0)