Skip to content

Commit 3b2bd1a

Browse files
committed
address #166
1 parent a72e251 commit 3b2bd1a

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.18.2"
3+
version = "1.18.4"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/residual_fsq.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,22 +38,25 @@ class ResidualFSQ(Module):
3838
def __init__(
3939
self,
4040
*,
41-
dim,
4241
levels: List[int],
4342
num_quantizers,
43+
dim = None,
44+
is_channel_first = False,
4445
quantize_dropout = False,
4546
quantize_dropout_cutoff_index = 0,
4647
quantize_dropout_multiple_of = 1,
4748
**kwargs
4849
):
4950
super().__init__()
5051
codebook_dim = len(levels)
52+
dim = default(dim, codebook_dim)
5153

5254
requires_projection = codebook_dim != dim
5355
self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
5456
self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
5557
self.has_projections = requires_projection
5658

59+
self.is_channel_first = is_channel_first
5760
self.num_quantizers = num_quantizers
5861

5962
self.levels = levels
@@ -143,10 +146,18 @@ def forward(
143146
):
144147
num_quant, quant_dropout_multiple_of, device = self.num_quantizers, self.quantize_dropout_multiple_of, x.device
145148

149+
# handle channel first
150+
151+
if self.is_channel_first:
152+
x = rearrange(x, 'b d ... -> b ... d')
153+
x, ps = pack([x], 'b * d')
154+
155+
# maybe project in
156+
146157
x = self.project_in(x)
147158

148159
quantized_out = 0.
149-
residual = first(self.layers).bound(x)
160+
residual = x
150161

151162
all_indices = []
152163

@@ -175,6 +186,7 @@ def forward(
175186
continue
176187

177188
quantized, indices = layer(residual / scale)
189+
178190
quantized = quantized * scale
179191

180192
residual = residual - quantized.detach()
@@ -190,6 +202,17 @@ def forward(
190202

191203
all_indices = torch.stack(all_indices, dim = -1)
192204

205+
# channel first out
206+
207+
if self.is_channel_first:
208+
quantized_out, = unpack(quantized_out, ps, 'b * d')
209+
all_indices, = unpack(all_indices, ps, 'b * d')
210+
211+
quantized_out = rearrange(quantized_out, 'b ... d -> b d ...')
212+
all_indices = rearrange(all_indices, 'b ... d -> b d ...')
213+
214+
# return
215+
193216
ret = (quantized_out, all_indices)
194217

195218
if not return_all_codes:

0 commit comments

Comments
 (0)