Skip to content

Commit be92f79

Browse files
committed
address #144
1 parent 3505761 commit be92f79

File tree

2 files changed

+86
-71
lines changed

2 files changed

+86
-71
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.14.46"
3+
version = "1.15.0"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 85 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from math import log2, ceil
1010
from functools import partial, cache
1111
from collections import namedtuple
12+
from contextlib import nullcontext
1213
import torch.distributed as dist
1314

1415
import torch
@@ -112,7 +113,8 @@ def __init__(
112113
channel_first = None,
113114
experimental_softplus_entropy_loss = False,
114115
entropy_loss_offset = 5., # how much to shift the loss before softplus
115-
spherical = False # from https://arxiv.org/abs/2406.07548
116+
spherical = False, # from https://arxiv.org/abs/2406.07548
117+
force_quantization_f32 = True # will force the quantization step to be full precision
116118
):
117119
super().__init__()
118120

@@ -192,6 +194,10 @@ def __init__(
192194
self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
193195
self.register_buffer('zero', torch.tensor(0.), persistent = False)
194196

197+
# whether to force quantization step to be f32
198+
199+
self.force_quantization_f32 = force_quantization_f32
200+
195201
# codes
196202

197203
all_codes = torch.arange(codebook_size)
@@ -241,7 +247,6 @@ def indices_to_codes(
241247

242248
return codes
243249

244-
@autocast(enabled = False)
245250
def forward(
246251
self,
247252
x,
@@ -257,9 +262,6 @@ def forward(
257262
c - number of codebook dim
258263
"""
259264

260-
orig_dtype = x.dtype
261-
x = x.float()
262-
263265
is_img_or_video = x.ndim >= 4
264266
should_transpose = default(self.channel_first, is_img_or_video)
265267

@@ -271,8 +273,7 @@ def forward(
271273

272274
assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
273275

274-
with autocast():
275-
x = self.project_in(x)
276+
x = self.project_in(x)
276277

277278
# maybe soft clamp
278279

@@ -288,104 +289,122 @@ def forward(
288289

289290
x = self.maybe_l2norm(x)
290291

291-
# quantize by eq 3.
292+
# whether to force quantization step to be full precision or not
292293

293-
original_input = x
294+
force_f32 = self.force_quantization_f32
294295

295-
codebook_value = torch.ones_like(x) * self.codebook_scale
296-
quantized = torch.where(x > 0, codebook_value, -codebook_value)
296+
quantization_context = partial(autocast, enabled = False) if force_f32 else nullcontext
297297

298-
# calculate indices
298+
with quantization_context():
299299

300-
indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
300+
if force_f32:
301+
orig_dtype = x.dtype
302+
x = x.float()
301303

302-
# maybe l2norm
304+
# quantize by eq 3.
303305

304-
quantized = self.maybe_l2norm(quantized)
306+
original_input = x
305307

306-
# use straight-through gradients (optionally with custom activation fn) if training
308+
codebook_value = torch.ones_like(x) * self.codebook_scale
309+
quantized = torch.where(x > 0, codebook_value, -codebook_value)
307310

308-
if self.training:
309-
x = self.activation(x)
310-
x = x + (quantized - x).detach()
311-
else:
312-
x = quantized
311+
# calculate indices
313312

314-
# entropy aux loss
313+
indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
315314

316-
if self.training:
317-
codebook = self.codebook.float()
315+
# maybe l2norm
318316

319-
codebook = self.maybe_l2norm(codebook)
317+
quantized = self.maybe_l2norm(quantized)
320318

321-
# the same as euclidean distance up to a constant
322-
distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)
319+
# use straight-through gradients (optionally with custom activation fn) if training
323320

324-
prob = (-distance * inv_temperature).softmax(dim = -1)
321+
if self.training:
322+
x = self.activation(x)
323+
x = x + (quantized - x).detach()
324+
else:
325+
x = quantized
325326

326-
# account for mask
327+
# entropy aux loss
327328

328-
if exists(mask):
329-
prob = prob[mask]
330-
else:
331-
prob = rearrange(prob, 'b n ... -> (b n) ...')
329+
if self.training:
332330

333-
# whether to only use a fraction of probs, for reducing memory
331+
if force_f32:
332+
codebook = self.codebook.float()
334333

335-
if self.frac_per_sample_entropy < 1.:
336-
num_tokens = prob.shape[0]
337-
num_sampled_tokens = int(num_tokens * self.frac_per_sample_entropy)
338-
rand_mask = torch.randn(num_tokens).argsort(dim = -1) < num_sampled_tokens
339-
per_sample_probs = prob[rand_mask]
340-
else:
341-
per_sample_probs = prob
334+
codebook = self.maybe_l2norm(codebook)
342335

343-
# calculate per sample entropy
336+
# the same as euclidean distance up to a constant
337+
distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)
344338

345-
per_sample_entropy = entropy(per_sample_probs).mean()
339+
prob = (-distance * inv_temperature).softmax(dim = -1)
346340

347-
# distribution over all available tokens in the batch
341+
# account for mask
348342

349-
avg_prob = reduce(per_sample_probs, '... c d -> c d', 'mean')
343+
if exists(mask):
344+
prob = prob[mask]
345+
else:
346+
prob = rearrange(prob, 'b n ... -> (b n) ...')
350347

351-
avg_prob = maybe_distributed_mean(avg_prob)
348+
# whether to only use a fraction of probs, for reducing memory
352349

353-
codebook_entropy = entropy(avg_prob).mean()
350+
if self.frac_per_sample_entropy < 1.:
351+
num_tokens = prob.shape[0]
352+
num_sampled_tokens = int(num_tokens * self.frac_per_sample_entropy)
353+
rand_mask = torch.randn(num_tokens).argsort(dim = -1) < num_sampled_tokens
354+
per_sample_probs = prob[rand_mask]
355+
else:
356+
per_sample_probs = prob
354357

355-
# 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
356-
# 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
358+
# calculate per sample entropy
357359

358-
entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
359-
else:
360-
# if not training, just return dummy 0
361-
entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
360+
per_sample_entropy = entropy(per_sample_probs).mean()
362361

363-
# whether to make the entropy loss positive or not through a (shifted) softplus
362+
# distribution over all available tokens in the batch
364363

365-
if self.training and self.experimental_softplus_entropy_loss:
366-
entropy_aux_loss = F.softplus(entropy_aux_loss + self.entropy_loss_offset)
364+
avg_prob = reduce(per_sample_probs, '... c d -> c d', 'mean')
367365

368-
# commit loss
366+
avg_prob = maybe_distributed_mean(avg_prob)
369367

370-
if self.training and self.commitment_loss_weight > 0.:
368+
codebook_entropy = entropy(avg_prob).mean()
371369

372-
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')
370+
# 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
371+
# 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
373372

374-
if exists(mask):
375-
commit_loss = commit_loss[mask]
373+
entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
374+
else:
375+
# if not training, just return dummy 0
376+
entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
376377

377-
commit_loss = commit_loss.mean()
378-
else:
379-
commit_loss = self.zero
378+
# whether to make the entropy loss positive or not through a (shifted) softplus
379+
380+
if self.training and self.experimental_softplus_entropy_loss:
381+
entropy_aux_loss = F.softplus(entropy_aux_loss + self.entropy_loss_offset)
382+
383+
# commit loss
384+
385+
if self.training and self.commitment_loss_weight > 0.:
386+
387+
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')
388+
389+
if exists(mask):
390+
commit_loss = commit_loss[mask]
391+
392+
commit_loss = commit_loss.mean()
393+
else:
394+
commit_loss = self.zero
395+
396+
# input back to original dtype if needed
397+
398+
if force_f32:
399+
x = x.type(orig_dtype)
380400

381401
# merge back codebook dim
382402

383403
x = rearrange(x, 'b n c d -> b n (c d)')
384404

385405
# project out to feature dimension if needed
386406

387-
with autocast():
388-
x = self.project_out(x)
407+
x = self.project_out(x)
389408

390409
# reconstitute image or video dimensions
391410

@@ -404,10 +423,6 @@ def forward(
404423

405424
aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
406425

407-
# restore original dtype
408-
409-
x = x.type(orig_dtype)
410-
411426
# returns
412427

413428
ret = Return(x, indices, aux_loss)

0 commit comments

Comments
 (0)