Skip to content

Commit bec307d

Browse files
committed
use threshold of ema cluster sizes to determine which codes to replace, as in soundstream paper, thanks to @wesbz
1 parent 369240f commit bec307d

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ quantized, indices, commit_loss = vq(x)
110110

111111
### Expiring stale codes
112112

113-
Finally, the SoundStream paper has a scheme where they replace codes that have not been used in a certain number of consecutive batches with a randomly selected vector from the current batch. You can set this threshold for consecutive misses before replacement with `max_codebook_misses_before_expiry` keyword. (I know it is a bit long, but I couldn't think of a better name)
113+
Finally, the SoundStream paper has a scheme where they replace codes that have hits below a certain threshold with randomly selected vector from the current batch. You can set this threshold with `threshold_ema_dead_code` keyword.
114114

115115
```python
116116
import torch
@@ -119,7 +119,7 @@ from vector_quantize_pytorch import VectorQuantize
119119
vq = VectorQuantize(
120120
dim = 256,
121121
codebook_size = 512,
122-
max_codebook_misses_before_expiry = 5 # should actively replace any codes that were missed 5 times in a row during training
122+
threshold_ema_dead_code = 2 # should actively replace any codes that have an exponential moving average cluster size less than 2
123123
)
124124

125125
x = torch.randn(1, 1024, 256)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '0.3.3',
6+
version = '0.3.4',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)