Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 2efd8aa

Browse files
Seppo Enarvicopybara-github
authored andcommitted
Merge of PR #1726
PiperOrigin-RevId: 281806525
1 parent a4a0cf0 commit 2efd8aa

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

tensor2tensor/models/research/universal_transformer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,6 @@ def universal_transformer_base():
458458
@registry.register_hparams
459459
def universal_transformer_base_tpu():
460460
hparams = universal_transformer_base()
461-
hparams = update_hparams_for_universal_transformer(hparams)
462461
transformer.update_hparams_for_tpu(hparams)
463462
hparams.add_step_timing_signal = False
464463
return hparams
@@ -467,7 +466,6 @@ def universal_transformer_base_tpu():
467466
@registry.register_hparams
468467
def universal_transformer_big():
469468
hparams = universal_transformer_base()
470-
hparams = update_hparams_for_universal_transformer(hparams)
471469
hparams.hidden_size = 2048
472470
hparams.filter_size = 8192
473471
return hparams

tensor2tensor/models/transformer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,9 +1170,6 @@ def fast_decode(encoder_output,
11701170
"scores": decoding log probs from the beam search,
11711171
None if using greedy decoding (beam_size=1)
11721172
}
1173-
1174-
Raises:
1175-
NotImplementedError: If beam size > 1 with partial targets.
11761173
"""
11771174
if encoder_output is not None:
11781175
batch_size = common_layers.shape_list(encoder_output)[0]

tensor2tensor/utils/decoding.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,13 @@ def _interactive_input_tensor_to_features_dict(feature_map, hparams):
927927
features["decode_length"] = (
928928
IMAGE_DECODE_LENGTH if input_is_image else inputs[1])
929929
features["inputs"] = x
930+
# Save inputs to "partial_targets" when prepending inputs to targets. Also
931+
# keep "inputs" as some models crash if they don't exist.
932+
if getattr(hparams, "prepend_mode", "none") != "none":
933+
shape = tf.shape(x)
934+
partial_targets = tf.reshape(x, [shape[0], shape[1]])
935+
partial_targets = tf.pad(partial_targets, [[0, 0], [0, 1]])
936+
features["partial_targets"] = partial_targets
930937
return features
931938

932939

@@ -957,6 +964,13 @@ def _decode_input_tensor_to_features_dict(feature_map, hparams):
957964
features["decode_length"] = (
958965
IMAGE_DECODE_LENGTH if input_is_image else tf.shape(x)[1] + 50)
959966
features["inputs"] = x
967+
# Save inputs to "partial_targets" when prepending inputs to targets. Also
968+
# keep "inputs" as some models crash if they don't exist.
969+
if getattr(hparams, "prepend_mode", "none") != "none":
970+
shape = tf.shape(x)
971+
partial_targets = tf.reshape(x, [shape[0], shape[1]])
972+
partial_targets = tf.pad(partial_targets, [[0, 0], [0, 1]])
973+
features["partial_targets"] = partial_targets
960974
return features
961975

962976

0 commit comments

Comments
 (0)