@@ -65,7 +65,8 @@ def __init__(
65
65
codebook_scale = 1. , # for residual LFQ, codebook scaled down by 2x at each layer
66
66
frac_per_sample_entropy = 1. , # make less than 1. to only use a random fraction of the probs for per sample entropy
67
67
use_code_agnostic_commit_loss = False ,
68
- projection_has_bias = True
68
+ projection_has_bias = True ,
69
+ soft_clamp_input_value = None
69
70
):
70
71
super ().__init__ ()
71
72
@@ -114,6 +115,11 @@ def __init__(
114
115
self .commitment_loss_weight = commitment_loss_weight
115
116
self .use_code_agnostic_commit_loss = use_code_agnostic_commit_loss
116
117
118
+ # whether to soft clamp the input value from -value to value
119
+
120
+ self .soft_clamp_input_value = soft_clamp_input_value
121
+ assert not exists (soft_clamp_input_value ) or soft_clamp_input_value >= 1.
122
+
117
123
# for no auxiliary loss, during inference
118
124
119
125
self .register_buffer ('mask' , 2 ** torch .arange (codebook_dim - 1 , - 1 , - 1 ))
@@ -195,6 +201,12 @@ def forward(
195
201
196
202
x = self .project_in (x )
197
203
204
+ # maybe soft clamp
205
+
206
+ if exists (self .soft_clamp_input_value ):
207
+ clamp_value = self .soft_clamp_input_value
208
+ x = (x / clamp_value ).tanh () * clamp_value
209
+
198
210
# split out number of codebooks
199
211
200
212
x = rearrange (x , 'b n (c d) -> b n c d' , c = self .num_codebooks )
0 commit comments