@@ -206,7 +206,7 @@ def evaluate(self, session, dataset):
206
206
self .will_resume_training (session )
207
207
return total_loss / step_count
208
208
209
- def sample_ids (self , session , prime_ids , num_steps = 100 , temperature = 0.5 ):
209
+ def sample_ids (self , session , prime_ids , num_steps = 100 , temperature = 0 ):
210
210
"""Let the model generate a sequence based on a preceding string.
211
211
212
212
This method primes the model with the given sequence of token ids. Then, it feeds the model
@@ -219,7 +219,8 @@ def sample_ids(self, session, prime_ids, num_steps=100, temperature=0.5):
219
219
num_steps (int): The number of tokens generated by the model.
220
220
temperature (float): Degree of randomness during sampling. The logits returned by the
221
221
model will be divided by the temperature value before calculating the softmax.
222
- The temperature will be clipped to 0.01 if it is below this bound.
222
+ If the temperature is below 0.01, the model will choose the token with the highest
223
+ predicted probability at each step instead of sampling from the distribution.
223
224
224
225
Returns:
225
226
list[int]: The generated sequence ids.
@@ -249,13 +250,13 @@ def sample_ids(self, session, prime_ids, num_steps=100, temperature=0.5):
249
250
self .will_resume_training (session )
250
251
return outputs
251
252
252
- def sample_text (self , session , vocabulary , prime , num_steps = 100 , temperature = 0.5 ):
253
+ def sample_text (self , session , vocabulary , prime , num_steps = 100 , temperature = 0 ):
253
254
"""Let the model generate a sequence based on a preceding string.
254
255
255
256
This method tokenizes the prime string and feeds the tokens to the model. Then, it feeds the
256
- model its own output (disgusting, I know) token by token and thus lets it generate /
257
- complete the text. For char level, this will result in 100 generated characters, for word
258
- level 100 generated tokens (words / punctuation / whitespace).
257
+ model its own output token by token and thus lets it generate / complete the text. For char
258
+ level, this will result in 100 generated characters, for word level 100 generated tokens
259
+ (words / punctuation / whitespace).
259
260
260
261
Args:
261
262
session (tf.Session): The TF session to run the operations in.
@@ -265,7 +266,8 @@ def sample_text(self, session, vocabulary, prime, num_steps=100, temperature=0.5
265
266
num_steps (int): The number of tokens generated by the model.
266
267
temperature (float): Degree of randomness during sampling. The logits returned by the
267
268
model will be divided by the temperature value before calculating the softmax.
268
- The temperature will be clipped to 0.01 if it is below this bound.
269
+ If the temperature is below 0.01, the model will choose the token with the highest
270
+ predicted probability at each step instead of sampling from the distribution.
269
271
270
272
Returns:
271
273
str: The generated text.
@@ -428,7 +430,7 @@ def will_resume_training(self, session):
428
430
# Re-enable dropout and return to the previous training state
429
431
session .run ([self ._output_keep_var .assign (self .output_keep_prob ), self ._unfreeze_state_op ])
430
432
431
- def _sample_step (self , session , inputs , update_state = True , temperature = 0.5 ):
433
+ def _sample_step (self , session , inputs , update_state = True , temperature = 0 ):
432
434
"""Feeds batch inputs to the model and returns the batch output ids.
433
435
434
436
Args:
@@ -443,7 +445,8 @@ def _sample_step(self, session, inputs, update_state=True, temperature=0.5):
443
445
be frozen before and unfrozen after this function call.
444
446
temperature (float): Degree of randomness during sampling. The logits returned by the
445
447
model will be divided by the temperature value before calculating the softmax.
446
- The temperature will be clipped to 0.01 if it is below this bound.
448
+ If the temperature is below 0.01, the model will choose the token with the highest
449
+ predicted probability at each step instead of sampling from the distribution.
447
450
448
451
Returns:
449
452
np.ndarray: A batch of outputs with the same shape and data type as the inputs
@@ -455,14 +458,16 @@ def _sample_step(self, session, inputs, update_state=True, temperature=0.5):
455
458
456
459
# Get the output
457
460
logits , _ = session .run (runs , feed_dict = feed_dict )
458
- temperature = max (temperature , 0.01 )
459
-
460
- # Sample from the output using the probability distribution in logits
461
- ids = range (logits .shape [2 ])
462
- result = np .zeros (logits .shape [0 :2 ], dtype = np .uint8 )
463
- for batch in range (logits .shape [0 ]):
464
- for step in range (logits .shape [1 ]):
465
- probs = np .exp (logits [batch , step ] / temperature )
466
- probs /= np .sum (probs )
467
- result [batch , step ] = np .random .choice (ids , p = probs )
461
+
462
+ if temperature < 0.01 :
463
+ result = np .argmax (logits , axis = 2 )
464
+ else :
465
+ result = np .zeros (logits .shape [0 :2 ], dtype = np .uint8 )
466
+ # Sample from the output using the probability distribution in logits
467
+ ids = range (logits .shape [2 ])
468
+ for batch in range (logits .shape [0 ]):
469
+ for step in range (logits .shape [1 ]):
470
+ probs = np .exp (logits [batch , step ] / temperature )
471
+ probs /= np .sum (probs )
472
+ result [batch , step ] = np .random .choice (ids , p = probs )
468
473
return result
0 commit comments