Skip to content

Commit 1ea2ef6

Browse files
committed
throw in an option to use code agnostic commit loss for LFQ, found empirically to work well by @MattMcPartlon
1 parent 190ac99 commit 1ea2ef6

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.14.12"
3+
version = "1.14.15"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,9 @@ def __init__(
6262
straight_through_activation = nn.Identity(),
6363
num_codebooks = 1,
6464
keep_num_codebooks_dim = None,
65-
codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer
66-
frac_per_sample_entropy = 1. # make less than 1. to only use a random fraction of the probs for per sample entropy
65+
codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer
66+
frac_per_sample_entropy = 1., # make less than 1. to only use a random fraction of the probs for per sample entropy
67+
use_code_agnostic_commit_loss = False
6768
):
6869
super().__init__()
6970

@@ -110,6 +111,7 @@ def __init__(
110111
# commitment loss
111112

112113
self.commitment_loss_weight = commitment_loss_weight
114+
self.use_code_agnostic_commit_loss = use_code_agnostic_commit_loss
113115

114116
# for no auxiliary loss, during inference
115117

@@ -259,8 +261,19 @@ def forward(
259261

260262
# commit loss
261263

262-
if self.training:
263-
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')
264+
if self.training and self.commitment_loss_weight > 0.:
265+
266+
if self.use_code_agnostic_commit_loss:
267+
# credit goes to @MattMcPartlon for sharing this in https://github.com/lucidrains/vector-quantize-pytorch/issues/120#issuecomment-2095089337
268+
269+
commit_loss = F.mse_loss(
270+
original_input ** 2,
271+
codebook_value ** 2,
272+
reduction = 'none'
273+
)
274+
275+
else:
276+
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')
264277

265278
if exists(mask):
266279
commit_loss = commit_loss[mask]

0 commit comments

Comments
 (0)