From 998f744b8791250227a4104fe87853e2d7e1614d Mon Sep 17 00:00:00 2001 From: rqi3 <54645143+rqi3@users.noreply.github.com> Date: Fri, 11 Apr 2025 16:51:33 -0400 Subject: [PATCH] assert freeze_codebook if self.freeze_codebook in init --- vector_quantize_pytorch/vector_quantize_pytorch.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vector_quantize_pytorch/vector_quantize_pytorch.py b/vector_quantize_pytorch/vector_quantize_pytorch.py index 086b026..64b1051 100644 --- a/vector_quantize_pytorch/vector_quantize_pytorch.py +++ b/vector_quantize_pytorch/vector_quantize_pytorch.py @@ -891,6 +891,8 @@ def __init__( self.eps = eps + self.freeze_codebook = freeze_codebook + self.has_commitment_loss = commitment_weight > 0. self.commitment_weight = commitment_weight self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss # whether to use cross entropy loss to codebook as commitment loss @@ -1048,6 +1050,9 @@ def forward( ): orig_input, input_requires_grad = x, x.requires_grad + if self.freeze_codebook: + assert freeze_codebook + # handle masking, either passed in as `mask` or `lens` assert not (exists(mask) and exists(lens))