Skip to content

Commit 48be859

Browse files
committed
address #209
1 parent a72d217 commit 48be859

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,9 @@ def __init__(
974974
# for variable lengthed sequences, whether to take care of masking out the padding to 0 (or return the original input)
975975
self.return_zeros_for_masked_padding = return_zeros_for_masked_padding
976976

977+
# whether to freeze the codebook, can be overridden on forward
978+
self.freeze_codebook = freeze_codebook
979+
977980
@property
978981
def codebook(self):
979982
codebook = self._codebook.embed
@@ -1042,12 +1045,16 @@ def forward(
10421045
mask = None,
10431046
lens = None,
10441047
sample_codebook_temp = None,
1045-
freeze_codebook = False,
1048+
freeze_codebook = None,
10461049
return_loss_breakdown = False,
10471050
codebook_transform_fn: Callable | None = None
10481051
):
10491052
orig_input, input_requires_grad = x, x.requires_grad
10501053

1054+
# freezing codebook
1055+
1056+
freeze_codebook = default(freeze_codebook, self.freeze_codebook)
1057+
10511058
# handle masking, either passed in as `mask` or `lens`
10521059

10531060
assert not (exists(mask) and exists(lens))

0 commit comments

Comments
 (0)