Skip to content

Commit dd660ad

Browse files
committed
add ability to get straight through gradients, as well as use reinmax
1 parent 8dc0b71 commit dd660ad

File tree

3 files changed

+76
-17
lines changed

3 files changed

+76
-17
lines changed

README.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,9 @@ residual_vq = ResidualVQ(
7070
dim = 256,
7171
num_quantizers = 8,
7272
codebook_size = 1024,
73-
sample_codebook_temp = 0.1, # temperature for stochastically sampling codes, 0 would be equivalent to non-stochastic
74-
shared_codebook = True # whether to share the codebooks for all quantizers or not
73+
stochastic_sample_codes = True,
74+
sample_codebook_temp = 0.1, # temperature for stochastically sampling codes, 0 would be equivalent to non-stochastic
75+
shared_codebook = True # whether to share the codebooks for all quantizers or not
7576
)
7677

7778
x = torch.randn(1, 1024, 256)
@@ -406,3 +407,12 @@ if __name__ == '__main__':
406407
}
407408
```
408409

410+
```bibtex
411+
@article{Liu2023BridgingDA,
412+
title = {Bridging Discrete and Backpropagation: Straight-Through and Beyond},
413+
author = {Liyuan Liu and Chengyu Dong and Xiaodong Liu and Bin Yu and Jianfeng Gao},
414+
journal = {ArXiv},
415+
year = {2023},
416+
volume = {abs/2304.08612}
417+
}
418+
```

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.5.8',
6+
version = '1.5.10',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial
2+
13
import torch
24
from torch import nn, einsum
35
import torch.nn.functional as F
@@ -40,11 +42,43 @@ def gumbel_noise(t):
4042
noise = torch.zeros_like(t).uniform_(0, 1)
4143
return -log(-log(noise))
4244

43-
def gumbel_sample(t, temperature = 1., dim = -1):
44-
if temperature == 0:
45-
return t.argmax(dim = dim)
45+
def gumbel_sample(
46+
logits,
47+
temperature = 1.,
48+
stochastic = False,
49+
straight_through = False,
50+
reinmax = False,
51+
dim = -1
52+
):
53+
dtype, size = logits.dtype, logits.shape[dim]
54+
55+
if stochastic:
56+
logits = logits + gumbel_noise(logits)
57+
58+
ind = logits.argmax(dim = dim)
59+
one_hot = F.one_hot(ind, size).type(dtype)
60+
61+
assert not (reinmax and not straight_through), 'reinmax can only be turned on if using straight through gumbel softmax'
62+
63+
if not straight_through:
64+
return ind, one_hot
65+
66+
# use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
67+
# algorithm 2
4668

47-
return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)
69+
temperature = max(temperature, 1e-2)
70+
71+
if reinmax:
72+
π0 = logits.softmax(dim = dim)
73+
π1 = (one_hot + (logits / temperature).softmax(dim = dim)) / 2
74+
π1 = ((π1.log() - logits).detach() + logits).softmax(dim = 1)
75+
π2 = 2 * π1 - 0.5 * π0
76+
one_hot = π2 - π2.detach() + one_hot
77+
else:
78+
π1 = (logits / temperature).softmax(dim = dim)
79+
one_hot = one_hot + π1 - π1.detach()
80+
81+
return ind, one_hot
4882

4983
def laplace_smoothing(x, n_categories, eps = 1e-5, dim = -1):
5084
denom = x.sum(dim = dim, keepdim = True)
@@ -200,7 +234,9 @@ def __init__(
200234
reset_cluster_size = None,
201235
use_ddp = False,
202236
learnable_codebook = False,
203-
sample_codebook_temp = 0
237+
sample_codebook_temp = 0,
238+
straight_through = False,
239+
gumbel_sample = gumbel_sample
204240
):
205241
super().__init__()
206242
self.transform_input = identity
@@ -216,7 +252,9 @@ def __init__(
216252
self.eps = eps
217253
self.threshold_ema_dead_code = threshold_ema_dead_code
218254
self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code)
219-
self.sample_codebook_temp = sample_codebook_temp
255+
256+
assert callable(gumbel_sample)
257+
self.gumbel_sample = gumbel_sample
220258

221259
assert not (use_ddp and num_codebooks > 1 and kmeans_init), 'kmeans init is not compatible with multiple codebooks in distributed environment for now'
222260

@@ -295,8 +333,7 @@ def forward(self, x):
295333

296334
dist = -torch.cdist(flatten, embed, p = 2)
297335

298-
embed_ind = gumbel_sample(dist, dim = -1, temperature = self.sample_codebook_temp)
299-
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
336+
embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1)
300337
embed_ind = unpack_one(embed_ind, ps, 'h *')
301338

302339
quantize = batched_embedding(embed_ind, self.embed)
@@ -339,7 +376,7 @@ def __init__(
339376
reset_cluster_size = None,
340377
use_ddp = False,
341378
learnable_codebook = False,
342-
sample_codebook_temp = 0.
379+
gumbel_sample = gumbel_sample
343380
):
344381
super().__init__()
345382
self.transform_input = l2norm
@@ -358,7 +395,9 @@ def __init__(
358395
self.eps = eps
359396
self.threshold_ema_dead_code = threshold_ema_dead_code
360397
self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code)
361-
self.sample_codebook_temp = sample_codebook_temp
398+
399+
assert callable(gumbel_sample)
400+
self.gumbel_sample = gumbel_sample
362401

363402
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
364403
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
@@ -437,8 +476,7 @@ def forward(self, x):
437476
embed = self.embed if not self.learnable_codebook else self.embed.detach()
438477

439478
dist = einsum('h n d, h c d -> h n c', flatten, embed)
440-
embed_ind = gumbel_sample(dist, dim = -1, temperature = self.sample_codebook_temp)
441-
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
479+
embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1)
442480
embed_ind = unpack_one(embed_ind, ps, 'h *')
443481

444482
quantize = batched_embedding(embed_ind, self.embed)
@@ -491,8 +529,11 @@ def __init__(
491529
orthogonal_reg_weight = 0.,
492530
orthogonal_reg_active_codes_only = False,
493531
orthogonal_reg_max_codes = None,
532+
stochastic_sample_codes = False,
494533
sample_codebook_temp = 0.,
495-
sync_codebook = False
534+
straight_through = False,
535+
reinmax = False, # using reinmax for improved straight-through, assuming straight through helps at all
536+
sync_codebook = False,
496537
):
497538
super().__init__()
498539
self.dim = dim
@@ -517,6 +558,14 @@ def __init__(
517558

518559
codebook_class = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook
519560

561+
gumbel_sample_fn = partial(
562+
gumbel_sample,
563+
stochastic = stochastic_sample_codes,
564+
temperature = sample_codebook_temp,
565+
reinmax = reinmax,
566+
straight_through = straight_through
567+
)
568+
520569
self._codebook = codebook_class(
521570
dim = codebook_dim,
522571
num_codebooks = heads if separate_codebook_per_head else 1,
@@ -529,7 +578,7 @@ def __init__(
529578
threshold_ema_dead_code = threshold_ema_dead_code,
530579
use_ddp = sync_codebook,
531580
learnable_codebook = has_codebook_orthogonal_loss,
532-
sample_codebook_temp = sample_codebook_temp
581+
gumbel_sample = gumbel_sample_fn
533582
)
534583

535584
self.codebook_size = codebook_size

0 commit comments

Comments
 (0)