Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit a4a0cf0

Browse files
gabegrandafrozenator
authored andcommitted
Merge of PR #1748
PiperOrigin-RevId: 281802220
1 parent 1f7cbd1 commit a4a0cf0

File tree

4 files changed

+135
-6
lines changed

4 files changed

+135
-6
lines changed

tensor2tensor/data_generators/text_problems.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
1818
* Text2TextProblem: input=text, target=text.
1919
* Text2ClassProblem: input=text, target=class.
20+
* Text2RealProblem: input=text, target=float.
2021
* Text2SelfProblem (for language modeling): target=text
2122
* QuestionAndContext2TextProblem: input=text, context=text, target=text.
2223
@@ -605,6 +606,96 @@ def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
605606
yield {"inputs": inputs, "targets": [label]}
606607

607608

609+
class Text2RealProblem(Text2TextProblem):
610+
"""Base class for text regression problems with one or more tasks.
611+
612+
Suitable for text-based problems where targets are continuous, real values.
613+
When ntasks = 1, each text example is mapped to a single scalar value. When
614+
ntasks > 1, each text example is mapped to a 1-d vector of length ntasks.
615+
"""
616+
617+
@property
618+
def ntasks(self):
619+
"""Set to n > 1 for multitask regression."""
620+
return 1
621+
622+
def generate_samples(self, data_dir, tmp_dir, dataset_split):
623+
"""Generate samples of text and real-valued target pairs.
624+
625+
Each yielded dict will be a single example. The inputs should be raw text.
626+
The target should be a list containing ntasks floats.
627+
Args:
628+
data_dir: final data directory. Typically only used in this method to copy
629+
over user-supplied vocab files (for example, if vocab_type ==
630+
VocabType.TOKEN).
631+
tmp_dir: temporary directory that you can use for downloading and scratch.
632+
dataset_split: problem.DatasetSplit, which data split to generate samples
633+
for (for example, training and evaluation).
634+
Yields:
635+
{"inputs": text, "targets": [x1, x2, ..., xN]} where N is ntasks
636+
"""
637+
raise NotImplementedError()
638+
639+
def generate_text_for_vocab(self, data_dir, tmp_dir):
640+
for i, sample in enumerate(
641+
self.generate_samples(data_dir, tmp_dir, problem.DatasetSplit.TRAIN)):
642+
yield sample["inputs"]
643+
if self.max_samples_for_vocab and (i + 1) >= self.max_samples_for_vocab:
644+
break
645+
646+
def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
647+
generator = self.generate_samples(data_dir, tmp_dir, dataset_split)
648+
encoder = self.get_or_create_vocab(data_dir, tmp_dir)
649+
for sample in generator:
650+
inputs = encoder.encode(sample["inputs"])
651+
inputs.append(text_encoder.EOS_ID)
652+
yield {"inputs": inputs, "targets": sample["targets"]}
653+
654+
def feature_encoders(self, data_dir):
655+
encoder = self.get_or_create_vocab(data_dir, None, force_get=True)
656+
657+
return {
658+
"inputs": encoder,
659+
"targets": text_encoder.RealEncoder(),
660+
}
661+
662+
def hparams(self, defaults, unused_model_hparams):
663+
p = defaults
664+
p.modality = {
665+
"inputs": modalities.ModalityType.SYMBOL,
666+
"targets": modalities.ModalityType.REAL_L2_LOSS,
667+
}
668+
p.vocab_size = {
669+
"inputs": self._encoders["inputs"].vocab_size,
670+
"targets": self.ntasks
671+
}
672+
p.target_space_id = problem.SpaceID.REAL
673+
p.add_hparam("regression_targets", True)
674+
675+
def max_length(self, model_hparams):
676+
return model_hparams.batch_size * self.ntasks
677+
678+
def preprocess_example(self, example, unused_mode, unused_hparams):
679+
example = problem.preprocess_example_common(example, unused_mode,
680+
unused_hparams)
681+
example["targets"] = tf.reshape(example["targets"], [1, 1, self.ntasks])
682+
return example
683+
684+
def example_reading_spec(self):
685+
data_fields = {
686+
"inputs": tf.VarLenFeature(tf.int64),
687+
"targets": tf.FixedLenFeature([self.ntasks], tf.float32),
688+
}
689+
data_items_to_decoders = None
690+
return (data_fields, data_items_to_decoders)
691+
692+
def eval_metrics(self):
693+
metrics_list = [metrics.Metrics.RMSE]
694+
if self.ntasks == 1:
695+
metrics_list.append(metrics.Metrics.PEARSON)
696+
return metrics_list
697+
698+
608699
def txt_line_iterator(txt_path):
609700
"""Iterate through lines of file."""
610701
with tf.gfile.Open(txt_path) as f:
@@ -692,6 +783,22 @@ def text2class_txt_iterator(source_txt_path, label_txt_path, class_strs=None):
692783
yield {"inputs": inputs, "label": label}
693784

694785

786+
def text2real_txt_iterator(source_txt_path, target_txt_path):
787+
"""Yield dicts for Text2RealProblem.generate_samples from lines of files.
788+
789+
Args:
790+
source_txt_path: txt file with record per line.
791+
target_txt_path: txt file with float (or space-separated float list for
792+
multitask) per line.
793+
Yields:
794+
{"inputs": inputs, "targets": targets}
795+
"""
796+
for inputs, targets in zip(
797+
txt_line_iterator(source_txt_path), txt_line_iterator(target_txt_path)):
798+
targets = [float(x) for x in targets.split(" ")]
799+
yield {"inputs": inputs, "targets": targets}
800+
801+
695802
def text2text_txt_tab_iterator(txt_path):
696803
"""Yield dicts for Text2TextProblem.generate_samples from lines of txt_path.
697804

tensor2tensor/data_generators/text_problems_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@ def setUpClass(cls):
9494
tf.gfile.Copy(cls.targets_file, os.path.join(cls.tmp_dir,
9595
"targets.eval.txt"))
9696

97+
cls.targets_regr = [[1.23, 2.34], [4.56, 5.67]]
98+
cls.targets_regr_file = os.path.join(cls.tmp_dir, "targets_regr.train.txt")
99+
with tf.gfile.Open(cls.targets_regr_file, "w") as f:
100+
for targets in cls.targets_regr:
101+
f.write(" ".join([str(x) for x in targets]) + "\n")
102+
97103
def testTxtLineIterator(self):
98104
lines = [line for line in text_problems.txt_line_iterator(self.inputs_file)]
99105
self.assertEqual(lines, self.inputs)
@@ -136,6 +142,16 @@ def testText2ClassTxtIteratorWithStrs(self):
136142
self.assertEqual(inputs, self.inputs)
137143
self.assertEqual(labels, self.labels)
138144

145+
def testText2RealTxtIterator(self):
146+
inputs = []
147+
targets = []
148+
for entry in text_problems.text2real_txt_iterator(self.inputs_file,
149+
self.targets_regr_file):
150+
inputs.append(entry["inputs"])
151+
targets.append(entry["targets"])
152+
self.assertEqual(inputs, self.inputs)
153+
self.assertEqual(targets, self.targets_regr)
154+
139155
def testText2TextTxtTabIterator(self):
140156
inputs = []
141157
targets = []

tensor2tensor/models/transformer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,8 @@ def _fast_decode_tpu(self,
462462

463463
if self.has_input:
464464
inputs_shape = common_layers.shape_list(features["inputs"])
465-
if target_modality == modalities.ModalityType.CLASS_LABEL:
465+
if (target_modality == modalities.ModalityType.CLASS_LABEL or
466+
self._problem_hparams.get("regression_targets")):
466467
decode_length = 1
467468
else:
468469
decode_length = (
@@ -704,7 +705,8 @@ def _fast_decode(self,
704705
" of the dataset when decoding.")
705706
if self.has_input:
706707
inputs_shape = common_layers.shape_list(features["inputs"])
707-
if target_modality == modalities.ModalityType.CLASS_LABEL:
708+
if (target_modality == modalities.ModalityType.CLASS_LABEL or
709+
self._problem_hparams.get("regression_targets")):
708710
decode_length = 1
709711
else:
710712
decode_length = (

tensor2tensor/utils/t2t_model.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -806,8 +806,10 @@ def infer(self,
806806

807807
if self._problem_hparams:
808808
target_modality = self._problem_hparams.modality["targets"]
809-
if target_modality == modalities.ModalityType.CLASS_LABEL:
810-
beam_size = 1 # No use to run beam-search for a single class.
809+
if (target_modality == modalities.ModalityType.CLASS_LABEL or
810+
self._problem_hparams.get("regression_targets")):
811+
# No use to run beam-search for classification or regression.
812+
beam_size = 1
811813
if beam_size == 1:
812814
log_info("Greedy Decoding")
813815
results = self._greedy_infer(features, decode_length, use_tpu)
@@ -1064,7 +1066,8 @@ def infer_step(i, recent_output, recent_logits, unused_loss):
10641066
initial_output = tf.slice(initial_output, [0, 0, 0, 0],
10651067
common_layers.shape_list(initial_output))
10661068
target_modality = self._problem_hparams.modality["targets"]
1067-
if target_modality == modalities.ModalityType.CLASS_LABEL:
1069+
if (target_modality == modalities.ModalityType.CLASS_LABEL or
1070+
self._problem_hparams.get("regression_targets")):
10681071
decode_length = 1
10691072
else:
10701073
if "partial_targets" in features:
@@ -1243,7 +1246,8 @@ def infer_step(recent_output, recent_logits, unused_loss):
12431246
initial_output = tf.slice(initial_output, [0, 0, 0, 0],
12441247
common_layers.shape_list(initial_output))
12451248
target_modality = self._problem_hparams.modality["targets"]
1246-
if target_modality == modalities.ModalityType.CLASS_LABEL:
1249+
if (target_modality == modalities.ModalityType.CLASS_LABEL or
1250+
self._problem_hparams.get("regression_targets")):
12471251
decode_length = 1
12481252
else:
12491253
if "partial_targets" in features:

0 commit comments

Comments
 (0)