Skip to content

Commit 9b906d6

Browse files
committed
address #116
1 parent 4ae176d commit 9b906d6

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.14.5',
6+
version = '1.14.6',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/finite_scalar_quantization.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,6 @@ def forward(self, z: Tensor) -> Tensor:
147147
orig_dtype = z.dtype
148148
is_img_or_video = z.ndim >= 4
149149

150-
# make sure allowed dtype
151-
152-
if z.dtype not in self.allowed_dtypes:
153-
z = z.float()
154-
155150
# standardize image or video into (batch, seq, dimension)
156151

157152
if is_img_or_video:
@@ -164,11 +159,23 @@ def forward(self, z: Tensor) -> Tensor:
164159

165160
z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)
166161

162+
# make sure allowed dtype before quantizing
163+
164+
if z.dtype not in self.allowed_dtypes:
165+
z = z.float()
166+
167167
codes = self.quantize(z)
168168
indices = self.codes_to_indices(codes)
169169

170170
codes = rearrange(codes, 'b n c d -> b n (c d)')
171171

172+
# cast codes back to original dtype
173+
174+
if codes.dtype != orig_dtype:
175+
codes = codes.type(orig_dtype)
176+
177+
# project out
178+
172179
out = self.project_out(codes)
173180

174181
# reconstitute image or video dimensions
@@ -182,11 +189,6 @@ def forward(self, z: Tensor) -> Tensor:
182189
if not self.keep_num_codebooks_dim:
183190
indices = rearrange(indices, '... 1 -> ...')
184191

185-
# cast back to original dtype
186-
187-
if out.dtype != orig_dtype:
188-
out = out.type(orig_dtype)
189-
190192
# return quantized output and indices
191193

192194
return out, indices

0 commit comments

Comments
 (0)