Skip to content

Commit 998f744

Browse files
authored
assert freeze_codebook if self.freeze_codebook in init
1 parent a72d217 commit 998f744

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,8 @@ def __init__(
891891

892892
self.eps = eps
893893

894+
self.freeze_codebook = freeze_codebook
895+
894896
self.has_commitment_loss = commitment_weight > 0.
895897
self.commitment_weight = commitment_weight
896898
self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss # whether to use cross entropy loss to codebook as commitment loss
@@ -1048,6 +1050,9 @@ def forward(
10481050
):
10491051
orig_input, input_requires_grad = x, x.requires_grad
10501052

1053+
if self.freeze_codebook:
1054+
assert freeze_codebook
1055+
10511056
# handle masking, either passed in as `mask` or `lens`
10521057

10531058
assert not (exists(mask) and exists(lens))

0 commit comments

Comments
 (0)