Skip to content

Commit 877287a

Browse files
committed
able to use only a fraction of the probs for per sample entropy regularization in LFQ, to resolve some memory issues in meshgpt
1 parent af576e5 commit 877287a

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.12.6',
6+
version = '1.12.7',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def __init__(
6161
straight_through_activation = nn.Identity(),
6262
num_codebooks = 1,
6363
keep_num_codebooks_dim = None,
64-
codebook_scale = 1. # for residual LFQ, codebook scaled down by 2x at each layer
64+
codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer
65+
frac_per_sample_entropy = 1. # make less than 1. to only use a random fraction of the probs for per sample entropy
6566
):
6667
super().__init__()
6768

@@ -95,6 +96,9 @@ def __init__(
9596

9697
# entropy aux loss related weights
9798

99+
assert 0 < frac_per_sample_entropy <= 1.
100+
self.frac_per_sample_entropy = frac_per_sample_entropy
101+
98102
self.diversity_gamma = diversity_gamma
99103
self.entropy_loss_weight = entropy_loss_weight
100104

@@ -219,8 +223,22 @@ def forward(
219223

220224
if exists(mask):
221225
prob = prob[mask]
226+
else:
227+
prob = rearrange(prob, 'b n ... -> (b n) ...')
228+
229+
# whether to only use a fraction of probs, for reducing memory
230+
231+
if self.frac_per_sample_entropy < 1.:
232+
num_tokens = prob.shape[0]
233+
num_sampled_tokens = int(num_tokens * self.frac_per_sample_entropy)
234+
rand_mask = torch.randn(num_tokens).argsort(dim = -1) < num_sampled_tokens
235+
per_sample_probs = prob[rand_mask]
236+
else:
237+
per_sample_probs = prob
238+
239+
# calculate per sample entropy
222240

223-
per_sample_entropy = entropy(prob).mean()
241+
per_sample_entropy = entropy(per_sample_probs).mean()
224242

225243
# distribution over all available tokens in the batch
226244

0 commit comments

Comments
 (0)