Skip to content

Commit f661132

Browse files
committed
revert due to #49, for cosine sim, just normalize input at beginning, and finally, add an option to use a cross entropy based commitment loss
1 parent 9821bdc commit f661132

File tree

2 files changed

+45
-26
lines changed

2 files changed

+45
-26
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.5.5',
6+
version = '1.5.6',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ def default(val, d):
1616
def noop(*args, **kwargs):
1717
pass
1818

19+
def identity(t):
20+
return t
21+
1922
def l2norm(t):
2023
return F.normalize(t, p = 2, dim = -1)
2124

@@ -200,6 +203,8 @@ def __init__(
200203
sample_codebook_temp = 0
201204
):
202205
super().__init__()
206+
self.transform_input = identity
207+
203208
self.decay = decay
204209
init_fn = uniform_init if not kmeans_init else torch.zeros
205210
embed = init_fn(num_codebooks, codebook_size, dim)
@@ -294,6 +299,8 @@ def forward(self, x):
294299
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
295300
embed_ind = unpack_one(embed_ind, ps, 'h *')
296301

302+
quantize = batched_embedding(embed_ind, self.embed)
303+
297304
if self.training:
298305
cluster_size = embed_onehot.sum(dim = 1)
299306

@@ -310,11 +317,6 @@ def forward(self, x):
310317
self.embed.data.copy_(embed_normalized)
311318
self.expire_codes_(x)
312319

313-
quantize = batched_embedding(embed_ind, self.embed)
314-
315-
if self.training:
316-
quantize = x + (quantize - x).detach()
317-
318320
if needs_codebook_dim:
319321
quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))
320322

@@ -340,6 +342,8 @@ def __init__(
340342
sample_codebook_temp = 0.
341343
):
342344
super().__init__()
345+
self.transform_input = l2norm
346+
343347
self.decay = decay
344348

345349
if not kmeans_init:
@@ -427,18 +431,18 @@ def forward(self, x):
427431
dtype = x.dtype
428432

429433
flatten, ps = pack_one(x, 'h * d')
430-
flatten = l2norm(flatten)
431434

432435
self.init_embed_(flatten)
433436

434437
embed = self.embed if not self.learnable_codebook else self.embed.detach()
435-
embed = l2norm(embed)
436438

437439
dist = einsum('h n d, h c d -> h n c', flatten, embed)
438440
embed_ind = gumbel_sample(dist, dim = -1, temperature = self.sample_codebook_temp)
439441
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
440442
embed_ind = unpack_one(embed_ind, ps, 'h *')
441443

444+
quantize = batched_embedding(embed_ind, self.embed)
445+
442446
if self.training:
443447
bins = embed_onehot.sum(dim = 1)
444448
self.all_reduce_fn(bins)
@@ -457,12 +461,6 @@ def forward(self, x):
457461
self.embed.data.copy_(l2norm(embed_normalized))
458462
self.expire_codes_(x)
459463

460-
quantize = batched_embedding(embed_ind, self.embed)
461-
462-
if self.training:
463-
l2norm_x = l2norm(x)
464-
quantize = l2norm_x + (quantize - l2norm_x).detach()
465-
466464
if needs_codebook_dim:
467465
quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))
468466

@@ -489,6 +487,7 @@ def __init__(
489487
channel_last = True,
490488
accept_image_fmap = False,
491489
commitment_weight = 1.,
490+
commitment_use_cross_entropy_loss = False,
492491
orthogonal_reg_weight = 0.,
493492
orthogonal_reg_active_codes_only = False,
494493
orthogonal_reg_max_codes = None,
@@ -509,6 +508,7 @@ def __init__(
509508

510509
self.eps = eps
511510
self.commitment_weight = commitment_weight
511+
self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss # whether to use cross entropy loss to codebook as commitment loss
512512

513513
has_codebook_orthogonal_loss = orthogonal_reg_weight > 0
514514
self.orthogonal_reg_weight = orthogonal_reg_weight
@@ -588,39 +588,58 @@ def forward(
588588

589589
x = self.project_in(x)
590590

591+
x = self._codebook.transform_input(x)
592+
591593
if is_multiheaded:
592594
ein_rhs_eq = 'h b n d' if self.separate_codebook_per_head else '1 (b h) n d'
593595
x = rearrange(x, f'b n (h d) -> {ein_rhs_eq}', h = heads)
594596

595597
quantize, embed_ind, distances = self._codebook(x)
596598

597-
if return_loss:
599+
if self.training:
600+
quantize = x + (quantize - x).detach()
601+
602+
# function for calculating cross entropy loss to distance matrix
603+
# used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss
604+
605+
def calculate_ce_loss(codes):
598606
if not is_multiheaded:
599607
dist_einops_eq = '1 b n l -> b l n'
600608
elif self.separate_codebook_per_head:
601609
dist_einops_eq = 'c b n l -> b l n c'
602610
else:
603611
dist_einops_eq = '1 (b h) n l -> b l n h'
604612

605-
distances = rearrange(distances, dist_einops_eq, b = shape[0])
606-
return quantize, F.cross_entropy(distances, indices, ignore_index = -1)
613+
ce_loss = F.cross_entropy(
614+
rearrange(distances, dist_einops_eq, b = shape[0]),
615+
codes,
616+
ignore_index = -1
617+
)
618+
619+
return ce_loss
620+
621+
if return_loss:
622+
return quantize, calculate_ce_loss(indices)
607623

608624
loss = torch.tensor([0.], device = device, requires_grad = self.training)
609625

610626
if self.training:
611627
if self.commitment_weight > 0:
612-
detached_quantize = quantize.detach()
628+
if self.commitment_use_cross_entropy_loss:
629+
commit_loss = calculate_ce_loss(distances, embed_ind)
630+
else:
631+
detached_quantize = quantize.detach()
613632

614-
if exists(mask):
615-
# with variable lengthed sequences
616-
commit_loss = F.mse_loss(detached_quantize, x, reduction = 'none')
633+
if exists(mask):
634+
# with variable lengthed sequences
635+
commit_loss = F.mse_loss(detached_quantize, x, reduction = 'none')
617636

618-
if is_multiheaded:
619-
mask = repeat(mask, 'b n -> c (b h) n', c = commit_loss.shape[0], h = commit_loss.shape[1] // mask.shape[0])
637+
if is_multiheaded:
638+
mask = repeat(mask, 'b n -> c (b h) n', c = commit_loss.shape[0], h = commit_loss.shape[1] // mask.shape[0])
620639

621-
commit_loss = commit_loss[mask].mean()
622-
else:
623-
commit_loss = F.mse_loss(detached_quantize, x)
640+
commit_loss = commit_loss[mask].mean()
641+
else:
642+
commit_loss = F.mse_loss(detached_quantize, x)
624643

625644
loss = loss + commit_loss * self.commitment_weight
626645

0 commit comments

Comments
 (0)