|
3 | 3 | Code adapted from Jax version in Appendix A.1
|
4 | 4 | """
|
5 | 5 |
|
6 |
| -from typing import List, Optional |
| 6 | +from typing import List, Tuple, Optional |
7 | 7 |
|
8 | 8 | import torch
|
9 | 9 | import torch.nn as nn
|
@@ -46,7 +46,8 @@ def __init__(
|
46 | 46 | dim: Optional[int] = None,
|
47 | 47 | num_codebooks = 1,
|
48 | 48 | keep_num_codebooks_dim: Optional[bool] = None,
|
49 |
| - scale: Optional[float] = None |
| 49 | + scale: Optional[float] = None, |
| 50 | + allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64) |
50 | 51 | ):
|
51 | 52 | super().__init__()
|
52 | 53 | _levels = torch.tensor(levels, dtype=int32)
|
@@ -80,6 +81,8 @@ def __init__(
|
80 | 81 | implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out = False)
|
81 | 82 | self.register_buffer("implicit_codebook", implicit_codebook, persistent = False)
|
82 | 83 |
|
| 84 | + self.allowed_dtypes = allowed_dtypes |
| 85 | + |
83 | 86 | def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
|
84 | 87 | """Bound `z`, an array of shape (..., d)."""
|
85 | 88 | half_l = (self._levels - 1) * (1 + eps) / 2
|
@@ -141,8 +144,14 @@ def forward(self, z: Tensor) -> Tensor:
|
141 | 144 | c - number of codebook dim
|
142 | 145 | """
|
143 | 146 |
|
| 147 | + orig_dtype = z.dtype |
144 | 148 | is_img_or_video = z.ndim >= 4
|
145 | 149 |
|
| 150 | + # make sure allowed dtype |
| 151 | + |
| 152 | + if z.dtype not in self.allowed_dtypes: |
| 153 | + z = z.float() |
| 154 | + |
146 | 155 | # standardize image or video into (batch, seq, dimension)
|
147 | 156 |
|
148 | 157 | if is_img_or_video:
|
@@ -173,4 +182,11 @@ def forward(self, z: Tensor) -> Tensor:
|
173 | 182 | if not self.keep_num_codebooks_dim:
|
174 | 183 | indices = rearrange(indices, '... 1 -> ...')
|
175 | 184 |
|
| 185 | + # cast back to original dtype |
| 186 | + |
| 187 | + if out.dtype != orig_dtype: |
| 188 | + out = out.type(orig_dtype) |
| 189 | + |
| 190 | + # return quantized output and indices |
| 191 | + |
176 | 192 | return out, indices
|
0 commit comments