Skip to content

Commit 5aab91c

Browse files
committed
address first two points #78
1 parent 8d50c02 commit 5aab91c

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.9.7',
6+
version = '1.9.9',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torch
1313
from torch import nn, Tensor
14+
import torch.nn.functional as F
1415
from torch.nn import Module
1516

1617
from einops import rearrange, reduce, pack, unpack
@@ -53,8 +54,9 @@ def __init__(
5354
dim = None,
5455
codebook_size = None,
5556
entropy_loss_weight = 0.1,
57+
commitment_loss_weight = 1.,
5658
diversity_gamma = 2.5,
57-
straight_through_activation = nn.Tanh(),
59+
straight_through_activation = nn.Identity(),
5860
num_codebooks = 1,
5961
keep_num_codebooks_dim = None
6062
):
@@ -91,6 +93,10 @@ def __init__(
9193
self.diversity_gamma = diversity_gamma
9294
self.entropy_loss_weight = entropy_loss_weight
9395

96+
# commitment loss
97+
98+
self.commitment_loss_weight = commitment_loss_weight
99+
94100
# for no auxiliary loss, during inference
95101

96102
self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
@@ -157,6 +163,8 @@ def forward(
157163

158164
# quantize by eq 3.
159165

166+
original_input = x
167+
160168
ones = torch.ones_like(x)
161169
quantized = torch.where(x > 0, ones, -ones)
162170

@@ -190,7 +198,12 @@ def forward(
190198
# if not training, just return dummy 0
191199
entropy_aux_loss = self.zero
192200

193-
entropy_aux_loss = entropy_aux_loss * self.entropy_loss_weight
201+
# commit loss
202+
203+
if self.training:
204+
commit_loss = F.mse_loss(original_input, quantized.detach())
205+
else:
206+
commit_loss = self.zero
194207

195208
# merge back codebook dim
196209

@@ -213,4 +226,8 @@ def forward(
213226
if not self.keep_num_codebooks_dim:
214227
indices = rearrange(indices, '... 1 -> ...')
215228

229+
# complete aux loss
230+
231+
aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
232+
216233
return Return(x, indices, entropy_aux_loss)

0 commit comments

Comments
 (0)