10
10
from collections import namedtuple
11
11
12
12
import torch
13
- from torch import nn , Tensor
13
+ from torch import nn , Tensor , einsum
14
14
import torch .nn .functional as F
15
15
from torch .nn import Module
16
16
@@ -37,13 +37,21 @@ def pack_one(t, pattern):
37
37
def unpack_one (t , ps , pattern ):
38
38
return unpack (t , ps , pattern )[0 ]
39
39
40
+ # distance
41
+
42
+ def euclidean_distance_squared (x , y ):
43
+ x2 = reduce (x ** 2 , '... n d -> ... n' , 'sum' )
44
+ y2 = reduce (y ** 2 , 'n d -> n' , 'sum' )
45
+ xy = einsum ('... i d, j d -> ... i j' , x , y ) * - 2
46
+ return rearrange (x2 , '... i -> ... i 1' ) + y2 + xy
47
+
40
48
# entropy
41
49
42
50
def log (t , eps = 1e-20 ):
43
51
return t .clamp (min = eps ).log ()
44
52
45
- def binary_entropy (prob ):
46
- return - prob * log (prob ) - ( 1 - prob ) * log ( 1 - prob )
53
+ def entropy (prob ):
54
+ return - prob * log (prob )
47
55
48
56
# class
49
57
@@ -102,6 +110,14 @@ def __init__(
102
110
self .register_buffer ('mask' , 2 ** torch .arange (codebook_dim - 1 , - 1 , - 1 ))
103
111
self .register_buffer ('zero' , torch .zeros (1 ,), persistent = False )
104
112
113
+ # codes
114
+
115
+ all_codes = torch .arange (codebook_size )
116
+ bits = ((all_codes [..., None ].int () & self .mask ) != 0 ).float ()
117
+ codebook = bits * 2 - 1
118
+
119
+ self .register_buffer ('codebook' , codebook , persistent = False )
120
+
105
121
def indices_to_codes (
106
122
self ,
107
123
indices ,
@@ -183,17 +199,19 @@ def forward(
183
199
# entropy aux loss
184
200
185
201
if self .training :
186
- prob = (x * inv_temperature ).sigmoid ()
202
+ distance = euclidean_distance_squared (original_input , self .codebook )
203
+
204
+ prob = (- distance * inv_temperature ).softmax (dim = - 1 )
187
205
188
- bit_entropy = binary_entropy (prob ).mean ()
206
+ per_sample_entropy = entropy (prob ).mean ()
189
207
190
208
avg_prob = reduce (prob , 'b n c d -> b c d' , 'mean' )
191
- codebook_entropy = binary_entropy (avg_prob ).mean ()
209
+ codebook_entropy = entropy (avg_prob ).mean ()
192
210
193
- # 1. entropy will be nudged to be low for each bit, so each scalar commits to one latent binary bit or the other
194
- # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used
211
+ # 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
212
+ # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
195
213
196
- entropy_aux_loss = bit_entropy - self .diversity_gamma * codebook_entropy
214
+ entropy_aux_loss = per_sample_entropy - self .diversity_gamma * codebook_entropy
197
215
else :
198
216
# if not training, just return dummy 0
199
217
entropy_aux_loss = self .zero
0 commit comments