Skip to content

Commit 9251915

Browse files
committed
Changed implementation of lfq's frac_per_sample_entropy
1 parent 76ec1de commit 9251915

File tree

2 files changed

+95
-14
lines changed

2 files changed

+95
-14
lines changed

tests/test_lfq.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import torch
2+
import pytest
3+
from vector_quantize_pytorch import LFQ
4+
import math
5+
"""
6+
testing_strategy:
7+
subdivisions: using masks, using frac_per_sample_entropy < 1
8+
"""
9+
10+
torch.manual_seed(0)
11+
12+
@pytest.mark.parametrize('frac_per_sample_entropy', (1., 0.5))
13+
def test_masked_lfq(
14+
frac_per_sample_entropy
15+
):
16+
# you can specify either dim or codebook_size
17+
# if both specified, will be validated against each other
18+
19+
quantizer = LFQ(
20+
codebook_size = 65536, # codebook size, must be a power of 2
21+
dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
22+
entropy_loss_weight = 0.1, # how much weight to place on entropy loss
23+
diversity_gamma = 1., # within entropy loss, how much weight to give to diversity
24+
frac_per_sample_entropy = frac_per_sample_entropy
25+
)
26+
27+
image_feats = torch.randn(2, 16, 32, 32)
28+
29+
ret, loss_breakdown = quantizer(image_feats, inv_temperature=100., return_loss_breakdown=True) # you may want to experiment with temperature
30+
31+
quantized, indices, _ = ret
32+
assert (quantized == quantizer.indices_to_codes(indices)).all()
33+
34+
@pytest.mark.parametrize('frac_per_sample_entropy', (0.1,))
35+
@pytest.mark.parametrize('iters', (10,))
36+
@pytest.mark.parametrize('mask', (None, torch.tensor([True, False])))
37+
def test_lfq_bruteforce_frac_per_sample_entropy(frac_per_sample_entropy, iters, mask):
38+
image_feats = torch.randn(2, 16, 32, 32)
39+
40+
full_per_sample_entropy_quantizer = LFQ(
41+
codebook_size = 65536, # codebook size, must be a power of 2
42+
dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
43+
entropy_loss_weight = 0.1, # how much weight to place on entropy loss
44+
diversity_gamma = 1., # within entropy loss, how much weight to give to diversity
45+
frac_per_sample_entropy = 1
46+
)
47+
48+
partial_per_sample_entropy_quantizer = LFQ(
49+
codebook_size = 65536, # codebook size, must be a power of 2
50+
dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
51+
entropy_loss_weight = 0.1, # how much weight to place on entropy loss
52+
diversity_gamma = 1., # within entropy loss, how much weight to give to diversity
53+
frac_per_sample_entropy = frac_per_sample_entropy
54+
)
55+
56+
ret, loss_breakdown = full_per_sample_entropy_quantizer(
57+
image_feats, inv_temperature=100., return_loss_breakdown=True, mask=mask)
58+
true_per_sample_entropy = loss_breakdown.per_sample_entropy
59+
60+
per_sample_losses = torch.zeros(iters)
61+
for iter in range(iters):
62+
ret, loss_breakdown = partial_per_sample_entropy_quantizer(
63+
image_feats, inv_temperature=100., return_loss_breakdown=True, mask=mask) # you may want to experiment with temperature
64+
65+
quantized, indices, _ = ret
66+
assert (quantized == partial_per_sample_entropy_quantizer.indices_to_codes(indices)).all()
67+
per_sample_losses[iter] = loss_breakdown.per_sample_entropy
68+
# 95% confidence interval
69+
assert abs(per_sample_losses.mean() - true_per_sample_entropy) \
70+
< (1.96*(per_sample_losses.std() / math.sqrt(iters)))
71+
72+
print("difference: ", abs(per_sample_losses.mean() - true_per_sample_entropy))
73+
print("std error:", (1.96*(per_sample_losses.std() / math.sqrt(iters))))

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -335,26 +335,34 @@ def forward(
335335

336336
codebook = self.maybe_l2norm(codebook)
337337

338-
# the same as euclidean distance up to a constant
339-
distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)
340-
341-
prob = (-distance * inv_temperature).softmax(dim = -1)
342-
343-
# account for mask
344-
345-
if exists(mask):
346-
prob = prob[mask]
347-
else:
348-
prob = rearrange(prob, 'b n ... -> (b n) ...')
349-
350338
# whether to only use a fraction of probs, for reducing memory
351339

352340
if self.frac_per_sample_entropy < 1.:
353-
num_tokens = prob.shape[0]
341+
# account for mask
342+
if exists(mask):
343+
original_input = original_input[mask]
344+
original_input = rearrange(original_input, 'b n ... -> (b n) ...')
345+
346+
num_tokens = original_input.size(0)
354347
num_sampled_tokens = int(num_tokens * self.frac_per_sample_entropy)
355348
rand_mask = torch.randn(num_tokens).argsort(dim = -1) < num_sampled_tokens
356-
per_sample_probs = prob[rand_mask]
349+
350+
sampled_input = original_input[rand_mask]
351+
352+
sampled_distance = -2 * einsum('... i d, j d -> ... i j', sampled_input, codebook)
353+
354+
sampled_prob = (-sampled_distance * inv_temperature).softmax(dim = -1)
355+
356+
per_sample_probs = sampled_prob
357357
else:
358+
if exists(mask):
359+
original_input = original_input[mask]
360+
original_input = rearrange(original_input, 'b n ... -> (b n) ...')
361+
# the same as euclidean distance up to a constant
362+
distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)
363+
364+
prob = (-distance * inv_temperature).softmax(dim = -1)
365+
358366
per_sample_probs = prob
359367

360368
# calculate per sample entropy

0 commit comments

Comments
 (0)