@@ -62,8 +62,9 @@ def __init__(
62
62
straight_through_activation = nn .Identity (),
63
63
num_codebooks = 1 ,
64
64
keep_num_codebooks_dim = None ,
65
- codebook_scale = 1. , # for residual LFQ, codebook scaled down by 2x at each layer
66
- frac_per_sample_entropy = 1. # make less than 1. to only use a random fraction of the probs for per sample entropy
65
+ codebook_scale = 1. , # for residual LFQ, codebook scaled down by 2x at each layer
66
+ frac_per_sample_entropy = 1. , # make less than 1. to only use a random fraction of the probs for per sample entropy
67
+ use_code_agnostic_commit_loss = False
67
68
):
68
69
super ().__init__ ()
69
70
@@ -110,6 +111,7 @@ def __init__(
110
111
# commitment loss
111
112
112
113
self .commitment_loss_weight = commitment_loss_weight
114
+ self .use_code_agnostic_commit_loss = use_code_agnostic_commit_loss
113
115
114
116
# for no auxiliary loss, during inference
115
117
@@ -259,8 +261,19 @@ def forward(
259
261
260
262
# commit loss
261
263
262
- if self .training :
263
- commit_loss = F .mse_loss (original_input , quantized .detach (), reduction = 'none' )
264
+ if self .training and self .commitment_loss_weight > 0. :
265
+
266
+ if self .use_code_agnostic_commit_loss :
267
+ # credit goes to @MattMcPartlon for sharing this in https://github.com/lucidrains/vector-quantize-pytorch/issues/120#issuecomment-2095089337
268
+
269
+ commit_loss = F .mse_loss (
270
+ original_input ** 2 ,
271
+ codebook_value ** 2 ,
272
+ reduction = 'none'
273
+ )
274
+
275
+ else :
276
+ commit_loss = F .mse_loss (original_input , quantized .detach (), reduction = 'none' )
264
277
265
278
if exists (mask ):
266
279
commit_loss = commit_loss [mask ]
0 commit comments