@@ -64,9 +64,11 @@ def test_vq_mask():
6464
6565@pytest .mark .parametrize ('implicit_neural_codebook' , (True , False ))
6666@pytest .mark .parametrize ('use_cosine_sim' , (True , False ))
67+ @pytest .mark .parametrize ('train' , (True , False ))
6768def test_residual_vq (
6869 implicit_neural_codebook ,
69- use_cosine_sim
70+ use_cosine_sim ,
71+ train
7072):
7173 from vector_quantize_pytorch import ResidualVQ
7274
@@ -80,14 +82,9 @@ def test_residual_vq(
8082
8183 x = torch .randn (1 , 256 , 32 )
8284
83- quantized , indices , commit_loss = residual_vq (x )
84- quantized , indices , commit_loss , all_codes = residual_vq (x , return_all_codes = True )
85-
86- # test eval mode and `get_output_from_indices`
87-
88- residual_vq .eval ()
89- quantized , indices , commit_loss = residual_vq (x )
85+ residual_vq .train (train )
9086
87+ quantized , indices , commit_loss = residual_vq (x , freeze_codebook = train and not implicit_neural_codebook )
9188 quantized_out = residual_vq .get_output_from_indices (indices )
9289 assert torch .allclose (quantized , quantized_out , atol = 1e-6 )
9390
0 commit comments