|
17 | 17 |
|
18 | 18 | * Text2TextProblem: input=text, target=text.
|
19 | 19 | * Text2ClassProblem: input=text, target=class.
|
| 20 | +* Text2RealProblem: input=text, target=float. |
20 | 21 | * Text2SelfProblem (for language modeling): target=text
|
21 | 22 | * QuestionAndContext2TextProblem: input=text, context=text, target=text.
|
22 | 23 |
|
@@ -605,6 +606,96 @@ def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
|
605 | 606 | yield {"inputs": inputs, "targets": [label]}
|
606 | 607 |
|
607 | 608 |
|
| 609 | +class Text2RealProblem(Text2TextProblem): |
| 610 | + """Base class for text regression problems with one or more tasks. |
| 611 | +
|
| 612 | + Suitable for text-based problems where targets are continuous, real values. |
| 613 | + When ntasks = 1, each text example is mapped to a single scalar value. When |
| 614 | + ntasks > 1, each text example is mapped to a 1-d vector of length ntasks. |
| 615 | + """ |
| 616 | + |
| 617 | + @property |
| 618 | + def ntasks(self): |
| 619 | + """Set to n > 1 for multitask regression.""" |
| 620 | + return 1 |
| 621 | + |
| 622 | + def generate_samples(self, data_dir, tmp_dir, dataset_split): |
| 623 | + """Generate samples of text and real-valued target pairs. |
| 624 | +
|
| 625 | + Each yielded dict will be a single example. The inputs should be raw text. |
| 626 | + The target should be a list containing ntasks floats. |
| 627 | + Args: |
| 628 | + data_dir: final data directory. Typically only used in this method to copy |
| 629 | + over user-supplied vocab files (for example, if vocab_type == |
| 630 | + VocabType.TOKEN). |
| 631 | + tmp_dir: temporary directory that you can use for downloading and scratch. |
| 632 | + dataset_split: problem.DatasetSplit, which data split to generate samples |
| 633 | + for (for example, training and evaluation). |
| 634 | + Yields: |
| 635 | + {"inputs": text, "targets": [x1, x2, ..., xN]} where N is ntasks |
| 636 | + """ |
| 637 | + raise NotImplementedError() |
| 638 | + |
| 639 | + def generate_text_for_vocab(self, data_dir, tmp_dir): |
| 640 | + for i, sample in enumerate( |
| 641 | + self.generate_samples(data_dir, tmp_dir, problem.DatasetSplit.TRAIN)): |
| 642 | + yield sample["inputs"] |
| 643 | + if self.max_samples_for_vocab and (i + 1) >= self.max_samples_for_vocab: |
| 644 | + break |
| 645 | + |
| 646 | + def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split): |
| 647 | + generator = self.generate_samples(data_dir, tmp_dir, dataset_split) |
| 648 | + encoder = self.get_or_create_vocab(data_dir, tmp_dir) |
| 649 | + for sample in generator: |
| 650 | + inputs = encoder.encode(sample["inputs"]) |
| 651 | + inputs.append(text_encoder.EOS_ID) |
| 652 | + yield {"inputs": inputs, "targets": sample["targets"]} |
| 653 | + |
| 654 | + def feature_encoders(self, data_dir): |
| 655 | + encoder = self.get_or_create_vocab(data_dir, None, force_get=True) |
| 656 | + |
| 657 | + return { |
| 658 | + "inputs": encoder, |
| 659 | + "targets": text_encoder.RealEncoder(), |
| 660 | + } |
| 661 | + |
| 662 | + def hparams(self, defaults, unused_model_hparams): |
| 663 | + p = defaults |
| 664 | + p.modality = { |
| 665 | + "inputs": modalities.ModalityType.SYMBOL, |
| 666 | + "targets": modalities.ModalityType.REAL_L2_LOSS, |
| 667 | + } |
| 668 | + p.vocab_size = { |
| 669 | + "inputs": self._encoders["inputs"].vocab_size, |
| 670 | + "targets": self.ntasks |
| 671 | + } |
| 672 | + p.target_space_id = problem.SpaceID.REAL |
| 673 | + p.add_hparam("regression_targets", True) |
| 674 | + |
| 675 | + def max_length(self, model_hparams): |
| 676 | + return model_hparams.batch_size * self.ntasks |
| 677 | + |
| 678 | + def preprocess_example(self, example, unused_mode, unused_hparams): |
| 679 | + example = problem.preprocess_example_common(example, unused_mode, |
| 680 | + unused_hparams) |
| 681 | + example["targets"] = tf.reshape(example["targets"], [1, 1, self.ntasks]) |
| 682 | + return example |
| 683 | + |
| 684 | + def example_reading_spec(self): |
| 685 | + data_fields = { |
| 686 | + "inputs": tf.VarLenFeature(tf.int64), |
| 687 | + "targets": tf.FixedLenFeature([self.ntasks], tf.float32), |
| 688 | + } |
| 689 | + data_items_to_decoders = None |
| 690 | + return (data_fields, data_items_to_decoders) |
| 691 | + |
| 692 | + def eval_metrics(self): |
| 693 | + metrics_list = [metrics.Metrics.RMSE] |
| 694 | + if self.ntasks == 1: |
| 695 | + metrics_list.append(metrics.Metrics.PEARSON) |
| 696 | + return metrics_list |
| 697 | + |
| 698 | + |
608 | 699 | def txt_line_iterator(txt_path):
|
609 | 700 | """Iterate through lines of file."""
|
610 | 701 | with tf.gfile.Open(txt_path) as f:
|
@@ -692,6 +783,22 @@ def text2class_txt_iterator(source_txt_path, label_txt_path, class_strs=None):
|
692 | 783 | yield {"inputs": inputs, "label": label}
|
693 | 784 |
|
694 | 785 |
|
| 786 | +def text2real_txt_iterator(source_txt_path, target_txt_path): |
| 787 | + """Yield dicts for Text2RealProblem.generate_samples from lines of files. |
| 788 | +
|
| 789 | + Args: |
| 790 | + source_txt_path: txt file with record per line. |
| 791 | + target_txt_path: txt file with float (or space-separated float list for |
| 792 | + multitask) per line. |
| 793 | + Yields: |
| 794 | + {"inputs": inputs, "targets": targets} |
| 795 | + """ |
| 796 | + for inputs, targets in zip( |
| 797 | + txt_line_iterator(source_txt_path), txt_line_iterator(target_txt_path)): |
| 798 | + targets = [float(x) for x in targets.split(" ")] |
| 799 | + yield {"inputs": inputs, "targets": targets} |
| 800 | + |
| 801 | + |
695 | 802 | def text2text_txt_tab_iterator(txt_path):
|
696 | 803 | """Yield dicts for Text2TextProblem.generate_samples from lines of txt_path.
|
697 | 804 |
|
|
0 commit comments