Skip to content

Commit 4c514db

Browse files
committed
Irie et al. notices that the original Oord implementation of VQ sets cluster sizes of 0 initially, leading to worse convergence. not an issue if kmeans init is turned on
1 parent cb3dd32 commit 4c514db

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,3 +679,12 @@ assert loss.item() >= 0
679679
primaryClass = {cs.LG}
680680
}
681681
```
682+
683+
```bibtex
684+
@inproceedings{Irie2023SelfOrganisingND,
685+
title = {Self-Organising Neural Discrete Representation Learning \`a la Kohonen},
686+
author = {Kazuki Irie and R'obert Csord'as and J{\"u}rgen Schmidhuber},
687+
year = {2023},
688+
url = {https://api.semanticscholar.org/CorpusID:256901024}
689+
}
690+
```

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.15.2"
3+
version = "1.15.3"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def __init__(
312312
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
313313

314314
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
315-
self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size))
315+
self.register_buffer('cluster_size', torch.ones(num_codebooks, codebook_size))
316316
self.register_buffer('embed_avg', embed.clone())
317317

318318
self.learnable_codebook = learnable_codebook
@@ -582,7 +582,7 @@ def __init__(
582582
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
583583

584584
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
585-
self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size))
585+
self.register_buffer('cluster_size', torch.ones(num_codebooks, codebook_size))
586586
self.register_buffer('embed_avg', embed.clone())
587587

588588
self.learnable_codebook = learnable_codebook

0 commit comments

Comments
 (0)