Skip to content

Commit a0d14af

Browse files
committed
cache mask in lfq
1 parent e469ae2 commit a0d14af

File tree

2 files changed

+8
-39
lines changed

2 files changed

+8
-39
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.9.4',
6+
version = '1.9.5',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -38,36 +38,11 @@ def unpack_one(t, ps, pattern):
3838

3939
# entropy
4040

41-
def binary_entropy(prob):
42-
return -prob * log(prob) - (1 - prob) * log(1 - prob)
43-
44-
# tensor helpers
45-
4641
def log(t, eps = 1e-20):
4742
return t.clamp(min = eps).log()
4843

49-
# convert to bit representations and back
50-
51-
def decimal_to_bits(x, bits):
52-
device = x.device
53-
54-
x = x.int()
55-
56-
mask = 2 ** torch.arange(bits - 1, -1, -1, device = device)
57-
x = rearrange(x, 'b n -> b n 1')
58-
59-
bits = ((x & mask) != 0).float()
60-
bits = rearrange(bits, 'b n d -> b n d')
61-
return bits * 2 - 1
62-
63-
def bits_to_decimal(x, bits):
64-
device = x.device
65-
66-
x = (x > 0).int()
67-
68-
mask = 2 ** torch.arange(bits - 1, -1, -1, device = device, dtype = torch.int32)
69-
dec = reduce(x * mask, 'b n d -> b n', 'sum')
70-
return dec
44+
def binary_entropy(prob):
45+
return -prob * log(prob) - (1 - prob) * log(1 - prob)
7146

7247
# class
7348

@@ -105,6 +80,7 @@ def __init__(
10580

10681
# for no auxiliary loss, during inference
10782

83+
self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
10884
self.register_buffer('zero', torch.zeros(1,), persistent = False)
10985

11086
def indices_to_codes(
@@ -114,14 +90,10 @@ def indices_to_codes(
11490
):
11591
is_img_or_video = indices.ndim >= 3
11692

117-
# rearrange if image or video into (batch, seq, dimension)
118-
119-
if is_img_or_video:
120-
indices, ps = pack_one(indices, 'b *')
121-
12293
# indices to codes, which are bits of either -1 or 1
12394

124-
codes = decimal_to_bits(indices, self.codebook_dim)
95+
bits = ((indices[..., None].int() & self.mask) != 0).float()
96+
codes = bits * 2 - 1
12597

12698
# whether to project codes out to original dimensions
12799
# if the input feature dimensions were not log2(codebook size)
@@ -132,7 +104,6 @@ def indices_to_codes(
132104
# rearrange codes back to original shape
133105

134106
if is_img_or_video:
135-
codes = unpack_one(codes, ps, 'b * d')
136107
codes = rearrange(codes, 'b ... d -> b d ...')
137108

138109
return codes
@@ -163,10 +134,8 @@ def forward(
163134

164135
# quantize by eq 3.
165136

166-
greater_than_zero = x > 0
167137
ones = torch.ones_like(x)
168-
169-
quantized = torch.where(greater_than_zero, ones, -ones)
138+
quantized = torch.where(x > 0, ones, -ones)
170139

171140
# use straight-through gradients with tanh if training
172141

@@ -178,7 +147,7 @@ def forward(
178147

179148
# calculate indices
180149

181-
indices = bits_to_decimal(x, self.codebook_dim)
150+
indices = reduce((x > 0).int() * self.mask.int(), 'b n d -> b n', 'sum')
182151

183152
# entropy aux loss
184153

0 commit comments

Comments
 (0)