1
+ import torch
2
+ import pytest
3
+ from vector_quantize_pytorch import LFQ
4
+ import math
5
+ """
6
+ testing_strategy:
7
+ subdivisions: using masks, using frac_per_sample_entropy < 1
8
+ """
9
+
10
+ torch .manual_seed (0 )
11
+
12
+ @pytest .mark .parametrize ('frac_per_sample_entropy' , (1. , 0.5 ))
13
+ @pytest .mark .parametrize ('mask' , (torch .tensor ([False , False ]),
14
+ torch .tensor ([True , False ]),
15
+ torch .tensor ([True , True ])))
16
+ def test_masked_lfq (
17
+ frac_per_sample_entropy ,
18
+ mask
19
+ ):
20
+ # you can specify either dim or codebook_size
21
+ # if both specified, will be validated against each other
22
+
23
+ quantizer = LFQ (
24
+ codebook_size = 65536 , # codebook size, must be a power of 2
25
+ dim = 16 , # this is the input feature dimension, defaults to log2(codebook_size) if not defined
26
+ entropy_loss_weight = 0.1 , # how much weight to place on entropy loss
27
+ diversity_gamma = 1. , # within entropy loss, how much weight to give to diversity
28
+ frac_per_sample_entropy = frac_per_sample_entropy
29
+ )
30
+
31
+ image_feats = torch .randn (2 , 16 , 32 , 32 )
32
+
33
+ ret , loss_breakdown = quantizer (image_feats , inv_temperature = 100. , return_loss_breakdown = True , mask = mask ) # you may want to experiment with temperature
34
+
35
+ quantized , indices , _ = ret
36
+ assert (quantized == quantizer .indices_to_codes (indices )).all ()
37
+
38
+ @pytest .mark .parametrize ('frac_per_sample_entropy' , (0.1 ,))
39
+ @pytest .mark .parametrize ('iters' , (10 ,))
40
+ @pytest .mark .parametrize ('mask' , (None , torch .tensor ([True , False ])))
41
+ def test_lfq_bruteforce_frac_per_sample_entropy (frac_per_sample_entropy , iters , mask ):
42
+ image_feats = torch .randn (2 , 16 , 32 , 32 )
43
+
44
+ full_per_sample_entropy_quantizer = LFQ (
45
+ codebook_size = 65536 , # codebook size, must be a power of 2
46
+ dim = 16 , # this is the input feature dimension, defaults to log2(codebook_size) if not defined
47
+ entropy_loss_weight = 0.1 , # how much weight to place on entropy loss
48
+ diversity_gamma = 1. , # within entropy loss, how much weight to give to diversity
49
+ frac_per_sample_entropy = 1
50
+ )
51
+
52
+ partial_per_sample_entropy_quantizer = LFQ (
53
+ codebook_size = 65536 , # codebook size, must be a power of 2
54
+ dim = 16 , # this is the input feature dimension, defaults to log2(codebook_size) if not defined
55
+ entropy_loss_weight = 0.1 , # how much weight to place on entropy loss
56
+ diversity_gamma = 1. , # within entropy loss, how much weight to give to diversity
57
+ frac_per_sample_entropy = frac_per_sample_entropy
58
+ )
59
+
60
+ ret , loss_breakdown = full_per_sample_entropy_quantizer (
61
+ image_feats , inv_temperature = 100. , return_loss_breakdown = True , mask = mask )
62
+ true_per_sample_entropy = loss_breakdown .per_sample_entropy
63
+
64
+ per_sample_losses = torch .zeros (iters )
65
+ for iter in range (iters ):
66
+ ret , loss_breakdown = partial_per_sample_entropy_quantizer (
67
+ image_feats , inv_temperature = 100. , return_loss_breakdown = True , mask = mask ) # you may want to experiment with temperature
68
+
69
+ quantized , indices , _ = ret
70
+ assert (quantized == partial_per_sample_entropy_quantizer .indices_to_codes (indices )).all ()
71
+ per_sample_losses [iter ] = loss_breakdown .per_sample_entropy
72
+ # 95% confidence interval
73
+ assert abs (per_sample_losses .mean () - true_per_sample_entropy ) \
74
+ < (1.96 * (per_sample_losses .std () / math .sqrt (iters )))
75
+
76
+ print ("difference: " , abs (per_sample_losses .mean () - true_per_sample_entropy ))
77
+ print ("std error:" , (1.96 * (per_sample_losses .std () / math .sqrt (iters ))))
0 commit comments