@@ -16,6 +16,9 @@ def default(val, d):
16
16
def noop (* args , ** kwargs ):
17
17
pass
18
18
19
+ def identity (t ):
20
+ return t
21
+
19
22
def l2norm (t ):
20
23
return F .normalize (t , p = 2 , dim = - 1 )
21
24
@@ -200,6 +203,8 @@ def __init__(
200
203
sample_codebook_temp = 0
201
204
):
202
205
super ().__init__ ()
206
+ self .transform_input = identity
207
+
203
208
self .decay = decay
204
209
init_fn = uniform_init if not kmeans_init else torch .zeros
205
210
embed = init_fn (num_codebooks , codebook_size , dim )
@@ -294,6 +299,8 @@ def forward(self, x):
294
299
embed_onehot = F .one_hot (embed_ind , self .codebook_size ).type (dtype )
295
300
embed_ind = unpack_one (embed_ind , ps , 'h *' )
296
301
302
+ quantize = batched_embedding (embed_ind , self .embed )
303
+
297
304
if self .training :
298
305
cluster_size = embed_onehot .sum (dim = 1 )
299
306
@@ -310,11 +317,6 @@ def forward(self, x):
310
317
self .embed .data .copy_ (embed_normalized )
311
318
self .expire_codes_ (x )
312
319
313
- quantize = batched_embedding (embed_ind , self .embed )
314
-
315
- if self .training :
316
- quantize = x + (quantize - x ).detach ()
317
-
318
320
if needs_codebook_dim :
319
321
quantize , embed_ind = map (lambda t : rearrange (t , '1 ... -> ...' ), (quantize , embed_ind ))
320
322
@@ -340,6 +342,8 @@ def __init__(
340
342
sample_codebook_temp = 0.
341
343
):
342
344
super ().__init__ ()
345
+ self .transform_input = l2norm
346
+
343
347
self .decay = decay
344
348
345
349
if not kmeans_init :
@@ -427,18 +431,18 @@ def forward(self, x):
427
431
dtype = x .dtype
428
432
429
433
flatten , ps = pack_one (x , 'h * d' )
430
- flatten = l2norm (flatten )
431
434
432
435
self .init_embed_ (flatten )
433
436
434
437
embed = self .embed if not self .learnable_codebook else self .embed .detach ()
435
- embed = l2norm (embed )
436
438
437
439
dist = einsum ('h n d, h c d -> h n c' , flatten , embed )
438
440
embed_ind = gumbel_sample (dist , dim = - 1 , temperature = self .sample_codebook_temp )
439
441
embed_onehot = F .one_hot (embed_ind , self .codebook_size ).type (dtype )
440
442
embed_ind = unpack_one (embed_ind , ps , 'h *' )
441
443
444
+ quantize = batched_embedding (embed_ind , self .embed )
445
+
442
446
if self .training :
443
447
bins = embed_onehot .sum (dim = 1 )
444
448
self .all_reduce_fn (bins )
@@ -457,12 +461,6 @@ def forward(self, x):
457
461
self .embed .data .copy_ (l2norm (embed_normalized ))
458
462
self .expire_codes_ (x )
459
463
460
- quantize = batched_embedding (embed_ind , self .embed )
461
-
462
- if self .training :
463
- l2norm_x = l2norm (x )
464
- quantize = l2norm_x + (quantize - l2norm_x ).detach ()
465
-
466
464
if needs_codebook_dim :
467
465
quantize , embed_ind = map (lambda t : rearrange (t , '1 ... -> ...' ), (quantize , embed_ind ))
468
466
@@ -489,6 +487,7 @@ def __init__(
489
487
channel_last = True ,
490
488
accept_image_fmap = False ,
491
489
commitment_weight = 1. ,
490
+ commitment_use_cross_entropy_loss = False ,
492
491
orthogonal_reg_weight = 0. ,
493
492
orthogonal_reg_active_codes_only = False ,
494
493
orthogonal_reg_max_codes = None ,
@@ -509,6 +508,7 @@ def __init__(
509
508
510
509
self .eps = eps
511
510
self .commitment_weight = commitment_weight
511
+ self .commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss # whether to use cross entropy loss to codebook as commitment loss
512
512
513
513
has_codebook_orthogonal_loss = orthogonal_reg_weight > 0
514
514
self .orthogonal_reg_weight = orthogonal_reg_weight
@@ -588,39 +588,58 @@ def forward(
588
588
589
589
x = self .project_in (x )
590
590
591
+ x = self ._codebook .transform_input (x )
592
+
591
593
if is_multiheaded :
592
594
ein_rhs_eq = 'h b n d' if self .separate_codebook_per_head else '1 (b h) n d'
593
595
x = rearrange (x , f'b n (h d) -> { ein_rhs_eq } ' , h = heads )
594
596
595
597
quantize , embed_ind , distances = self ._codebook (x )
596
598
597
- if return_loss :
599
+ if self .training :
600
+ quantize = x + (quantize - x ).detach ()
601
+
602
+ # function for calculating cross entropy loss to distance matrix
603
+ # used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss
604
+
605
+ def calculate_ce_loss (codes ):
598
606
if not is_multiheaded :
599
607
dist_einops_eq = '1 b n l -> b l n'
600
608
elif self .separate_codebook_per_head :
601
609
dist_einops_eq = 'c b n l -> b l n c'
602
610
else :
603
611
dist_einops_eq = '1 (b h) n l -> b l n h'
604
612
605
- distances = rearrange (distances , dist_einops_eq , b = shape [0 ])
606
- return quantize , F .cross_entropy (distances , indices , ignore_index = - 1 )
613
+ ce_loss = F .cross_entropy (
614
+ rearrange (distances , dist_einops_eq , b = shape [0 ]),
615
+ codes ,
616
+ ignore_index = - 1
617
+ )
618
+
619
+ return ce_loss
620
+
621
+ if return_loss :
622
+ return quantize , calculate_ce_loss (indices )
607
623
608
624
loss = torch .tensor ([0. ], device = device , requires_grad = self .training )
609
625
610
626
if self .training :
611
627
if self .commitment_weight > 0 :
612
- detached_quantize = quantize .detach ()
628
+ if self .commitment_use_cross_entropy_loss :
629
+ commit_loss = calculate_ce_loss (distances , embed_ind )
630
+ else :
631
+ detached_quantize = quantize .detach ()
613
632
614
- if exists (mask ):
615
- # with variable lengthed sequences
616
- commit_loss = F .mse_loss (detached_quantize , x , reduction = 'none' )
633
+ if exists (mask ):
634
+ # with variable lengthed sequences
635
+ commit_loss = F .mse_loss (detached_quantize , x , reduction = 'none' )
617
636
618
- if is_multiheaded :
619
- mask = repeat (mask , 'b n -> c (b h) n' , c = commit_loss .shape [0 ], h = commit_loss .shape [1 ] // mask .shape [0 ])
637
+ if is_multiheaded :
638
+ mask = repeat (mask , 'b n -> c (b h) n' , c = commit_loss .shape [0 ], h = commit_loss .shape [1 ] // mask .shape [0 ])
620
639
621
- commit_loss = commit_loss [mask ].mean ()
622
- else :
623
- commit_loss = F .mse_loss (detached_quantize , x )
640
+ commit_loss = commit_loss [mask ].mean ()
641
+ else :
642
+ commit_loss = F .mse_loss (detached_quantize , x )
624
643
625
644
loss = loss + commit_loss * self .commitment_weight
626
645
0 commit comments