Skip to content

Commit b48946f

Browse files
authored
Merge pull request #176 from EmmettBicker/master
Changed frac_per_sample_entropy to take up less memory
2 parents 76ec1de + cdd0141 commit b48946f

File tree

2 files changed

+99
-14
lines changed

2 files changed

+99
-14
lines changed

tests/test_lfq.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import torch
2+
import pytest
3+
from vector_quantize_pytorch import LFQ
4+
import math
5+
"""
6+
testing_strategy:
7+
subdivisions: using masks, using frac_per_sample_entropy < 1
8+
"""
9+
10+
torch.manual_seed(0)
11+
12+
@pytest.mark.parametrize('frac_per_sample_entropy', (1., 0.5))
13+
@pytest.mark.parametrize('mask', (torch.tensor([False, False]),
14+
torch.tensor([True, False]),
15+
torch.tensor([True, True])))
16+
def test_masked_lfq(
17+
frac_per_sample_entropy,
18+
mask
19+
):
20+
# you can specify either dim or codebook_size
21+
# if both specified, will be validated against each other
22+
23+
quantizer = LFQ(
24+
codebook_size = 65536, # codebook size, must be a power of 2
25+
dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
26+
entropy_loss_weight = 0.1, # how much weight to place on entropy loss
27+
diversity_gamma = 1., # within entropy loss, how much weight to give to diversity
28+
frac_per_sample_entropy = frac_per_sample_entropy
29+
)
30+
31+
image_feats = torch.randn(2, 16, 32, 32)
32+
33+
ret, loss_breakdown = quantizer(image_feats, inv_temperature=100., return_loss_breakdown=True, mask=mask) # you may want to experiment with temperature
34+
35+
quantized, indices, _ = ret
36+
assert (quantized == quantizer.indices_to_codes(indices)).all()
37+
38+
@pytest.mark.parametrize('frac_per_sample_entropy', (0.1,))
39+
@pytest.mark.parametrize('iters', (10,))
40+
@pytest.mark.parametrize('mask', (None, torch.tensor([True, False])))
41+
def test_lfq_bruteforce_frac_per_sample_entropy(frac_per_sample_entropy, iters, mask):
42+
image_feats = torch.randn(2, 16, 32, 32)
43+
44+
full_per_sample_entropy_quantizer = LFQ(
45+
codebook_size = 65536, # codebook size, must be a power of 2
46+
dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
47+
entropy_loss_weight = 0.1, # how much weight to place on entropy loss
48+
diversity_gamma = 1., # within entropy loss, how much weight to give to diversity
49+
frac_per_sample_entropy = 1
50+
)
51+
52+
partial_per_sample_entropy_quantizer = LFQ(
53+
codebook_size = 65536, # codebook size, must be a power of 2
54+
dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
55+
entropy_loss_weight = 0.1, # how much weight to place on entropy loss
56+
diversity_gamma = 1., # within entropy loss, how much weight to give to diversity
57+
frac_per_sample_entropy = frac_per_sample_entropy
58+
)
59+
60+
ret, loss_breakdown = full_per_sample_entropy_quantizer(
61+
image_feats, inv_temperature=100., return_loss_breakdown=True, mask=mask)
62+
true_per_sample_entropy = loss_breakdown.per_sample_entropy
63+
64+
per_sample_losses = torch.zeros(iters)
65+
for iter in range(iters):
66+
ret, loss_breakdown = partial_per_sample_entropy_quantizer(
67+
image_feats, inv_temperature=100., return_loss_breakdown=True, mask=mask) # you may want to experiment with temperature
68+
69+
quantized, indices, _ = ret
70+
assert (quantized == partial_per_sample_entropy_quantizer.indices_to_codes(indices)).all()
71+
per_sample_losses[iter] = loss_breakdown.per_sample_entropy
72+
# 95% confidence interval
73+
assert abs(per_sample_losses.mean() - true_per_sample_entropy) \
74+
< (1.96*(per_sample_losses.std() / math.sqrt(iters)))
75+
76+
print("difference: ", abs(per_sample_losses.mean() - true_per_sample_entropy))
77+
print("std error:", (1.96*(per_sample_losses.std() / math.sqrt(iters))))

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -335,26 +335,34 @@ def forward(
335335

336336
codebook = self.maybe_l2norm(codebook)
337337

338-
# the same as euclidean distance up to a constant
339-
distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)
340-
341-
prob = (-distance * inv_temperature).softmax(dim = -1)
342-
343-
# account for mask
344-
345-
if exists(mask):
346-
prob = prob[mask]
347-
else:
348-
prob = rearrange(prob, 'b n ... -> (b n) ...')
349-
350338
# whether to only use a fraction of probs, for reducing memory
351339

352340
if self.frac_per_sample_entropy < 1.:
353-
num_tokens = prob.shape[0]
341+
# account for mask
342+
if exists(mask):
343+
original_input = original_input[mask]
344+
original_input = rearrange(original_input, 'b n ... -> (b n) ...')
345+
346+
num_tokens = original_input.size(0)
354347
num_sampled_tokens = int(num_tokens * self.frac_per_sample_entropy)
355348
rand_mask = torch.randn(num_tokens).argsort(dim = -1) < num_sampled_tokens
356-
per_sample_probs = prob[rand_mask]
349+
350+
sampled_input = original_input[rand_mask]
351+
352+
sampled_distance = -2 * einsum('... i d, j d -> ... i j', sampled_input, codebook)
353+
354+
sampled_prob = (-sampled_distance * inv_temperature).softmax(dim = -1)
355+
356+
per_sample_probs = sampled_prob
357357
else:
358+
if exists(mask):
359+
original_input = original_input[mask]
360+
original_input = rearrange(original_input, 'b n ... -> (b n) ...')
361+
# the same as euclidean distance up to a constant
362+
distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)
363+
364+
prob = (-distance * inv_temperature).softmax(dim = -1)
365+
358366
per_sample_probs = prob
359367

360368
# calculate per sample entropy

0 commit comments

Comments
 (0)