@@ -559,9 +559,18 @@ def forward(
559
559
560
560
else :
561
561
if exists (codebook_transform_fn ):
562
- quantize = einx .get_at ('h b n [c] d, h b n -> h b n d' , transformed_embed , embed_ind )
562
+ # quantize = einx.get_at('h b n [c] d, h b n -> h b n d', transformed_embed, embed_ind)
563
+
564
+ repeated_embed_ind = repeat (embed_ind , 'h b n -> h b n 1 d' , d = transformed_embed .shape [- 1 ])
565
+ quantize = transformed_embed .gather (- 2 , repeated_embed_ind )
566
+ quantize = rearrange (quantize , 'h b n 1 d -> h b n d' )
567
+
563
568
else :
564
- quantize = einx .get_at ('h [c] d, h b n -> h b n d' , embed , embed_ind )
569
+ # quantize = einx.get_at('h [c] d, h b n -> h b n d', embed, embed_ind)
570
+
571
+ repeated_embed = repeat (embed , 'h c d -> h b c d' , b = embed_ind .shape [1 ])
572
+ repeated_embed_ind = repeat (embed_ind , 'h b n -> h b n d' , d = embed .shape [- 1 ])
573
+ quantize = repeated_embed .gather (- 2 , repeated_embed_ind )
565
574
566
575
if self .training and self .ema_update and not freeze_codebook :
567
576
@@ -767,9 +776,18 @@ def forward(
767
776
768
777
else :
769
778
if exists (codebook_transform_fn ):
770
- quantize = einx .get_at ('h b n [c] d, h b n -> h b n d' , transformed_embed , embed_ind )
779
+ # quantize = einx.get_at('h b n [c] d, h b n -> h b n d', transformed_embed, embed_ind)
780
+
781
+ repeated_embed_ind = repeat (embed_ind , 'h b n -> h b n 1 d' , d = transformed_embed .shape [- 1 ])
782
+ quantize = transformed_embed .gather (- 2 , repeated_embed_ind )
783
+ quantize = rearrange (quantize , 'h b n 1 d -> h b n d' )
784
+
771
785
else :
772
- quantize = einx .get_at ('h [c] d, h b n -> h b n d' , embed , embed_ind )
786
+ # quantize = einx.get_at('h [c] d, h b n -> h b n d', embed, embed_ind)
787
+
788
+ repeated_embed = repeat (embed , 'h c d -> h b c d' , b = embed_ind .shape [1 ])
789
+ repeated_embed_ind = repeat (embed_ind , 'h b n -> h b n d' , d = embed .shape [- 1 ])
790
+ quantize = repeated_embed .gather (- 2 , repeated_embed_ind )
773
791
774
792
if self .training and self .ema_update and not freeze_codebook :
775
793
if exists (mask ):
0 commit comments