@@ -826,17 +826,21 @@ def get_codes_from_indices(self, indices):
826
826
is_multiheaded = codebook .ndim > 2
827
827
828
828
if not is_multiheaded :
829
- return codebook [indices ]
829
+ codes = codebook [indices ]
830
+ else :
831
+ indices , ps = pack_one (indices , 'b * h' )
832
+ indices = rearrange (indices , 'b n h -> b h n' )
833
+
834
+ indices = repeat (indices , 'b h n -> b h n d' , d = codebook .shape [- 1 ])
835
+ codebook = repeat (codebook , 'h n d -> b h n d' , b = indices .shape [0 ])
830
836
831
- indices , ps = pack_one (indices , 'b * h' )
832
- indices = rearrange (indices , 'b n h -> b h n' )
837
+ codes = codebook .gather (2 , indices )
838
+ codes = rearrange (codes , 'b h n d -> b n (h d)' )
839
+ codes = unpack_one (codes , ps , 'b * d' )
833
840
834
- indices = repeat ( indices , 'b h n -> b h n d' , d = codebook . shape [ - 1 ])
835
- codebook = repeat ( codebook , 'h n d -> b h n d' , b = indices . shape [ 0 ] )
841
+ if not self . channel_last :
842
+ codes = rearrange ( codes , 'b ... d -> b d ...' )
836
843
837
- codes = codebook .gather (2 , indices )
838
- codes = rearrange (codes , 'b h n d -> b n (h d)' )
839
- codes = unpack_one (codes , ps , 'b * d' )
840
844
return codes
841
845
842
846
def get_output_from_indices (self , indices ):
0 commit comments