diff --git a/vector_quantize_pytorch/vector_quantize_pytorch.py b/vector_quantize_pytorch/vector_quantize_pytorch.py index 086b026..64b1051 100644 --- a/vector_quantize_pytorch/vector_quantize_pytorch.py +++ b/vector_quantize_pytorch/vector_quantize_pytorch.py @@ -891,6 +891,8 @@ def __init__( self.eps = eps + self.freeze_codebook = freeze_codebook + self.has_commitment_loss = commitment_weight > 0. self.commitment_weight = commitment_weight self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss # whether to use cross entropy loss to codebook as commitment loss @@ -1048,6 +1050,9 @@ def forward( ): orig_input, input_requires_grad = x, x.requires_grad + if self.freeze_codebook: + assert freeze_codebook + # handle masking, either passed in as `mask` or `lens` assert not (exists(mask) and exists(lens))