11
11
12
12
import torch
13
13
from torch import nn , Tensor
14
+ import torch .nn .functional as F
14
15
from torch .nn import Module
15
16
16
17
from einops import rearrange , reduce , pack , unpack
@@ -53,8 +54,9 @@ def __init__(
53
54
dim = None ,
54
55
codebook_size = None ,
55
56
entropy_loss_weight = 0.1 ,
57
+ commitment_loss_weight = 1. ,
56
58
diversity_gamma = 2.5 ,
57
- straight_through_activation = nn .Tanh (),
59
+ straight_through_activation = nn .Identity (),
58
60
num_codebooks = 1 ,
59
61
keep_num_codebooks_dim = None
60
62
):
@@ -91,6 +93,10 @@ def __init__(
91
93
self .diversity_gamma = diversity_gamma
92
94
self .entropy_loss_weight = entropy_loss_weight
93
95
96
+ # commitment loss
97
+
98
+ self .commitment_loss_weight = commitment_loss_weight
99
+
94
100
# for no auxiliary loss, during inference
95
101
96
102
self .register_buffer ('mask' , 2 ** torch .arange (codebook_dim - 1 , - 1 , - 1 ))
@@ -157,6 +163,8 @@ def forward(
157
163
158
164
# quantize by eq 3.
159
165
166
+ original_input = x
167
+
160
168
ones = torch .ones_like (x )
161
169
quantized = torch .where (x > 0 , ones , - ones )
162
170
@@ -190,7 +198,12 @@ def forward(
190
198
# if not training, just return dummy 0
191
199
entropy_aux_loss = self .zero
192
200
193
- entropy_aux_loss = entropy_aux_loss * self .entropy_loss_weight
201
+ # commit loss
202
+
203
+ if self .training :
204
+ commit_loss = F .mse_loss (original_input , quantized .detach ())
205
+ else :
206
+ commit_loss = self .zero
194
207
195
208
# merge back codebook dim
196
209
@@ -213,4 +226,8 @@ def forward(
213
226
if not self .keep_num_codebooks_dim :
214
227
indices = rearrange (indices , '... 1 -> ...' )
215
228
229
+ # complete aux loss
230
+
231
+ aux_loss = entropy_aux_loss * self .entropy_loss_weight + commit_loss * self .commitment_loss_weight
232
+
216
233
return Return (x , indices , entropy_aux_loss )
0 commit comments