Skip to content

Commit e80dbb3

Browse files
committed
Allow ignoring the temperature parameter by setting it to zero
1 parent 03c9fa7 commit e80dbb3

File tree

6 files changed

+30
-25
lines changed

6 files changed

+30
-25
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ with tf.Session() as session:
3434

3535
This should output something like:
3636

37-
The Wtestath t s ien es a a ug dm ooi el e a s i n k s a u ta o e
37+
The ee e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e e
3838

3939
## Command Line Usage
4040

examples/readme/basic_usage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
neurons_per_layer=100, num_layers=3, num_timesteps=15)
88

99
# Train it
10-
model.train(session, max_epochs=10, max_steps=500, print_logs=True)
10+
model.train(session, max_epochs=10, max_steps=500)
1111

1212
# Let it generate a text
1313
generated = model.sample(session, "The ", num_steps=100)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
numpy==1.13.1
2-
tensorflow==1.1.0
2+
# tensorflow==1.1.0
33
nltk==3.2.4
44
python-dateutil==2.6.1

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
],
3232
install_requires=[
3333
"numpy==1.13.1",
34-
"tensorflow==1.1.0",
3534
"nltk==3.2.4",
3635
"python-dateutil==2.6.1",
3736
],

tensorlm/model.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def evaluate(self, session, dataset):
206206
self.will_resume_training(session)
207207
return total_loss / step_count
208208

209-
def sample_ids(self, session, prime_ids, num_steps=100, temperature=0.5):
209+
def sample_ids(self, session, prime_ids, num_steps=100, temperature=0):
210210
"""Let the model generate a sequence based on a preceding string.
211211
212212
This method primes the model with the given sequence of token ids. Then, it feeds the model
@@ -219,7 +219,8 @@ def sample_ids(self, session, prime_ids, num_steps=100, temperature=0.5):
219219
num_steps (int): The number of tokens generated by the model.
220220
temperature (float): Degree of randomness during sampling. The logits returned by the
221221
model will be divided by the temperature value before calculating the softmax.
222-
The temperature will be clipped to 0.01 if it is below this bound.
222+
If the temperature is below 0.01, the model will choose the token with the highest
223+
predicted probability at each step instead of sampling from the distribution.
223224
224225
Returns:
225226
list[int]: The generated sequence ids.
@@ -249,13 +250,13 @@ def sample_ids(self, session, prime_ids, num_steps=100, temperature=0.5):
249250
self.will_resume_training(session)
250251
return outputs
251252

252-
def sample_text(self, session, vocabulary, prime, num_steps=100, temperature=0.5):
253+
def sample_text(self, session, vocabulary, prime, num_steps=100, temperature=0):
253254
"""Let the model generate a sequence based on a preceding string.
254255
255256
This method tokenizes the prime string and feeds the tokens to the model. Then, it feeds the
256-
model its own output (disgusting, I know) token by token and thus lets it generate /
257-
complete the text. For char level, this will result in 100 generated characters, for word
258-
level 100 generated tokens (words / punctuation / whitespace).
257+
model its own output token by token and thus lets it generate / complete the text. For char
258+
level, this will result in 100 generated characters, for word level 100 generated tokens
259+
(words / punctuation / whitespace).
259260
260261
Args:
261262
session (tf.Session): The TF session to run the operations in.
@@ -265,7 +266,8 @@ def sample_text(self, session, vocabulary, prime, num_steps=100, temperature=0.5
265266
num_steps (int): The number of tokens generated by the model.
266267
temperature (float): Degree of randomness during sampling. The logits returned by the
267268
model will be divided by the temperature value before calculating the softmax.
268-
The temperature will be clipped to 0.01 if it is below this bound.
269+
If the temperature is below 0.01, the model will choose the token with the highest
270+
predicted probability at each step instead of sampling from the distribution.
269271
270272
Returns:
271273
str: The generated text.
@@ -428,7 +430,7 @@ def will_resume_training(self, session):
428430
# Re-enable dropout and return to the previous training state
429431
session.run([self._output_keep_var.assign(self.output_keep_prob), self._unfreeze_state_op])
430432

431-
def _sample_step(self, session, inputs, update_state=True, temperature=0.5):
433+
def _sample_step(self, session, inputs, update_state=True, temperature=0):
432434
"""Feeds batch inputs to the model and returns the batch output ids.
433435
434436
Args:
@@ -443,7 +445,8 @@ def _sample_step(self, session, inputs, update_state=True, temperature=0.5):
443445
be frozen before and unfrozen after this function call.
444446
temperature (float): Degree of randomness during sampling. The logits returned by the
445447
model will be divided by the temperature value before calculating the softmax.
446-
The temperature will be clipped to 0.01 if it is below this bound.
448+
If the temperature is below 0.01, the model will choose the token with the highest
449+
predicted probability at each step instead of sampling from the distribution.
447450
448451
Returns:
449452
np.ndarray: A batch of outputs with the same shape and data type as the inputs
@@ -455,14 +458,16 @@ def _sample_step(self, session, inputs, update_state=True, temperature=0.5):
455458

456459
# Get the output
457460
logits, _ = session.run(runs, feed_dict=feed_dict)
458-
temperature = max(temperature, 0.01)
459-
460-
# Sample from the output using the probability distribution in logits
461-
ids = range(logits.shape[2])
462-
result = np.zeros(logits.shape[0:2], dtype=np.uint8)
463-
for batch in range(logits.shape[0]):
464-
for step in range(logits.shape[1]):
465-
probs = np.exp(logits[batch, step] / temperature)
466-
probs /= np.sum(probs)
467-
result[batch, step] = np.random.choice(ids, p=probs)
461+
462+
if temperature < 0.01:
463+
result = np.argmax(logits, axis=2)
464+
else:
465+
result = np.zeros(logits.shape[0:2], dtype=np.uint8)
466+
# Sample from the output using the probability distribution in logits
467+
ids = range(logits.shape[2])
468+
for batch in range(logits.shape[0]):
469+
for step in range(logits.shape[1]):
470+
probs = np.exp(logits[batch, step] / temperature)
471+
probs /= np.sum(probs)
472+
result[batch, step] = np.random.choice(ids, p=probs)
468473
return result

tensorlm/wrappers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def evaluate(self, tf_session, text_path):
174174
loss = self.tf_model.evaluate(tf_session, dataset)
175175
return loss
176176

177-
def sample(self, tf_session, prime, num_steps, temperature=0.5):
177+
def sample(self, tf_session, prime, num_steps, temperature=0):
178178
"""Let the model generate text after being primed with some text.
179179
180180
Args:
@@ -185,7 +185,8 @@ def sample(self, tf_session, prime, num_steps, temperature=0.5):
185185
in num_steps words / numbers / punctuation marks / whitespace characters
186186
temperature (float): Degree of randomness during sampling. The logits returned by the
187187
model will be divided by the temperature value before calculating the softmax.
188-
The temperature will be clipped to 0.01 if it is below this bound.
188+
If the temperature is below 0.01, the model will choose the token with the highest
189+
predicted probability at each step instead of sampling from the distribution.
189190
190191
Returns:
191192
str: The generated sequence.

0 commit comments

Comments
 (0)