Skip to content

Commit ce4a4fc

Browse files
committed
commit loss weighting for sim vq
1 parent 919a1b8 commit ce4a4fc

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.20.7"
3+
version = "1.20.8"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/sim_vq.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
accept_image_fmap = False,
4545
rotation_trick = True, # works even better with rotation trick turned on, with no straight through and the commit loss from input to quantize
4646
input_to_quantize_commit_loss_weight = 0.25,
47+
commitment_weight = 1.,
4748
frozen_codebook_dim = None # frozen codebook dim could have different dimensions than projection
4849
):
4950
super().__init__()
@@ -74,6 +75,10 @@ def __init__(
7475

7576
self.input_to_quantize_commit_loss_weight = input_to_quantize_commit_loss_weight
7677

78+
# total commitment loss weight
79+
80+
self.commitment_weight = commitment_weight
81+
7782
@property
7883
def codebook(self):
7984
return self.code_transform(self.frozen_codebook)
@@ -132,7 +137,7 @@ def forward(
132137

133138
indices = inverse_pack(indices, 'b *')
134139

135-
return quantized, indices, commit_loss
140+
return quantized, indices, commit_loss * self.commitment_weight
136141

137142
# main
138143

0 commit comments

Comments
 (0)