Skip to content

Commit 7c9f1dc

Browse files
committed
fix a bug with straight-through, thanks to @goldbird5
1 parent e682e59 commit 7c9f1dc

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.5.14',
6+
version = '1.5.15',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def gumbel_sample(
6161
assert not (reinmax and not straight_through), 'reinmax can only be turned on if using straight through gumbel softmax'
6262

6363
if not straight_through:
64-
return ind, one_hot
64+
return ind, one_hot, None
6565

6666
# use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
6767
# algorithm 2
@@ -78,7 +78,9 @@ def gumbel_sample(
7878
π1 = (logits / temperature).softmax(dim = dim)
7979
one_hot = one_hot + π1 - π1.detach()
8080

81-
return ind, one_hot
81+
st_mult = one_hot.gather(-1, rearrange(ind, '... -> ... 1')) # multiplier for straight-through
82+
83+
return ind, one_hot, st_mult
8284

8385
def laplace_smoothing(x, n_categories, eps = 1e-5, dim = -1):
8486
denom = x.sum(dim = dim, keepdim = True)
@@ -333,11 +335,16 @@ def forward(self, x):
333335

334336
dist = -torch.cdist(flatten, embed, p = 2)
335337

336-
embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1)
338+
embed_ind, embed_onehot, straight_through_mult = self.gumbel_sample(dist, dim = -1)
339+
337340
embed_ind = unpack_one(embed_ind, ps, 'h *')
338341

339342
quantize = batched_embedding(embed_ind, self.embed)
340343

344+
if exists(straight_through_mult):
345+
mult = unpack_one(straight_through_mult, ps, 'h * d')
346+
quantize = quantize * mult
347+
341348
if self.training:
342349
cluster_size = embed_onehot.sum(dim = 1)
343350

@@ -476,11 +483,15 @@ def forward(self, x):
476483
embed = self.embed if not self.learnable_codebook else self.embed.detach()
477484

478485
dist = einsum('h n d, h c d -> h n c', flatten, embed)
479-
embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1)
486+
embed_ind, embed_onehot, straight_through_mult = self.gumbel_sample(dist, dim = -1)
480487
embed_ind = unpack_one(embed_ind, ps, 'h *')
481488

482489
quantize = batched_embedding(embed_ind, self.embed)
483490

491+
if exists(straight_through_mult):
492+
mult = unpack_one(straight_through_mult, ps, 'h * d')
493+
quantize = quantize * mult
494+
484495
if self.training:
485496
bins = embed_onehot.sum(dim = 1)
486497
self.all_reduce_fn(bins)

0 commit comments

Comments
 (0)