1
+ import torch
2
+ import pytest
3
+ from vector_quantize_pytorch import LFQ
4
+ import math
5
+ """
6
+ testing_strategy:
7
+ subdivisions: using masks, using frac_per_sample_entropy < 1
8
+ """
9
+
10
+ torch .manual_seed (0 )
11
+
12
+ @pytest .mark .parametrize ('frac_per_sample_entropy' , (1. , 0.5 ))
13
+ def test_masked_lfq (
14
+ frac_per_sample_entropy
15
+ ):
16
+ # you can specify either dim or codebook_size
17
+ # if both specified, will be validated against each other
18
+
19
+ quantizer = LFQ (
20
+ codebook_size = 65536 , # codebook size, must be a power of 2
21
+ dim = 16 , # this is the input feature dimension, defaults to log2(codebook_size) if not defined
22
+ entropy_loss_weight = 0.1 , # how much weight to place on entropy loss
23
+ diversity_gamma = 1. , # within entropy loss, how much weight to give to diversity
24
+ frac_per_sample_entropy = frac_per_sample_entropy
25
+ )
26
+
27
+ image_feats = torch .randn (2 , 16 , 32 , 32 )
28
+
29
+ ret , loss_breakdown = quantizer (image_feats , inv_temperature = 100. , return_loss_breakdown = True ) # you may want to experiment with temperature
30
+
31
+ quantized , indices , _ = ret
32
+ assert (quantized == quantizer .indices_to_codes (indices )).all ()
33
+
34
+ @pytest .mark .parametrize ('frac_per_sample_entropy' , (0.1 ,))
35
+ @pytest .mark .parametrize ('iters' , (10 ,))
36
+ @pytest .mark .parametrize ('mask' , (None , torch .tensor ([True , False ])))
37
+ def test_lfq_bruteforce_frac_per_sample_entropy (frac_per_sample_entropy , iters , mask ):
38
+ image_feats = torch .randn (2 , 16 , 32 , 32 )
39
+
40
+ full_per_sample_entropy_quantizer = LFQ (
41
+ codebook_size = 65536 , # codebook size, must be a power of 2
42
+ dim = 16 , # this is the input feature dimension, defaults to log2(codebook_size) if not defined
43
+ entropy_loss_weight = 0.1 , # how much weight to place on entropy loss
44
+ diversity_gamma = 1. , # within entropy loss, how much weight to give to diversity
45
+ frac_per_sample_entropy = 1
46
+ )
47
+
48
+ partial_per_sample_entropy_quantizer = LFQ (
49
+ codebook_size = 65536 , # codebook size, must be a power of 2
50
+ dim = 16 , # this is the input feature dimension, defaults to log2(codebook_size) if not defined
51
+ entropy_loss_weight = 0.1 , # how much weight to place on entropy loss
52
+ diversity_gamma = 1. , # within entropy loss, how much weight to give to diversity
53
+ frac_per_sample_entropy = frac_per_sample_entropy
54
+ )
55
+
56
+ ret , loss_breakdown = full_per_sample_entropy_quantizer (
57
+ image_feats , inv_temperature = 100. , return_loss_breakdown = True , mask = mask )
58
+ true_per_sample_entropy = loss_breakdown .per_sample_entropy
59
+
60
+ per_sample_losses = torch .zeros (iters )
61
+ for iter in range (iters ):
62
+ ret , loss_breakdown = partial_per_sample_entropy_quantizer (
63
+ image_feats , inv_temperature = 100. , return_loss_breakdown = True , mask = mask ) # you may want to experiment with temperature
64
+
65
+ quantized , indices , _ = ret
66
+ assert (quantized == partial_per_sample_entropy_quantizer .indices_to_codes (indices )).all ()
67
+ per_sample_losses [iter ] = loss_breakdown .per_sample_entropy
68
+ # 95% confidence interval
69
+ assert abs (per_sample_losses .mean () - true_per_sample_entropy ) \
70
+ < (1.96 * (per_sample_losses .std () / math .sqrt (iters )))
71
+
72
+ print ("difference: " , abs (per_sample_losses .mean () - true_per_sample_entropy ))
73
+ print ("std error:" , (1.96 * (per_sample_losses .std () / math .sqrt (iters ))))
0 commit comments