Skip to content

Commit 3e09df7

Browse files
committed
addressed the issue with how entropy is calculated for LFQ #78
1 parent 5aab91c commit 3e09df7

File tree

2 files changed

+28
-10
lines changed

2 files changed

+28
-10
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.9.9',
6+
version = '1.9.10',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from collections import namedtuple
1111

1212
import torch
13-
from torch import nn, Tensor
13+
from torch import nn, Tensor, einsum
1414
import torch.nn.functional as F
1515
from torch.nn import Module
1616

@@ -37,13 +37,21 @@ def pack_one(t, pattern):
3737
def unpack_one(t, ps, pattern):
3838
return unpack(t, ps, pattern)[0]
3939

40+
# distance
41+
42+
def euclidean_distance_squared(x, y):
43+
x2 = reduce(x ** 2, '... n d -> ... n', 'sum')
44+
y2 = reduce(y ** 2, 'n d -> n', 'sum')
45+
xy = einsum('... i d, j d -> ... i j', x, y) * -2
46+
return rearrange(x2, '... i -> ... i 1') + y2 + xy
47+
4048
# entropy
4149

4250
def log(t, eps = 1e-20):
4351
return t.clamp(min = eps).log()
4452

45-
def binary_entropy(prob):
46-
return -prob * log(prob) - (1 - prob) * log(1 - prob)
53+
def entropy(prob):
54+
return -prob * log(prob)
4755

4856
# class
4957

@@ -102,6 +110,14 @@ def __init__(
102110
self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
103111
self.register_buffer('zero', torch.zeros(1,), persistent = False)
104112

113+
# codes
114+
115+
all_codes = torch.arange(codebook_size)
116+
bits = ((all_codes[..., None].int() & self.mask) != 0).float()
117+
codebook = bits * 2 - 1
118+
119+
self.register_buffer('codebook', codebook, persistent = False)
120+
105121
def indices_to_codes(
106122
self,
107123
indices,
@@ -183,17 +199,19 @@ def forward(
183199
# entropy aux loss
184200

185201
if self.training:
186-
prob = (x * inv_temperature).sigmoid()
202+
distance = euclidean_distance_squared(original_input, self.codebook)
203+
204+
prob = (-distance * inv_temperature).softmax(dim = -1)
187205

188-
bit_entropy = binary_entropy(prob).mean()
206+
per_sample_entropy = entropy(prob).mean()
189207

190208
avg_prob = reduce(prob, 'b n c d -> b c d', 'mean')
191-
codebook_entropy = binary_entropy(avg_prob).mean()
209+
codebook_entropy = entropy(avg_prob).mean()
192210

193-
# 1. entropy will be nudged to be low for each bit, so each scalar commits to one latent binary bit or the other
194-
# 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used
211+
# 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
212+
# 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
195213

196-
entropy_aux_loss = bit_entropy - self.diversity_gamma * codebook_entropy
214+
entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
197215
else:
198216
# if not training, just return dummy 0
199217
entropy_aux_loss = self.zero

0 commit comments

Comments
 (0)