@@ -61,7 +61,8 @@ def __init__(
61
61
straight_through_activation = nn .Identity (),
62
62
num_codebooks = 1 ,
63
63
keep_num_codebooks_dim = None ,
64
- codebook_scale = 1. # for residual LFQ, codebook scaled down by 2x at each layer
64
+ codebook_scale = 1. , # for residual LFQ, codebook scaled down by 2x at each layer
65
+ frac_per_sample_entropy = 1. # make less than 1. to only use a random fraction of the probs for per sample entropy
65
66
):
66
67
super ().__init__ ()
67
68
@@ -95,6 +96,9 @@ def __init__(
95
96
96
97
# entropy aux loss related weights
97
98
99
+ assert 0 < frac_per_sample_entropy <= 1.
100
+ self .frac_per_sample_entropy = frac_per_sample_entropy
101
+
98
102
self .diversity_gamma = diversity_gamma
99
103
self .entropy_loss_weight = entropy_loss_weight
100
104
@@ -219,8 +223,22 @@ def forward(
219
223
220
224
if exists (mask ):
221
225
prob = prob [mask ]
226
+ else :
227
+ prob = rearrange (prob , 'b n ... -> (b n) ...' )
228
+
229
+ # whether to only use a fraction of probs, for reducing memory
230
+
231
+ if self .frac_per_sample_entropy < 1. :
232
+ num_tokens = prob .shape [0 ]
233
+ num_sampled_tokens = int (num_tokens * self .frac_per_sample_entropy )
234
+ rand_mask = torch .randn (num_tokens ).argsort (dim = - 1 ) < num_sampled_tokens
235
+ per_sample_probs = prob [rand_mask ]
236
+ else :
237
+ per_sample_probs = prob
238
+
239
+ # calculate per sample entropy
222
240
223
- per_sample_entropy = entropy (prob ).mean ()
241
+ per_sample_entropy = entropy (per_sample_probs ).mean ()
224
242
225
243
# distribution over all available tokens in the batch
226
244
0 commit comments