Skip to content

Commit 133f738

Browse files
committed
account for indices being optional in fsq for images
1 parent fc55a8c commit 133f738

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.14.29"
3+
version = "1.14.30"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/finite_scalar_quantization.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
from __future__ import annotations
7+
from functools import wraps
78
from typing import List, Tuple
89

910
import torch
@@ -25,6 +26,14 @@ def default(*args):
2526
return arg
2627
return None
2728

29+
def maybe(fn):
30+
@wraps(fn)
31+
def inner(x, *args, **kwargs):
32+
if not exists(x):
33+
return x
34+
return fn(x, *args, **kwargs)
35+
return inner
36+
2837
def pack_one(t, pattern):
2938
return pack([t], pattern)
3039

@@ -179,7 +188,11 @@ def forward(self, z):
179188
z = z.float()
180189

181190
codes = self.quantize(z)
191+
192+
# returning indices could be optional
193+
182194
indices = None
195+
183196
if self.return_indices:
184197
indices = self.codes_to_indices(codes)
185198

@@ -200,10 +213,10 @@ def forward(self, z):
200213
out = unpack_one(out, ps, 'b * d')
201214
out = rearrange(out, 'b ... d -> b d ...')
202215

203-
indices = unpack_one(indices, ps, 'b * c')
216+
indices = maybe(unpack_one)(indices, ps, 'b * c')
204217

205218
if not self.keep_num_codebooks_dim and self.return_indices:
206-
indices = rearrange(indices, '... 1 -> ...')
219+
indices = maybe(rearrange)(indices, '... 1 -> ...')
207220

208221
# return quantized output and indices
209222

0 commit comments

Comments
 (0)