1
+ from functools import partial
2
+
1
3
import torch
2
4
from torch import nn , einsum
3
5
import torch .nn .functional as F
@@ -40,11 +42,43 @@ def gumbel_noise(t):
40
42
noise = torch .zeros_like (t ).uniform_ (0 , 1 )
41
43
return - log (- log (noise ))
42
44
43
- def gumbel_sample (t , temperature = 1. , dim = - 1 ):
44
- if temperature == 0 :
45
- return t .argmax (dim = dim )
45
+ def gumbel_sample (
46
+ logits ,
47
+ temperature = 1. ,
48
+ stochastic = False ,
49
+ straight_through = False ,
50
+ reinmax = False ,
51
+ dim = - 1
52
+ ):
53
+ dtype , size = logits .dtype , logits .shape [dim ]
54
+
55
+ if stochastic :
56
+ logits = logits + gumbel_noise (logits )
57
+
58
+ ind = logits .argmax (dim = dim )
59
+ one_hot = F .one_hot (ind , size ).type (dtype )
60
+
61
+ assert not (reinmax and not straight_through ), 'reinmax can only be turned on if using straight through gumbel softmax'
62
+
63
+ if not straight_through :
64
+ return ind , one_hot
65
+
66
+ # use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
67
+ # algorithm 2
46
68
47
- return ((t / temperature ) + gumbel_noise (t )).argmax (dim = dim )
69
+ temperature = max (temperature , 1e-2 )
70
+
71
+ if reinmax :
72
+ π0 = logits .softmax (dim = dim )
73
+ π1 = (one_hot + (logits / temperature ).softmax (dim = dim )) / 2
74
+ π1 = ((π1 .log () - logits ).detach () + logits ).softmax (dim = 1 )
75
+ π2 = 2 * π1 - 0.5 * π0
76
+ one_hot = π2 - π2 .detach () + one_hot
77
+ else :
78
+ π1 = (logits / temperature ).softmax (dim = dim )
79
+ one_hot = one_hot + π1 - π1 .detach ()
80
+
81
+ return ind , one_hot
48
82
49
83
def laplace_smoothing (x , n_categories , eps = 1e-5 , dim = - 1 ):
50
84
denom = x .sum (dim = dim , keepdim = True )
@@ -200,7 +234,9 @@ def __init__(
200
234
reset_cluster_size = None ,
201
235
use_ddp = False ,
202
236
learnable_codebook = False ,
203
- sample_codebook_temp = 0
237
+ sample_codebook_temp = 0 ,
238
+ straight_through = False ,
239
+ gumbel_sample = gumbel_sample
204
240
):
205
241
super ().__init__ ()
206
242
self .transform_input = identity
@@ -216,7 +252,9 @@ def __init__(
216
252
self .eps = eps
217
253
self .threshold_ema_dead_code = threshold_ema_dead_code
218
254
self .reset_cluster_size = default (reset_cluster_size , threshold_ema_dead_code )
219
- self .sample_codebook_temp = sample_codebook_temp
255
+
256
+ assert callable (gumbel_sample )
257
+ self .gumbel_sample = gumbel_sample
220
258
221
259
assert not (use_ddp and num_codebooks > 1 and kmeans_init ), 'kmeans init is not compatible with multiple codebooks in distributed environment for now'
222
260
@@ -295,8 +333,7 @@ def forward(self, x):
295
333
296
334
dist = - torch .cdist (flatten , embed , p = 2 )
297
335
298
- embed_ind = gumbel_sample (dist , dim = - 1 , temperature = self .sample_codebook_temp )
299
- embed_onehot = F .one_hot (embed_ind , self .codebook_size ).type (dtype )
336
+ embed_ind , embed_onehot = self .gumbel_sample (dist , dim = - 1 )
300
337
embed_ind = unpack_one (embed_ind , ps , 'h *' )
301
338
302
339
quantize = batched_embedding (embed_ind , self .embed )
@@ -339,7 +376,7 @@ def __init__(
339
376
reset_cluster_size = None ,
340
377
use_ddp = False ,
341
378
learnable_codebook = False ,
342
- sample_codebook_temp = 0.
379
+ gumbel_sample = gumbel_sample
343
380
):
344
381
super ().__init__ ()
345
382
self .transform_input = l2norm
@@ -358,7 +395,9 @@ def __init__(
358
395
self .eps = eps
359
396
self .threshold_ema_dead_code = threshold_ema_dead_code
360
397
self .reset_cluster_size = default (reset_cluster_size , threshold_ema_dead_code )
361
- self .sample_codebook_temp = sample_codebook_temp
398
+
399
+ assert callable (gumbel_sample )
400
+ self .gumbel_sample = gumbel_sample
362
401
363
402
self .sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
364
403
self .kmeans_all_reduce_fn = distributed .all_reduce if use_ddp and sync_kmeans else noop
@@ -437,8 +476,7 @@ def forward(self, x):
437
476
embed = self .embed if not self .learnable_codebook else self .embed .detach ()
438
477
439
478
dist = einsum ('h n d, h c d -> h n c' , flatten , embed )
440
- embed_ind = gumbel_sample (dist , dim = - 1 , temperature = self .sample_codebook_temp )
441
- embed_onehot = F .one_hot (embed_ind , self .codebook_size ).type (dtype )
479
+ embed_ind , embed_onehot = self .gumbel_sample (dist , dim = - 1 )
442
480
embed_ind = unpack_one (embed_ind , ps , 'h *' )
443
481
444
482
quantize = batched_embedding (embed_ind , self .embed )
@@ -491,8 +529,11 @@ def __init__(
491
529
orthogonal_reg_weight = 0. ,
492
530
orthogonal_reg_active_codes_only = False ,
493
531
orthogonal_reg_max_codes = None ,
532
+ stochastic_sample_codes = False ,
494
533
sample_codebook_temp = 0. ,
495
- sync_codebook = False
534
+ straight_through = False ,
535
+ reinmax = False , # using reinmax for improved straight-through, assuming straight through helps at all
536
+ sync_codebook = False ,
496
537
):
497
538
super ().__init__ ()
498
539
self .dim = dim
@@ -517,6 +558,14 @@ def __init__(
517
558
518
559
codebook_class = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook
519
560
561
+ gumbel_sample_fn = partial (
562
+ gumbel_sample ,
563
+ stochastic = stochastic_sample_codes ,
564
+ temperature = sample_codebook_temp ,
565
+ reinmax = reinmax ,
566
+ straight_through = straight_through
567
+ )
568
+
520
569
self ._codebook = codebook_class (
521
570
dim = codebook_dim ,
522
571
num_codebooks = heads if separate_codebook_per_head else 1 ,
@@ -529,7 +578,7 @@ def __init__(
529
578
threshold_ema_dead_code = threshold_ema_dead_code ,
530
579
use_ddp = sync_codebook ,
531
580
learnable_codebook = has_codebook_orthogonal_loss ,
532
- sample_codebook_temp = sample_codebook_temp
581
+ gumbel_sample = gumbel_sample_fn
533
582
)
534
583
535
584
self .codebook_size = codebook_size
0 commit comments