@@ -27,6 +27,8 @@ vq = VectorQuantize(
27
27
28
28
x = torch.randn(1 , 1024 , 256 )
29
29
quantized, indices, commit_loss = vq(x) # (1, 1024, 256), (1, 1024), (1)
30
+ print (quantized.shape, indices.shape, commit_loss.shape)
31
+ # > torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])
30
32
```
31
33
32
34
## Residual VQ
@@ -46,16 +48,14 @@ residual_vq = ResidualVQ(
46
48
x = torch.randn(1 , 1024 , 256 )
47
49
48
50
quantized, indices, commit_loss = residual_vq(x)
49
-
50
- # (1, 1024, 256), (1, 1024, 8), (1, 8)
51
- # (batch, seq, dim), (batch, seq, quantizer), (batch, quantizer)
51
+ print (quantized.shape, indices.shape, commit_loss.shape)
52
+ # > torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) torch.Size([1, 8])
52
53
53
54
# if you need all the codes across the quantization layers, just pass return_all_codes = True
54
55
55
56
quantized, indices, commit_loss, all_codes = residual_vq(x, return_all_codes = True )
56
-
57
- # *_, (8, 1, 1024, 256)
58
- # all_codes - (quantizer, batch, seq, dim)
57
+ print (all_codes.shape)
58
+ # > torch.Size([8, 1, 1024, 256])
59
59
```
60
60
61
61
Furthermore, <a href =" https://arxiv.org/abs/2203.01941 " >this paper</a > uses Residual-VQ to construct the RQ-VAE, for generating high resolution images with more compressed codes.
@@ -77,9 +77,8 @@ residual_vq = ResidualVQ(
77
77
78
78
x = torch.randn(1 , 1024 , 256 )
79
79
quantized, indices, commit_loss = residual_vq(x)
80
-
81
- # (1, 1024, 256), (8, 1, 1024), (8, 1)
82
- # (batch, seq, dim), (quantizer, batch, seq), (quantizer, batch)
80
+ print (quantized.shape, indices.shape, commit_loss.shape)
81
+ # > torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) torch.Size([1, 8])
83
82
```
84
83
85
84
<a href =" https://arxiv.org/abs/2305.02765 " >A recent paper</a > further proposes to do residual VQ on groups of the feature dimension, showing equivalent results to Encodec while using far fewer codebooks. You can use it by importing ` GroupedResidualVQ `
@@ -98,9 +97,8 @@ residual_vq = GroupedResidualVQ(
98
97
x = torch.randn(1 , 1024 , 256 )
99
98
100
99
quantized, indices, commit_loss = residual_vq(x)
101
-
102
- # (1, 1024, 256), (2, 1, 1024, 8), (2, 1, 8)
103
- # (batch, seq, dim), (groups, batch, seq, quantizer), (groups, batch, quantizer)
100
+ print (quantized.shape, indices.shape, commit_loss.shape)
101
+ # > torch.Size([1, 1024, 256]) torch.Size([2, 1, 1024, 8]) torch.Size([2, 1, 8])
104
102
105
103
```
106
104
@@ -122,6 +120,8 @@ residual_vq = ResidualVQ(
122
120
123
121
x = torch.randn(1 , 1024 , 256 )
124
122
quantized, indices, commit_loss = residual_vq(x)
123
+ print (quantized.shape, indices.shape, commit_loss.shape)
124
+ # > torch.Size([1, 1024, 256]) torch.Size([1, 1024, 4]) torch.Size([1, 4])
125
125
```
126
126
127
127
## Increasing codebook usage
@@ -144,6 +144,8 @@ vq = VectorQuantize(
144
144
145
145
x = torch.randn(1 , 1024 , 256 )
146
146
quantized, indices, commit_loss = vq(x)
147
+ print (quantized.shape, indices.shape, commit_loss.shape)
148
+ # > torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])
147
149
```
148
150
149
151
### Cosine similarity
@@ -162,6 +164,8 @@ vq = VectorQuantize(
162
164
163
165
x = torch.randn(1 , 1024 , 256 )
164
166
quantized, indices, commit_loss = vq(x)
167
+ print (quantized.shape, indices.shape, commit_loss.shape)
168
+ # > torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])
165
169
```
166
170
167
171
### Expiring stale codes
@@ -180,6 +184,8 @@ vq = VectorQuantize(
180
184
181
185
x = torch.randn(1 , 1024 , 256 )
182
186
quantized, indices, commit_loss = vq(x)
187
+ print (quantized.shape, indices.shape, commit_loss.shape)
188
+ # > torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])
183
189
```
184
190
185
191
### Orthogonal regularization loss
@@ -204,6 +210,8 @@ vq = VectorQuantize(
204
210
img_fmap = torch.randn(1 , 256 , 32 , 32 )
205
211
quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32), (1,)
206
212
# loss now contains the orthogonal regularization loss with the weight as assigned
213
+ print (quantized.shape, indices.shape, loss.shape)
214
+ # > torch.Size([1, 256, 32, 32]) torch.Size([1, 32, 32]) torch.Size([1])
207
215
```
208
216
209
217
### Multi-headed VQ
@@ -226,10 +234,12 @@ vq = VectorQuantize(
226
234
)
227
235
228
236
img_fmap = torch.randn(1 , 256 , 32 , 32 )
229
- quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32, 8), (1,)
237
+ quantized, indices, loss = vq(img_fmap)
238
+ print (quantized.shape, indices.shape, loss.shape)
239
+ # > torch.Size([1, 256, 32, 32]) torch.Size([1, 32, 32, 8]) torch.Size([1])
230
240
231
- # indices shape - (batch, height, width, heads)
232
241
```
242
+
233
243
### Random Projection Quantizer
234
244
235
245
<a href =" https://arxiv.org/abs/2202.01855 " >This paper</a > first proposed to use a random projection quantizer for masked speech modeling, where signals are projected with a randomly initialized matrix and then matched with a random initialized codebook. One therefore does not need to learn the quantizer. This technique was used by Google's <a href =" https://ai.googleblog.com/2023/03/universal-speech-model-usm-state-of-art.html " >Universal Speech Model</a > to achieve SOTA for speech-to-text modeling.
@@ -248,7 +258,9 @@ quantizer = RandomProjectionQuantizer(
248
258
)
249
259
250
260
x = torch.randn(1 , 1024 , 512 )
251
- indices = quantizer(x) # (1, 1024, 16) - (batch, seq, num_codebooks)
261
+ indices = quantizer(x)
262
+ print (indices.shape)
263
+ # > torch.Size([1, 1024, 16])
252
264
```
253
265
254
266
This repository should also automatically synchronizing the codebooks in a multi-process setting. If somehow it isn't, please open an issue. You can override whether to synchronize codebooks or not by setting ` sync_codebook = True | False `
@@ -279,10 +291,11 @@ quantizer = FSQ(levels)
279
291
x = torch.randn(1 , 1024 , 4 ) # 4 since there are 4 levels
280
292
xhat, indices = quantizer(x)
281
293
282
- print (xhat.shape) # (1, 1024, 4) - (batch, seq, dim)
283
- print (indices.shape) # (1, 1024) - (batch, seq)
294
+ print (xhat.shape)
295
+ # > torch.Size([1, 1024, 4])
296
+ print (indices.shape)
297
+ # > torch.Size([1, 1024])
284
298
285
- assert xhat.shape == x.shape
286
299
assert torch.all(xhat == quantizer.indices_to_codes(indices))
287
300
```
288
301
@@ -305,14 +318,12 @@ x = torch.randn(1, 1024, 256)
305
318
residual_fsq.eval()
306
319
307
320
quantized, indices = residual_fsq(x)
308
-
309
- # (1, 1024, 256), (1, 1024, 8), (8)
310
- # (batch, seq, dim), (batch, seq, quantizers), (quantizers)
321
+ print (quantized.shape, indices.shape)
322
+ # > torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8])
311
323
312
324
quantized_out = residual_fsq.get_output_from_indices(indices)
313
-
314
- # (8, 1, 1024, 8)
315
- # (residual layers, batch, seq, quantizers)
325
+ print (quantized_out.shape)
326
+ # > torch.Size([1, 1024, 256])
316
327
317
328
assert torch.all(quantized == quantized_out)
318
329
```
@@ -346,26 +357,34 @@ quantizer = LFQ(
346
357
image_feats = torch.randn(1 , 16 , 32 , 32 )
347
358
348
359
quantized, indices, entropy_aux_loss = quantizer(image_feats, inv_temperature = 100 .) # you may want to experiment with temperature
360
+ print (quantized.shape, indices.shape, entropy_aux_loss.shape)
361
+ # > torch.Size([1, 16, 32, 32]) torch.Size([1, 32, 32]) torch.Size([])
349
362
350
- # (1, 16, 32, 32), (1, 32, 32), (1,)
351
-
352
- assert image_feats.shape == quantized.shape
353
363
assert (quantized == quantizer.indices_to_codes(indices)).all()
354
364
```
355
365
356
366
You can also pass in video features as ` (batch, feat, time, height, width) ` or sequences as ` (batch, seq, feat) `
357
367
358
368
``` python
369
+ import torch
370
+ from vector_quantize_pytorch import LFQ
371
+
372
+ quantizer = LFQ(
373
+ codebook_size = 65536 ,
374
+ dim = 16 ,
375
+ entropy_loss_weight = 0.1 ,
376
+ diversity_gamma = 1 .
377
+ )
359
378
360
379
seq = torch.randn(1 , 32 , 16 )
361
380
quantized, * _ = quantizer(seq)
362
381
363
- assert seq.shape == quantized.shape
382
+ # assert seq.shape == quantized.shape
364
383
365
- video_feats = torch.randn(1 , 16 , 10 , 32 , 32 )
366
- quantized, * _ = quantizer(video_feats)
384
+ # video_feats = torch.randn(1, 16, 10, 32, 32)
385
+ # quantized, *_ = quantizer(video_feats)
367
386
368
- assert video_feats.shape == quantized.shape
387
+ # assert video_feats.shape == quantized.shape
369
388
370
389
```
371
390
@@ -384,8 +403,8 @@ quantizer = LFQ(
384
403
image_feats = torch.randn(1 , 16 , 32 , 32 )
385
404
386
405
quantized, indices, entropy_aux_loss = quantizer(image_feats)
387
-
388
- # ( 1, 16, 32, 32), ( 1, 32, 32, 4), (1, )
406
+ print (quantized.shape, indices.shape, entropy_aux_loss.shape)
407
+ # > torch.Size([ 1, 16, 32, 32]) torch.Size([ 1, 32, 32, 4]) torch.Size([] )
389
408
390
409
assert image_feats.shape == quantized.shape
391
410
assert (quantized == quantizer.indices_to_codes(indices)).all()
@@ -408,14 +427,12 @@ x = torch.randn(1, 1024, 256)
408
427
residual_lfq.eval()
409
428
410
429
quantized, indices, commit_loss = residual_lfq(x)
411
-
412
- # (1, 1024, 256), (1, 1024, 8), (8)
413
- # (batch, seq, dim), (batch, seq, quantizers), (quantizers)
430
+ print (quantized.shape, indices.shape, commit_loss.shape)
431
+ # > torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) torch.Size([8])
414
432
415
433
quantized_out = residual_lfq.get_output_from_indices(indices)
416
-
417
- # (8, 1, 1024, 8)
418
- # (residual layers, batch, seq, quantizers)
434
+ print (quantized_out.shape)
435
+ # > torch.Size([1, 1024, 256])
419
436
420
437
assert torch.all(quantized == quantized_out)
421
438
```
@@ -443,8 +460,8 @@ quantizer = LatentQuantize(
443
460
image_feats = torch.randn(1 , 16 , 32 , 32 )
444
461
445
462
quantized, indices, loss = quantizer(image_feats)
446
-
447
- # ( 1, 16, 32, 32), ( 1, 32, 32), (1, )
463
+ print (quantized.shape, indices.shape, loss.shape)
464
+ # > torch.Size([ 1, 16, 32, 32]) torch.Size([ 1, 32, 32]) torch.Size([] )
448
465
449
466
assert image_feats.shape == quantized.shape
450
467
assert (quantized == quantizer.indices_to_codes(indices)).all()
@@ -454,15 +471,25 @@ You can also pass in video features as `(batch, feat, time, height, width)` or s
454
471
455
472
``` python
456
473
474
+ import torch
475
+ from vector_quantize_pytorch import LatentQuantize
476
+
477
+ quantizer = LatentQuantize(
478
+ levels = [5 , 5 , 8 ],
479
+ dim = 16 ,
480
+ commitment_loss_weight = 0.1 ,
481
+ quantization_loss_weight = 0.1 ,
482
+ )
483
+
457
484
seq = torch.randn(1 , 32 , 16 )
458
485
quantized, * _ = quantizer(seq)
459
-
460
- assert seq.shape == quantized.shape
486
+ print (quantized.shape)
487
+ # > torch.Size([1, 32, 16])
461
488
462
489
video_feats = torch.randn(1 , 16 , 10 , 32 , 32 )
463
490
quantized, * _ = quantizer(video_feats)
464
-
465
- assert video_feats.shape == quantized.shape
491
+ print (quantized.shape)
492
+ # > torch.Size([1, 16, 10, 32, 32])
466
493
467
494
```
468
495
@@ -480,6 +507,8 @@ model = LatentQuantize(levels, dim, num_codebooks=num_codebooks)
480
507
481
508
input_tensor = torch.randn(2 , 3 , dim)
482
509
output_tensor, indices, loss = model(input_tensor)
510
+ print (output_tensor.shape, indices.shape, loss.shape)
511
+ # > torch.Size([2, 3, 9]) torch.Size([2, 3, 3]) torch.Size([])
483
512
484
513
assert output_tensor.shape == input_tensor.shape
485
514
assert indices.shape == (2 , 3 , num_codebooks)
0 commit comments