Skip to content

Commit 1447998

Browse files
committed
test train and eval pathways for residual vq
1 parent 3ad24a3 commit 1447998

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

tests/test_readme.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,11 @@ def test_vq_mask():
6464

6565
@pytest.mark.parametrize('implicit_neural_codebook', (True, False))
6666
@pytest.mark.parametrize('use_cosine_sim', (True, False))
67+
@pytest.mark.parametrize('train', (True, False))
6768
def test_residual_vq(
6869
implicit_neural_codebook,
69-
use_cosine_sim
70+
use_cosine_sim,
71+
train
7072
):
7173
from vector_quantize_pytorch import ResidualVQ
7274

@@ -80,14 +82,9 @@ def test_residual_vq(
8082

8183
x = torch.randn(1, 256, 32)
8284

83-
quantized, indices, commit_loss = residual_vq(x)
84-
quantized, indices, commit_loss, all_codes = residual_vq(x, return_all_codes = True)
85-
86-
# test eval mode and `get_output_from_indices`
87-
88-
residual_vq.eval()
89-
quantized, indices, commit_loss = residual_vq(x)
85+
residual_vq.train(train)
9086

87+
quantized, indices, commit_loss = residual_vq(x, freeze_codebook = train and not implicit_neural_codebook)
9188
quantized_out = residual_vq.get_output_from_indices(indices)
9289
assert torch.allclose(quantized, quantized_out, atol = 1e-6)
9390

0 commit comments

Comments
 (0)