4
4
"""
5
5
6
6
from __future__ import annotations
7
- from functools import wraps
7
+ from functools import wraps , partial
8
+ from contextlib import nullcontext
8
9
from typing import List , Tuple
9
10
10
11
import torch
@@ -61,6 +62,7 @@ def __init__(
61
62
channel_first : bool = False ,
62
63
projection_has_bias : bool = True ,
63
64
return_indices = True ,
65
+ force_quantization_f32 = True
64
66
):
65
67
super ().__init__ ()
66
68
_levels = torch .tensor (levels , dtype = int32 )
@@ -99,6 +101,7 @@ def __init__(
99
101
self .register_buffer ("implicit_codebook" , implicit_codebook , persistent = False )
100
102
101
103
self .allowed_dtypes = allowed_dtypes
104
+ self .force_quantization_f32 = force_quantization_f32
102
105
103
106
def bound (self , z , eps : float = 1e-3 ):
104
107
""" Bound `z`, an array of shape (..., d). """
@@ -166,7 +169,6 @@ def forward(self, z):
166
169
c - number of codebook dim
167
170
"""
168
171
169
- orig_dtype = z .dtype
170
172
is_img_or_video = z .ndim >= 4
171
173
need_move_channel_last = is_img_or_video or self .channel_first
172
174
@@ -182,25 +184,28 @@ def forward(self, z):
182
184
183
185
z = rearrange (z , 'b n (c d) -> b n c d' , c = self .num_codebooks )
184
186
185
- # make sure allowed dtype before quantizing
187
+ # whether to force quantization step to be full precision or not
186
188
187
- if z . dtype not in self .allowed_dtypes :
188
- z = z . float ()
189
+ force_f32 = self .force_quantization_f32
190
+ quantization_context = partial ( autocast , enabled = False ) if force_f32 else nullcontext
189
191
190
- codes = self .quantize (z )
192
+ with quantization_context ():
193
+ orig_dtype = z .dtype
191
194
192
- # returning indices could be optional
195
+ if force_f32 and orig_dtype not in self .allowed_dtypes :
196
+ z = z .float ()
193
197
194
- indices = None
198
+ codes = self . quantize ( z )
195
199
196
- if self .return_indices :
197
- indices = self .codes_to_indices (codes )
200
+ # returning indices could be optional
198
201
199
- codes = rearrange ( codes , 'b n c d -> b n (c d)' )
202
+ indices = None
200
203
201
- # cast codes back to original dtype
204
+ if self .return_indices :
205
+ indices = self .codes_to_indices (codes )
206
+
207
+ codes = rearrange (codes , 'b n c d -> b n (c d)' )
202
208
203
- if codes .dtype != orig_dtype :
204
209
codes = codes .type (orig_dtype )
205
210
206
211
# project out
0 commit comments