Skip to content

Commit 1786225

Browse files
committed
add a new technique where one can do cross entropy loss on the distance matrix with the codes, if indices were to be passed in
1 parent e0e073d commit 1786225

File tree

5 files changed

+72
-12
lines changed

5 files changed

+72
-12
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,3 +367,11 @@ if __name__ == '__main__':
367367
year = {2023}
368368
}
369369
```
370+
371+
```bibtex
372+
@inproceedings{Shen2023NaturalSpeech2L,
373+
title = {NaturalSpeech 2: Latent Diffusion Models are Natural and Zero-Shot Speech and Singing Synthesizers},
374+
author = {Kai Shen and Zeqian Ju and Xu Tan and Yanqing Liu and Yichong Leng and Lei He and Tao Qin and Sheng Zhao and Jiang Bian},
375+
year = {2023}
376+
}
377+
```

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.1.6',
6+
version = '1.2.0',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/random_projection_quantizer.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
from einops import rearrange, repeat, pack, unpack
77

8+
def exists(val):
9+
return val is not None
10+
811
class RandomProjectionQuantizer(nn.Module):
912
""" https://arxiv.org/abs/2202.01855 """
1013

@@ -40,15 +43,24 @@ def __init__(
4043
**kwargs
4144
)
4245

43-
@torch.no_grad()
44-
def forward(self, x):
46+
def forward(
47+
self,
48+
x,
49+
indices = None
50+
):
51+
return_loss = exists(indices)
4552

4653
x = self.norm(x)
4754

4855
x = einsum('b n d, h d e -> b n h e', x, self.rand_projs)
4956
x, ps = pack([x], 'b n *')
5057

5158
self.vq.eval()
52-
_, indices, _ = self.vq(x)
59+
out = self.vq(x, indices = indices)
60+
61+
if return_loss:
62+
_, ce_loss = out
63+
return ce_loss
5364

65+
_, indices, _ = out
5466
return indices

vector_quantize_pytorch/residual_vq.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99

1010
from einops import rearrange, repeat, pack, unpack
1111

12+
# helper functions
13+
14+
def exists(val):
15+
return val is not None
16+
1217
def round_up_multiple(num, mult):
1318
return ceil(num / mult) * mult
1419

@@ -99,16 +104,21 @@ def get_codes_from_indices(self, indices):
99104
def forward(
100105
self,
101106
x,
107+
indices = None,
102108
return_all_codes = False
103109
):
104-
num_quant, quant_dropout_multiple_of, device = self.num_quantizers, self.quantize_dropout_multiple_of, x.device
110+
num_quant, quant_dropout_multiple_of, return_loss, device = self.num_quantizers, self.quantize_dropout_multiple_of, exists(indices), x.device
111+
112+
assert not (self.accept_image_fmap and exists(indices))
113+
105114
quantized_out = 0.
106115
residual = x
107116

108117
all_losses = []
109118
all_indices = []
119+
ce_losses = [] # for cross entropy losses across quantizers, if indices are passed in
110120

111-
should_quantize_dropout = self.training and self.quantize_dropout
121+
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
112122

113123
# sample a layer index at which to dropout further residual quantization
114124
# also prepare null indices and loss
@@ -132,13 +142,32 @@ def forward(
132142
all_losses.append(null_loss)
133143
continue
134144

135-
quantized, indices, loss = layer(residual)
145+
layer_indices = None
146+
if return_loss:
147+
layer_indices = indices[..., quantizer_index]
148+
149+
quantized, *rest = layer(residual, indices = layer_indices)
150+
136151
residual = residual - quantized.detach()
137152
quantized_out = quantized_out + quantized
138153

139-
all_indices.append(indices)
154+
if return_loss:
155+
ce_loss = rest[0]
156+
ce_losses.append(ce_loss)
157+
continue
158+
159+
embed_indices, loss = rest
160+
161+
all_indices.append(embed_indices)
140162
all_losses.append(loss)
141163

164+
# whether to early return the cross entropy loss
165+
166+
if return_loss:
167+
return quantized_out, sum(ce_losses)
168+
169+
# stack all losses and indices
170+
142171
all_losses, all_indices = map(partial(torch.stack, dim = -1), (all_losses, all_indices))
143172

144173
ret = (quantized_out, all_indices, all_losses)

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def forward(self, x):
301301
if needs_codebook_dim:
302302
quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))
303303

304-
return quantize, embed_ind
304+
return quantize, embed_ind, dist
305305

306306
class CosineSimCodebook(nn.Module):
307307
def __init__(
@@ -441,7 +441,7 @@ def forward(self, x):
441441
if needs_codebook_dim:
442442
quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))
443443

444-
return quantize, embed_ind
444+
return quantize, embed_ind, dist
445445

446446
# main class
447447

@@ -541,14 +541,15 @@ def get_codes_from_indices(self, indices):
541541
def forward(
542542
self,
543543
x,
544+
indices = None,
544545
mask = None
545546
):
546547
only_one = x.ndim == 2
547548

548549
if only_one:
549550
x = rearrange(x, 'b d -> b 1 d')
550551

551-
shape, device, heads, is_multiheaded, codebook_size = x.shape, x.device, self.heads, self.heads > 1, self.codebook_size
552+
shape, device, heads, is_multiheaded, codebook_size, return_loss = x.shape, x.device, self.heads, self.heads > 1, self.codebook_size, exists(indices)
552553

553554
need_transpose = not self.channel_last and not self.accept_image_fmap
554555

@@ -565,11 +566,21 @@ def forward(
565566
ein_rhs_eq = 'h b n d' if self.separate_codebook_per_head else '1 (b h) n d'
566567
x = rearrange(x, f'b n (h d) -> {ein_rhs_eq}', h = heads)
567568

568-
quantize, embed_ind = self._codebook(x)
569+
quantize, embed_ind, distances = self._codebook(x)
569570

570571
if self.training:
571572
quantize = x + (quantize - x).detach()
572573

574+
if return_loss:
575+
if not is_multiheaded:
576+
distances = rearrange(distances, '1 (b n) l -> b l n', b = shape[0])
577+
elif self.separate_codebook_per_head:
578+
distances = rearrange(distances, 'c (b n) l -> b l n c', b = shape[0])
579+
else:
580+
distances = rearrange(distances, '1 (b h n) l -> b l n h', b = shape[0], h = heads)
581+
582+
return quantize, F.cross_entropy(distances, indices, ignore_index = -1)
583+
573584
loss = torch.tensor([0.], device = device, requires_grad = self.training)
574585

575586
if self.training:

0 commit comments

Comments
 (0)