@@ -61,7 +61,7 @@ def gumbel_sample(
61
61
assert not (reinmax and not straight_through ), 'reinmax can only be turned on if using straight through gumbel softmax'
62
62
63
63
if not straight_through :
64
- return ind , one_hot
64
+ return ind , one_hot , None
65
65
66
66
# use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
67
67
# algorithm 2
@@ -78,7 +78,9 @@ def gumbel_sample(
78
78
π1 = (logits / temperature ).softmax (dim = dim )
79
79
one_hot = one_hot + π1 - π1 .detach ()
80
80
81
- return ind , one_hot
81
+ st_mult = one_hot .gather (- 1 , rearrange (ind , '... -> ... 1' )) # multiplier for straight-through
82
+
83
+ return ind , one_hot , st_mult
82
84
83
85
def laplace_smoothing (x , n_categories , eps = 1e-5 , dim = - 1 ):
84
86
denom = x .sum (dim = dim , keepdim = True )
@@ -333,11 +335,16 @@ def forward(self, x):
333
335
334
336
dist = - torch .cdist (flatten , embed , p = 2 )
335
337
336
- embed_ind , embed_onehot = self .gumbel_sample (dist , dim = - 1 )
338
+ embed_ind , embed_onehot , straight_through_mult = self .gumbel_sample (dist , dim = - 1 )
339
+
337
340
embed_ind = unpack_one (embed_ind , ps , 'h *' )
338
341
339
342
quantize = batched_embedding (embed_ind , self .embed )
340
343
344
+ if exists (straight_through_mult ):
345
+ mult = unpack_one (straight_through_mult , ps , 'h * d' )
346
+ quantize = quantize * mult
347
+
341
348
if self .training :
342
349
cluster_size = embed_onehot .sum (dim = 1 )
343
350
@@ -476,11 +483,15 @@ def forward(self, x):
476
483
embed = self .embed if not self .learnable_codebook else self .embed .detach ()
477
484
478
485
dist = einsum ('h n d, h c d -> h n c' , flatten , embed )
479
- embed_ind , embed_onehot = self .gumbel_sample (dist , dim = - 1 )
486
+ embed_ind , embed_onehot , straight_through_mult = self .gumbel_sample (dist , dim = - 1 )
480
487
embed_ind = unpack_one (embed_ind , ps , 'h *' )
481
488
482
489
quantize = batched_embedding (embed_ind , self .embed )
483
490
491
+ if exists (straight_through_mult ):
492
+ mult = unpack_one (straight_through_mult , ps , 'h * d' )
493
+ quantize = quantize * mult
494
+
484
495
if self .training :
485
496
bins = embed_onehot .sum (dim = 1 )
486
497
self .all_reduce_fn (bins )
0 commit comments