@@ -87,44 +87,46 @@ def __init__(
87
87
88
88
self .allowed_dtypes = allowed_dtypes
89
89
90
- def bound (self , z : Tensor , eps : float = 1e-3 ) -> Tensor :
91
- """Bound `z`, an array of shape (..., d)."""
90
+ def bound (self , z , eps : float = 1e-3 ):
91
+ """ Bound `z`, an array of shape (..., d). """
92
92
half_l = (self ._levels - 1 ) * (1 + eps ) / 2
93
93
offset = torch .where (self ._levels % 2 == 0 , 0.5 , 0.0 )
94
94
shift = (offset / half_l ).atanh ()
95
95
return (z + shift ).tanh () * half_l - offset
96
96
97
- def quantize (self , z : Tensor ) -> Tensor :
98
- """Quantizes z, returns quantized zhat, same shape as z."""
97
+ def quantize (self , z ) :
98
+ """ Quantizes z, returns quantized zhat, same shape as z. """
99
99
quantized = round_ste (self .bound (z ))
100
100
half_width = self ._levels // 2 # Renormalize to [-1, 1].
101
101
return quantized / half_width
102
102
103
- def _scale_and_shift (self , zhat_normalized : Tensor ) -> Tensor :
103
+ def _scale_and_shift (self , zhat_normalized ) :
104
104
half_width = self ._levels // 2
105
105
return (zhat_normalized * half_width ) + half_width
106
106
107
- def _scale_and_shift_inverse (self , zhat : Tensor ) -> Tensor :
107
+ def _scale_and_shift_inverse (self , zhat ) :
108
108
half_width = self ._levels // 2
109
109
return (zhat - half_width ) / half_width
110
110
111
- def _indices_to_codes (self , indices : Tensor ):
112
- indices = rearrange (indices , '... -> ... 1' )
113
- codes_non_centered = (indices // self ._basis ) % self ._levels
114
- codes = self ._scale_and_shift_inverse (codes_non_centered )
111
+ def _indices_to_codes (self , indices ):
112
+ level_indices = self .indices_to_level_indices (indices )
113
+ codes = self ._scale_and_shift_inverse (level_indices )
115
114
return codes
116
115
117
- def codes_to_indices (self , zhat : Tensor ) -> Tensor :
118
- """Converts a `code` to an index in the codebook."""
116
+ def codes_to_indices (self , zhat ) :
117
+ """ Converts a `code` to an index in the codebook. """
119
118
assert zhat .shape [- 1 ] == self .codebook_dim
120
119
zhat = self ._scale_and_shift (zhat )
121
120
return (zhat * self ._basis ).sum (dim = - 1 ).to (int32 )
122
121
123
- def indices_to_codes (
124
- self ,
125
- indices : Tensor
126
- ) -> Tensor :
127
- """Inverse of `codes_to_indices`."""
122
+ def indices_to_level_indices (self , indices ):
123
+ """ Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """
124
+ indices = rearrange (indices , '... -> ... 1' )
125
+ codes_non_centered = (indices // self ._basis ) % self ._levels
126
+ return codes_non_centered
127
+
128
+ def indices_to_codes (self , indices ):
129
+ """ Inverse of `codes_to_indices`. """
128
130
129
131
is_img_or_video = indices .ndim >= (3 + int (self .keep_num_codebooks_dim ))
130
132
@@ -141,7 +143,7 @@ def indices_to_codes(
141
143
return codes
142
144
143
145
@autocast (enabled = False )
144
- def forward (self , z : Tensor ) -> Tensor :
146
+ def forward (self , z ) :
145
147
"""
146
148
einstein notation
147
149
b - batch
0 commit comments