Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 82cce52

Browse files
authored
Merge pull request #204 from rsepassi/push
v1.1.5
2 parents c35c7a3 + eee190b commit 82cce52

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1659
-1515
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ _pycache__/
66

77
# Python egg metadata, regenerated from source files by setuptools.
88
/*.egg-info
9+
/*.egg
910

1011
# PyPI distribution artifacts.
1112
build/

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.1.4',
8+
version='1.1.5',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='no-reply@google.com',
@@ -20,6 +20,7 @@
2020
],
2121
install_requires=[
2222
'numpy',
23+
'requests',
2324
'sympy',
2425
'six',
2526
],

tensor2tensor/bin/t2t-datagen

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616

1717
"""Produces the training and dev data for --problem into --data_dir.
1818
19-
generator.py produces sharded and shuffled TFRecord files of tensorflow.Example
20-
protocol buffers for a variety of datasets registered in this file.
19+
Produces sharded and shuffled TFRecord files of tensorflow.Example protocol
20+
buffers for a variety of registered datasets.
2121
22-
All datasets are registered in _SUPPORTED_PROBLEM_GENERATORS. Each entry maps a
23-
string name (selectable on the command-line with --problem) to a function that
24-
takes 2 arguments - input_directory and mode (one of "train" or "dev") - and
25-
yields for each training example a dictionary mapping string feature names to
26-
lists of {string, int, float}. The generator will be run once for each mode.
22+
All Problems are registered with @registry.register_problem or are in
23+
_SUPPORTED_PROBLEM_GENERATORS in this file. Each entry maps a string name
24+
(selectable on the command-line with --problem) to a function that takes 2
25+
arguments - input_directory and mode (one of "train" or "dev") - and yields for
26+
each training example a dictionary mapping string feature names to lists of
27+
{string, int, float}. The generator will be run once for each mode.
2728
"""
2829
from __future__ import absolute_import
2930
from __future__ import division
@@ -229,8 +230,7 @@ def generate_data_for_problem(problem):
229230
num_shards = FLAGS.num_shards or 10
230231
tf.logging.info("Generating training data for %s.", problem)
231232
train_output_files = generator_utils.train_data_filenames(
232-
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir,
233-
num_shards)
233+
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, num_shards)
234234
generator_utils.generate_files(training_gen(), train_output_files,
235235
FLAGS.max_cases)
236236
tf.logging.info("Generating development data for %s.", problem)
@@ -250,9 +250,10 @@ def generate_data_for_registered_problem(problem_name):
250250
raise ValueError("--num_shards should not be set for registered Problem.")
251251
problem = registry.problem(problem_name)
252252
task_id = None if FLAGS.task_id < 0 else FLAGS.task_id
253-
problem.generate_data(os.path.expanduser(FLAGS.data_dir),
254-
os.path.expanduser(FLAGS.tmp_dir),
255-
task_id=task_id)
253+
problem.generate_data(
254+
os.path.expanduser(FLAGS.data_dir),
255+
os.path.expanduser(FLAGS.tmp_dir),
256+
task_id=task_id)
256257

257258

258259
if __name__ == "__main__":

tensor2tensor/data_generators/all_problems.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from tensor2tensor.data_generators import wmt
3131
from tensor2tensor.data_generators import wsj_parsing
3232

33+
3334
# Problem modules that require optional dependencies
3435
# pylint: disable=g-import-not-at-top
3536
try:

tensor2tensor/data_generators/image.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from tensor2tensor.data_generators import generator_utils
3737
from tensor2tensor.data_generators import problem
3838
from tensor2tensor.data_generators import text_encoder
39-
from tensor2tensor.models import common_layers
39+
from tensor2tensor.layers import common_layers
4040
from tensor2tensor.utils import registry
4141

4242
import tensorflow as tf
@@ -76,10 +76,11 @@ class ImageFSNS(ImageProblem):
7676
def generate_data(self, data_dir, tmp_dir, task_id=-1):
7777
list_url = ("https://raw.githubusercontent.com/tensorflow/models/master/"
7878
"street/python/fsns_urls.txt")
79-
fsns_urls = generator_utils.maybe_download(
80-
tmp_dir, "fsns_urls.txt", list_url)
81-
fsns_files = [f.strip() for f in open(fsns_urls, "r")
82-
if f.startswith("http://")]
79+
fsns_urls = generator_utils.maybe_download(tmp_dir, "fsns_urls.txt",
80+
list_url)
81+
fsns_files = [
82+
f.strip() for f in open(fsns_urls, "r") if f.startswith("http://")
83+
]
8384
for url in fsns_files:
8485
if "/train/train" in url:
8586
generator_utils.maybe_download(
@@ -88,8 +89,7 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
8889
generator_utils.maybe_download(
8990
data_dir, "image_fsns-dev" + url[-len("-00100-of-00512"):], url)
9091
elif "charset" in url:
91-
generator_utils.maybe_download(
92-
data_dir, "charset_size134.txt", url)
92+
generator_utils.maybe_download(data_dir, "charset_size134.txt", url)
9393

9494
def feature_encoders(self, data_dir):
9595
# This vocab file must be present within the data directory.
@@ -111,8 +111,8 @@ def hparams(self, defaults, model_hparams):
111111

112112
def example_reading_spec(self):
113113
label_key = "image/unpadded_label"
114-
return super(ImageFSNS, self).example_reading_spec(self,
115-
label_key=label_key)
114+
return super(ImageFSNS, self).example_reading_spec(
115+
self, label_key=label_key)
116116

117117

118118
class Image2ClassProblem(ImageProblem):
@@ -161,6 +161,7 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
161161

162162
def imagenet_preprocess_examples(examples, mode):
163163
"""Preprocessing used for Imagenet and similar problems."""
164+
164165
def preprocess(img):
165166
img = tf.image.resize_images(img, [360, 360])
166167
img = common_layers.image_augmentation(tf.to_float(img) / 255.)
@@ -215,8 +216,8 @@ def is_small(self):
215216

216217
def preprocess_examples(self, examples, mode):
217218
examples = imagenet_preprocess_examples(examples, mode)
218-
examples["inputs"] = tf.to_int64(tf.image.resize_images(
219-
examples["inputs"], [32, 32]))
219+
examples["inputs"] = tf.to_int64(
220+
tf.image.resize_images(examples["inputs"], [32, 32]))
220221

221222

222223
def image_generator(images, labels):
@@ -665,12 +666,20 @@ def generator(self, data_dir, tmp_dir, is_training):
665666
vocab_filename = "vocab.endefr.%d" % self.targeted_vocab_size
666667
if is_training:
667668
return mscoco_generator(
668-
data_dir, tmp_dir, True, 80000,
669-
vocab_filename=vocab_filename, vocab_size=self.targeted_vocab_size)
669+
data_dir,
670+
tmp_dir,
671+
True,
672+
80000,
673+
vocab_filename=vocab_filename,
674+
vocab_size=self.targeted_vocab_size)
670675
else:
671676
return mscoco_generator(
672-
data_dir, tmp_dir, False, 40000,
673-
vocab_filename=vocab_filename, vocab_size=self.targeted_vocab_size)
677+
data_dir,
678+
tmp_dir,
679+
False,
680+
40000,
681+
vocab_filename=vocab_filename,
682+
vocab_size=self.targeted_vocab_size)
674683

675684

676685
@registry.register_problem
@@ -690,8 +699,8 @@ def targeted_vocab_size(self):
690699
def _get_celeba(directory):
691700
"""Download and extract CELEBA to directory unless it is there."""
692701
# path = os.path.join(directory, _CELEBA_NAME)
693-
path = generator_utils.maybe_download_from_drive(directory,
694-
_CELEBA_NAME, _CELEBA_URL)
702+
path = generator_utils.maybe_download_from_drive(directory, _CELEBA_NAME,
703+
_CELEBA_URL)
695704
if not tf.gfile.Exists(path):
696705
zipfile.ZipFile(path + ".zip", "r").extractall(directory)
697706

@@ -711,7 +720,7 @@ def celeba_generator(tmp_dir, how_many, start_from=0):
711720
"""
712721
_get_celeba(tmp_dir)
713722
image_files = tf.gfile.Glob(os.path.join(tmp_dir, _CELEBA_NAME) + "/*.jpg")
714-
for filename in image_files[start_from:start_from+how_many]:
723+
for filename in image_files[start_from:start_from + how_many]:
715724
with tf.gfile.Open(filename, "r") as f:
716725
encoded_image_data = f.read()
717726
yield {

tensor2tensor/data_generators/problem_hparams.py

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# Dependency imports
2626

2727
from tensor2tensor.data_generators import text_encoder
28-
from tensor2tensor.models import modalities # pylint: disable=unused-import
28+
from tensor2tensor.layers import modalities # pylint: disable=unused-import
2929
from tensor2tensor.utils import registry
3030

3131
import tensorflow as tf
@@ -202,8 +202,7 @@ def default_problem_hparams():
202202
# the targets. For instance `problem_copy` will copy the inputs, but
203203
# `problem_rev_copy` will copy the targets.
204204
was_reversed=False,
205-
was_copy=False,
206-
)
205+
was_copy=False,)
207206

208207

209208
def test_problem_hparams(unused_model_hparams, input_vocab_size,
@@ -327,9 +326,7 @@ def lm1b_32k(model_hparams):
327326
encoder = text_encoder.SubwordTextEncoder(
328327
os.path.join(model_hparams.data_dir, "lm1b_32k.subword_text_encoder"))
329328
p.target_modality = (registry.Modalities.SYMBOL, encoder.vocab_size)
330-
p.vocabulary = {
331-
"targets": encoder
332-
}
329+
p.vocabulary = {"targets": encoder}
333330
p.target_space_id = 3
334331
return p
335332

@@ -343,9 +340,7 @@ def lm1b_characters(unused_model_hparams):
343340
p.input_modality = {}
344341
encoder = text_encoder.ByteTextEncoder()
345342
p.target_modality = (registry.Modalities.SYMBOL, encoder.vocab_size)
346-
p.vocabulary = {
347-
"targets": encoder
348-
}
343+
p.vocabulary = {"targets": encoder}
349344
p.target_space_id = 2
350345
return p
351346

@@ -358,10 +353,7 @@ def wiki_32k(model_hparams):
358353
modality_spec = (registry.Modalities.SYMBOL, encoder.vocab_size)
359354
p.input_modality = {"inputs": modality_spec}
360355
p.target_modality = modality_spec
361-
p.vocabulary = {
362-
"inputs": encoder,
363-
"targets": encoder
364-
}
356+
p.vocabulary = {"inputs": encoder, "targets": encoder}
365357
p.target_space_id = 3
366358
return p
367359

@@ -430,9 +422,7 @@ def wmt_parsing_tokens(model_hparams, wrong_vocab_size):
430422
return p
431423

432424

433-
def wsj_parsing_tokens(model_hparams,
434-
prefix,
435-
wrong_source_vocab_size,
425+
def wsj_parsing_tokens(model_hparams, prefix, wrong_source_vocab_size,
436426
wrong_target_vocab_size):
437427
"""English to parse tree translation benchmark.
438428
@@ -487,11 +477,9 @@ def ice_parsing_tokens(model_hparams, wrong_source_vocab_size):
487477
p = default_problem_hparams()
488478
# This vocab file must be present within the data directory.
489479
source_vocab_filename = os.path.join(
490-
model_hparams.data_dir,
491-
"ice_source.vocab.%d" % wrong_source_vocab_size)
492-
target_vocab_filename = os.path.join(
493-
model_hparams.data_dir,
494-
"ice_target.vocab.256")
480+
model_hparams.data_dir, "ice_source.vocab.%d" % wrong_source_vocab_size)
481+
target_vocab_filename = os.path.join(model_hparams.data_dir,
482+
"ice_target.vocab.256")
495483
source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename)
496484
target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename)
497485
p.input_modality = {
@@ -502,7 +490,7 @@ def ice_parsing_tokens(model_hparams, wrong_source_vocab_size):
502490
"inputs": source_subtokenizer,
503491
"targets": target_subtokenizer,
504492
}
505-
p.input_space_id = 18 # Icelandic tokens
493+
p.input_space_id = 18 # Icelandic tokens
506494
p.target_space_id = 19 # Icelandic parse tokens
507495
return p
508496

@@ -534,23 +522,41 @@ def image_celeba(unused_model_hparams):
534522
# Dictionary of named hyperparameter settings for various problems.
535523
# This is only accessed through the problem_hparams function below.
536524
PROBLEM_HPARAMS_MAP = {
537-
"audio_timit_characters_tune": audio_timit_characters,
538-
"audio_timit_characters_test": audio_timit_characters,
539-
"audio_timit_tokens_8k_tune": lambda p: audio_timit_tokens(p, 2**13),
540-
"audio_timit_tokens_8k_test": lambda p: audio_timit_tokens(p, 2**13),
541-
"audio_wsj_characters_tune": audio_wsj_characters,
542-
"audio_wsj_characters_test": audio_wsj_characters,
543-
"audio_wsj_tokens_8k_tune": lambda p: audio_wsj_tokens(p, 2**13),
544-
"audio_wsj_tokens_8k_test": lambda p: audio_wsj_tokens(p, 2**13),
545-
"lm1b_characters": lm1b_characters,
546-
"lm1b_32k": lm1b_32k,
547-
"wiki_32k": wiki_32k,
548-
"ice_parsing_characters": wmt_parsing_characters,
549-
"ice_parsing_tokens": lambda p: ice_parsing_tokens(p, 2**13),
550-
"wmt_parsing_tokens_8k": lambda p: wmt_parsing_tokens(p, 2**13),
551-
"wsj_parsing_tokens_16k": lambda p: wsj_parsing_tokens( # pylint: disable=g-long-lambda
552-
p, "wsj", 2**14, 2**9),
553-
"wmt_ende_bpe32k": wmt_ende_bpe32k,
554-
"image_celeba_tune": image_celeba,
555-
"img2img_imagenet": img2img_imagenet,
525+
"audio_timit_characters_tune":
526+
audio_timit_characters,
527+
"audio_timit_characters_test":
528+
audio_timit_characters,
529+
"audio_timit_tokens_8k_tune":
530+
lambda p: audio_timit_tokens(p, 2**13),
531+
"audio_timit_tokens_8k_test":
532+
lambda p: audio_timit_tokens(p, 2**13),
533+
"audio_wsj_characters_tune":
534+
audio_wsj_characters,
535+
"audio_wsj_characters_test":
536+
audio_wsj_characters,
537+
"audio_wsj_tokens_8k_tune":
538+
lambda p: audio_wsj_tokens(p, 2**13),
539+
"audio_wsj_tokens_8k_test":
540+
lambda p: audio_wsj_tokens(p, 2**13),
541+
"lm1b_characters":
542+
lm1b_characters,
543+
"lm1b_32k":
544+
lm1b_32k,
545+
"wiki_32k":
546+
wiki_32k,
547+
"ice_parsing_characters":
548+
wmt_parsing_characters,
549+
"ice_parsing_tokens":
550+
lambda p: ice_parsing_tokens(p, 2**13),
551+
"wmt_parsing_tokens_8k":
552+
lambda p: wmt_parsing_tokens(p, 2**13),
553+
"wsj_parsing_tokens_16k":
554+
lambda p: wsj_parsing_tokens( # pylint: disable=g-long-lambda
555+
p, "wsj", 2**14, 2**9),
556+
"wmt_ende_bpe32k":
557+
wmt_ende_bpe32k,
558+
"image_celeba_tune":
559+
image_celeba,
560+
"img2img_imagenet":
561+
img2img_imagenet,
556562
}

tensor2tensor/layers/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)