@@ -103,7 +103,6 @@ def gumbel_sample(
103
103
temperature = 1. ,
104
104
stochastic = False ,
105
105
straight_through = False ,
106
- reinmax = False ,
107
106
dim = - 1 ,
108
107
training = True
109
108
):
@@ -117,23 +116,11 @@ def gumbel_sample(
117
116
ind = sampling_logits .argmax (dim = dim )
118
117
one_hot = F .one_hot (ind , size ).type (dtype )
119
118
120
- assert not (reinmax and not straight_through ), 'reinmax can only be turned on if using straight through gumbel softmax'
121
-
122
119
if not straight_through or temperature <= 0. or not training :
123
120
return ind , one_hot
124
121
125
- # use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
126
- # algorithm 2
127
-
128
- if reinmax :
129
- π0 = logits .softmax (dim = dim )
130
- π1 = (one_hot + (logits / temperature ).softmax (dim = dim )) / 2
131
- π1 = ((log (π1 ) - logits ).detach () + logits ).softmax (dim = 1 )
132
- π2 = 2 * π1 - 0.5 * π0
133
- one_hot = π2 - π2 .detach () + one_hot
134
- else :
135
- π1 = (logits / temperature ).softmax (dim = dim )
136
- one_hot = one_hot + π1 - π1 .detach ()
122
+ π1 = (logits / temperature ).softmax (dim = dim )
123
+ one_hot = one_hot + π1 - π1 .detach ()
137
124
138
125
return ind , one_hot
139
126
@@ -828,7 +815,6 @@ def __init__(
828
815
sample_codebook_temp = 1. ,
829
816
straight_through = False ,
830
817
rotation_trick = True , # Propagate grads through VQ layer w/ rotation trick: https://arxiv.org/abs/2410.06424 by @cfifty
831
- reinmax = False , # using reinmax for improved straight-through, assuming straight through helps at all
832
818
sync_codebook = None ,
833
819
sync_affine_param = False ,
834
820
ema_update = True ,
@@ -895,7 +881,6 @@ def __init__(
895
881
gumbel_sample_fn = partial (
896
882
gumbel_sample ,
897
883
stochastic = stochastic_sample_codes ,
898
- reinmax = reinmax ,
899
884
straight_through = straight_through
900
885
)
901
886
0 commit comments