@@ -679,6 +679,8 @@ def __init__(
679
679
self .commitment_weight = commitment_weight
680
680
self .commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss # whether to use cross entropy loss to codebook as commitment loss
681
681
682
+ self .learnable_codebook = learnable_codebook
683
+
682
684
has_codebook_orthogonal_loss = orthogonal_reg_weight > 0
683
685
self .has_codebook_orthogonal_loss = has_codebook_orthogonal_loss
684
686
self .orthogonal_reg_weight = orthogonal_reg_weight
@@ -806,7 +808,14 @@ def forward(
806
808
quantize , embed_ind , distances = self ._codebook (x , sample_codebook_temp = sample_codebook_temp )
807
809
808
810
if self .training :
809
- orig_quantize = quantize
811
+ # determine code to use for commitment loss
812
+
813
+ if not self .learnable_codebook :
814
+ commit_quantize = quantize .detach ()
815
+ else :
816
+ commit_quantize = quantize
817
+
818
+ # straight through
810
819
811
820
quantize = x + (quantize - x ).detach ()
812
821
@@ -869,14 +878,14 @@ def calculate_ce_loss(codes):
869
878
else :
870
879
if exists (mask ):
871
880
# with variable lengthed sequences
872
- commit_loss = F .mse_loss (orig_quantize , x , reduction = 'none' )
881
+ commit_loss = F .mse_loss (commit_quantize , x , reduction = 'none' )
873
882
874
883
if is_multiheaded :
875
884
mask = repeat (mask , 'b n -> c (b h) n' , c = commit_loss .shape [0 ], h = commit_loss .shape [1 ] // mask .shape [0 ])
876
885
877
886
commit_loss = commit_loss [mask ].mean ()
878
887
else :
879
- commit_loss = F .mse_loss (orig_quantize , x )
888
+ commit_loss = F .mse_loss (commit_quantize , x )
880
889
881
890
loss = loss + commit_loss * self .commitment_weight
882
891
0 commit comments