Skip to content

Commit 3505761

Browse files
committed
address #145 again
1 parent 1bce1c3 commit 3505761

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.14.45"
3+
version = "1.14.46"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def __init__(
198198
bits = ((all_codes[..., None].int() & self.mask) != 0).float()
199199
codebook = self.bits_to_codes(bits)
200200

201-
self.register_buffer('codebook', codebook, persistent = False)
201+
self.register_buffer('codebook', codebook.float(), persistent = False)
202202

203203
def bits_to_codes(self, bits):
204204
return bits * self.codebook_scale * 2 - self.codebook_scale
@@ -257,6 +257,7 @@ def forward(
257257
c - number of codebook dim
258258
"""
259259

260+
orig_dtype = x.dtype
260261
x = x.float()
261262

262263
is_img_or_video = x.ndim >= 4
@@ -313,7 +314,7 @@ def forward(
313314
# entropy aux loss
314315

315316
if self.training:
316-
codebook = self.codebook
317+
codebook = self.codebook.float()
317318

318319
codebook = self.maybe_l2norm(codebook)
319320

@@ -403,6 +404,12 @@ def forward(
403404

404405
aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
405406

407+
# restore original dtype
408+
409+
x = x.type(orig_dtype)
410+
411+
# returns
412+
406413
ret = Return(x, indices, aux_loss)
407414

408415
if not return_loss_breakdown:

0 commit comments

Comments
 (0)