File tree Expand file tree Collapse file tree 1 file changed +8
-1
lines changed Expand file tree Collapse file tree 1 file changed +8
-1
lines changed Original file line number Diff line number Diff line change @@ -974,6 +974,9 @@ def __init__(
974974 # for variable lengthed sequences, whether to take care of masking out the padding to 0 (or return the original input)
975975 self .return_zeros_for_masked_padding = return_zeros_for_masked_padding
976976
977+ # whether to freeze the codebook, can be overridden on forward
978+ self .freeze_codebook = freeze_codebook
979+
977980 @property
978981 def codebook (self ):
979982 codebook = self ._codebook .embed
@@ -1042,12 +1045,16 @@ def forward(
10421045 mask = None ,
10431046 lens = None ,
10441047 sample_codebook_temp = None ,
1045- freeze_codebook = False ,
1048+ freeze_codebook = None ,
10461049 return_loss_breakdown = False ,
10471050 codebook_transform_fn : Callable | None = None
10481051 ):
10491052 orig_input , input_requires_grad = x , x .requires_grad
10501053
1054+ # freezing codebook
1055+
1056+ freeze_codebook = default (freeze_codebook , self .freeze_codebook )
1057+
10511058 # handle masking, either passed in as `mask` or `lens`
10521059
10531060 assert not (exists (mask ) and exists (lens ))
You can’t perform that action at this time.
0 commit comments