Skip to content

Commit 16ec0a8

Browse files
committed
more flexible disabling of autocast for fsq
1 parent be92f79 commit 16ec0a8

File tree

2 files changed

+19
-14
lines changed

2 files changed

+19
-14
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.15.0"
3+
version = "1.15.1"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/finite_scalar_quantization.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
"""
55

66
from __future__ import annotations
7-
from functools import wraps
7+
from functools import wraps, partial
8+
from contextlib import nullcontext
89
from typing import List, Tuple
910

1011
import torch
@@ -61,6 +62,7 @@ def __init__(
6162
channel_first: bool = False,
6263
projection_has_bias: bool = True,
6364
return_indices = True,
65+
force_quantization_f32 = True
6466
):
6567
super().__init__()
6668
_levels = torch.tensor(levels, dtype=int32)
@@ -99,6 +101,7 @@ def __init__(
99101
self.register_buffer("implicit_codebook", implicit_codebook, persistent = False)
100102

101103
self.allowed_dtypes = allowed_dtypes
104+
self.force_quantization_f32 = force_quantization_f32
102105

103106
def bound(self, z, eps: float = 1e-3):
104107
""" Bound `z`, an array of shape (..., d). """
@@ -166,7 +169,6 @@ def forward(self, z):
166169
c - number of codebook dim
167170
"""
168171

169-
orig_dtype = z.dtype
170172
is_img_or_video = z.ndim >= 4
171173
need_move_channel_last = is_img_or_video or self.channel_first
172174

@@ -182,25 +184,28 @@ def forward(self, z):
182184

183185
z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)
184186

185-
# make sure allowed dtype before quantizing
187+
# whether to force quantization step to be full precision or not
186188

187-
if z.dtype not in self.allowed_dtypes:
188-
z = z.float()
189+
force_f32 = self.force_quantization_f32
190+
quantization_context = partial(autocast, enabled = False) if force_f32 else nullcontext
189191

190-
codes = self.quantize(z)
192+
with quantization_context():
193+
orig_dtype = z.dtype
191194

192-
# returning indices could be optional
195+
if force_f32 and orig_dtype not in self.allowed_dtypes:
196+
z = z.float()
193197

194-
indices = None
198+
codes = self.quantize(z)
195199

196-
if self.return_indices:
197-
indices = self.codes_to_indices(codes)
200+
# returning indices could be optional
198201

199-
codes = rearrange(codes, 'b n c d -> b n (c d)')
202+
indices = None
200203

201-
# cast codes back to original dtype
204+
if self.return_indices:
205+
indices = self.codes_to_indices(codes)
206+
207+
codes = rearrange(codes, 'b n c d -> b n (c d)')
202208

203-
if codes.dtype != orig_dtype:
204209
codes = codes.type(orig_dtype)
205210

206211
# project out

0 commit comments

Comments
 (0)