@@ -47,7 +47,8 @@ def __init__(
47
47
num_codebooks = 1 ,
48
48
keep_num_codebooks_dim : Optional [bool ] = None ,
49
49
scale : Optional [float ] = None ,
50
- allowed_dtypes : Tuple [torch .dtype , ...] = (torch .float32 , torch .float64 )
50
+ allowed_dtypes : Tuple [torch .dtype , ...] = (torch .float32 , torch .float64 ),
51
+ channel_first : bool = False
51
52
):
52
53
super ().__init__ ()
53
54
_levels = torch .tensor (levels , dtype = int32 )
@@ -71,14 +72,17 @@ def __init__(
71
72
72
73
self .dim = default (dim , len (_levels ) * num_codebooks )
73
74
75
+ self .channel_first = channel_first
76
+
74
77
has_projections = self .dim != effective_codebook_dim
75
78
self .project_in = nn .Linear (self .dim , effective_codebook_dim ) if has_projections else nn .Identity ()
76
79
self .project_out = nn .Linear (effective_codebook_dim , self .dim ) if has_projections else nn .Identity ()
80
+
77
81
self .has_projections = has_projections
78
82
79
83
self .codebook_size = self ._levels .prod ().item ()
80
84
81
- implicit_codebook = self .indices_to_codes (torch .arange (self .codebook_size ), project_out = False )
85
+ implicit_codebook = self ._indices_to_codes (torch .arange (self .codebook_size ))
82
86
self .register_buffer ("implicit_codebook" , implicit_codebook , persistent = False )
83
87
84
88
self .allowed_dtypes = allowed_dtypes
@@ -103,33 +107,35 @@ def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
103
107
def _scale_and_shift_inverse (self , zhat : Tensor ) -> Tensor :
104
108
half_width = self ._levels // 2
105
109
return (zhat - half_width ) / half_width
106
-
110
+
111
+ def _indices_to_codes (self , indices : Tensor ):
112
+ indices = rearrange (indices , '... -> ... 1' )
113
+ codes_non_centered = (indices // self ._basis ) % self ._levels
114
+ codes = self ._scale_and_shift_inverse (codes_non_centered )
115
+ return codes
116
+
107
117
def codes_to_indices (self , zhat : Tensor ) -> Tensor :
108
118
"""Converts a `code` to an index in the codebook."""
109
119
assert zhat .shape [- 1 ] == self .codebook_dim
110
120
zhat = self ._scale_and_shift (zhat )
111
121
return (zhat * self ._basis ).sum (dim = - 1 ).to (int32 )
112
-
122
+
113
123
def indices_to_codes (
114
124
self ,
115
- indices : Tensor ,
116
- project_out = True
125
+ indices : Tensor
117
126
) -> Tensor :
118
127
"""Inverse of `codes_to_indices`."""
119
128
120
129
is_img_or_video = indices .ndim >= (3 + int (self .keep_num_codebooks_dim ))
121
130
122
- indices = rearrange (indices , '... -> ... 1' )
123
- codes_non_centered = (indices // self ._basis ) % self ._levels
124
- codes = self ._scale_and_shift_inverse (codes_non_centered )
131
+ codes = self ._indices_to_codes (indices )
125
132
126
133
if self .keep_num_codebooks_dim :
127
134
codes = rearrange (codes , '... c d -> ... (c d)' )
128
135
129
- if project_out :
130
- codes = self .project_out (codes )
136
+ codes = self .project_out (codes )
131
137
132
- if is_img_or_video :
138
+ if is_img_or_video or self . channel_first :
133
139
codes = rearrange (codes , 'b ... d -> b d ...' )
134
140
135
141
return codes
@@ -146,10 +152,11 @@ def forward(self, z: Tensor) -> Tensor:
146
152
147
153
orig_dtype = z .dtype
148
154
is_img_or_video = z .ndim >= 4
155
+ need_move_channel_last = is_img_or_video or self .channel_first
149
156
150
157
# standardize image or video into (batch, seq, dimension)
151
158
152
- if is_img_or_video :
159
+ if need_move_channel_last :
153
160
z = rearrange (z , 'b d ... -> b ... d' )
154
161
z , ps = pack_one (z , 'b * d' )
155
162
@@ -180,7 +187,7 @@ def forward(self, z: Tensor) -> Tensor:
180
187
181
188
# reconstitute image or video dimensions
182
189
183
- if is_img_or_video :
190
+ if need_move_channel_last :
184
191
out = unpack_one (out , ps , 'b * d' )
185
192
out = rearrange (out , 'b ... d -> b d ...' )
186
193
0 commit comments