We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a72d217 commit 998f744Copy full SHA for 998f744
vector_quantize_pytorch/vector_quantize_pytorch.py
@@ -891,6 +891,8 @@ def __init__(
891
892
self.eps = eps
893
894
+ self.freeze_codebook = freeze_codebook
895
+
896
self.has_commitment_loss = commitment_weight > 0.
897
self.commitment_weight = commitment_weight
898
self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss # whether to use cross entropy loss to codebook as commitment loss
@@ -1048,6 +1050,9 @@ def forward(
1048
1050
):
1049
1051
orig_input, input_requires_grad = x, x.requires_grad
1052
1053
+ if self.freeze_codebook:
1054
+ assert freeze_codebook
1055
1056
# handle masking, either passed in as `mask` or `lens`
1057
1058
assert not (exists(mask) and exists(lens))
0 commit comments