Skip to content

Commit 46bd7af

Browse files
committed
it had its chance
1 parent ca90db2 commit 46bd7af

File tree

2 files changed

+2
-27
lines changed

2 files changed

+2
-27
lines changed

README.md

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -611,16 +611,6 @@ assert loss.item() >= 0
611611
}
612612
```
613613

614-
```bibtex
615-
@article{Liu2023BridgingDA,
616-
title = {Bridging Discrete and Backpropagation: Straight-Through and Beyond},
617-
author = {Liyuan Liu and Chengyu Dong and Xiaodong Liu and Bin Yu and Jianfeng Gao},
618-
journal = {ArXiv},
619-
year = {2023},
620-
volume = {abs/2304.08612}
621-
}
622-
```
623-
624614
```bibtex
625615
@inproceedings{huh2023improvedvqste,
626616
title = {Straightening Out the Straight-Through Estimator: Overcoming Optimization Challenges in Vector Quantized Networks},

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def gumbel_sample(
103103
temperature = 1.,
104104
stochastic = False,
105105
straight_through = False,
106-
reinmax = False,
107106
dim = -1,
108107
training = True
109108
):
@@ -117,23 +116,11 @@ def gumbel_sample(
117116
ind = sampling_logits.argmax(dim = dim)
118117
one_hot = F.one_hot(ind, size).type(dtype)
119118

120-
assert not (reinmax and not straight_through), 'reinmax can only be turned on if using straight through gumbel softmax'
121-
122119
if not straight_through or temperature <= 0. or not training:
123120
return ind, one_hot
124121

125-
# use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
126-
# algorithm 2
127-
128-
if reinmax:
129-
π0 = logits.softmax(dim = dim)
130-
π1 = (one_hot + (logits / temperature).softmax(dim = dim)) / 2
131-
π1 = ((log(π1) - logits).detach() + logits).softmax(dim = 1)
132-
π2 = 2 * π1 - 0.5 * π0
133-
one_hot = π2 - π2.detach() + one_hot
134-
else:
135-
π1 = (logits / temperature).softmax(dim = dim)
136-
one_hot = one_hot + π1 - π1.detach()
122+
π1 = (logits / temperature).softmax(dim = dim)
123+
one_hot = one_hot + π1 - π1.detach()
137124

138125
return ind, one_hot
139126

@@ -828,7 +815,6 @@ def __init__(
828815
sample_codebook_temp = 1.,
829816
straight_through = False,
830817
rotation_trick = True, # Propagate grads through VQ layer w/ rotation trick: https://arxiv.org/abs/2410.06424 by @cfifty
831-
reinmax = False, # using reinmax for improved straight-through, assuming straight through helps at all
832818
sync_codebook = None,
833819
sync_affine_param = False,
834820
ema_update = True,
@@ -895,7 +881,6 @@ def __init__(
895881
gumbel_sample_fn = partial(
896882
gumbel_sample,
897883
stochastic = stochastic_sample_codes,
898-
reinmax = reinmax,
899884
straight_through = straight_through
900885
)
901886

0 commit comments

Comments
 (0)