Skip to content

Commit 0024008

Browse files
committed
address #114
1 parent 0f970c4 commit 0024008

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.14.2',
6+
version = '1.14.4',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/finite_scalar_quantization.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Code adapted from Jax version in Appendix A.1
44
"""
55

6-
from typing import List, Optional
6+
from typing import List, Tuple, Optional
77

88
import torch
99
import torch.nn as nn
@@ -46,7 +46,8 @@ def __init__(
4646
dim: Optional[int] = None,
4747
num_codebooks = 1,
4848
keep_num_codebooks_dim: Optional[bool] = None,
49-
scale: Optional[float] = None
49+
scale: Optional[float] = None,
50+
allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64)
5051
):
5152
super().__init__()
5253
_levels = torch.tensor(levels, dtype=int32)
@@ -80,6 +81,8 @@ def __init__(
8081
implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out = False)
8182
self.register_buffer("implicit_codebook", implicit_codebook, persistent = False)
8283

84+
self.allowed_dtypes = allowed_dtypes
85+
8386
def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
8487
"""Bound `z`, an array of shape (..., d)."""
8588
half_l = (self._levels - 1) * (1 + eps) / 2
@@ -141,8 +144,14 @@ def forward(self, z: Tensor) -> Tensor:
141144
c - number of codebook dim
142145
"""
143146

147+
orig_dtype = z.dtype
144148
is_img_or_video = z.ndim >= 4
145149

150+
# make sure allowed dtype
151+
152+
if z.dtype not in self.allowed_dtypes:
153+
z = z.float()
154+
146155
# standardize image or video into (batch, seq, dimension)
147156

148157
if is_img_or_video:
@@ -173,4 +182,11 @@ def forward(self, z: Tensor) -> Tensor:
173182
if not self.keep_num_codebooks_dim:
174183
indices = rearrange(indices, '... 1 -> ...')
175184

185+
# cast back to original dtype
186+
187+
if out.dtype != orig_dtype:
188+
out = out.type(orig_dtype)
189+
190+
# return quantized output and indices
191+
176192
return out, indices

0 commit comments

Comments
 (0)