Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 0c66117

Browse files
authored
Merge pull request #156 from rsepassi/push
v1.0.14
2 parents 43bfb9f + c8b7000 commit 0c66117

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+1272
-745
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def transformer_my_very_own_hparams_set():
242242

243243
```python
244244
# In ~/usr/t2t_usr/__init__.py
245-
import my_registrations
245+
from . import my_registrations
246246
```
247247

248248
```

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.0.13',
8+
version='1.0.14',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='no-reply@google.com',

tensor2tensor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2017 Google Inc.
1+
# Copyright 2017 The Tensor2Tensor Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

tensor2tensor/bin/t2t-datagen

100755100644
Lines changed: 55 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright 2017 Google Inc.
2+
# Copyright 2017 The Tensor2Tensor Authors.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -24,6 +24,9 @@ takes 2 arguments - input_directory and mode (one of "train" or "dev") - and
2424
yields for each training example a dictionary mapping string feature names to
2525
lists of {string, int, float}. The generator will be run once for each mode.
2626
"""
27+
from __future__ import absolute_import
28+
from __future__ import division
29+
from __future__ import print_function
2730

2831
import random
2932
import tempfile
@@ -34,6 +37,7 @@ import numpy as np
3437

3538
from tensor2tensor.data_generators import algorithmic
3639
from tensor2tensor.data_generators import algorithmic_math
40+
from tensor2tensor.data_generators import all_problems # pylint: disable=unused-import
3741
from tensor2tensor.data_generators import audio
3842
from tensor2tensor.data_generators import generator_utils
3943
from tensor2tensor.data_generators import image
@@ -43,6 +47,7 @@ from tensor2tensor.data_generators import snli
4347
from tensor2tensor.data_generators import wiki
4448
from tensor2tensor.data_generators import wmt
4549
from tensor2tensor.data_generators import wsj_parsing
50+
from tensor2tensor.utils import registry
4651

4752
import tensorflow as tf
4853

@@ -62,12 +67,6 @@ flags.DEFINE_integer("random_seed", 429459, "Random seed to use.")
6267
# Mapping from problems that we can generate data for to their generators.
6368
# pylint: disable=g-long-lambda
6469
_SUPPORTED_PROBLEM_GENERATORS = {
65-
"algorithmic_identity_binary40": (
66-
lambda: algorithmic.identity_generator(2, 40, 100000),
67-
lambda: algorithmic.identity_generator(2, 400, 10000)),
68-
"algorithmic_identity_decimal40": (
69-
lambda: algorithmic.identity_generator(10, 40, 100000),
70-
lambda: algorithmic.identity_generator(10, 400, 10000)),
7170
"algorithmic_shift_decimal40": (
7271
lambda: algorithmic.shift_generator(20, 10, 40, 100000),
7372
lambda: algorithmic.shift_generator(20, 10, 80, 10000)),
@@ -104,9 +103,9 @@ _SUPPORTED_PROBLEM_GENERATORS = {
104103
lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)),
105104
"ice_parsing_tokens": (
106105
lambda: wmt.tabbed_parsing_token_generator(FLAGS.tmp_dir,
107-
True, "ice", 2**13, 2**8),
106+
True, "ice", 2**13, 2**8),
108107
lambda: wmt.tabbed_parsing_token_generator(FLAGS.tmp_dir,
109-
False, "ice", 2**13, 2**8)),
108+
False, "ice", 2**13, 2**8)),
110109
"ice_parsing_characters": (
111110
lambda: wmt.tabbed_parsing_character_generator(FLAGS.tmp_dir, True),
112111
lambda: wmt.tabbed_parsing_character_generator(FLAGS.tmp_dir, False)),
@@ -118,11 +117,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
118117
2**14, 2**9),
119118
lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, False,
120119
2**14, 2**9)),
121-
"wsj_parsing_tokens_32k": (
122-
lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, True,
123-
2**15, 2**9),
124-
lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, False,
125-
2**15, 2**9)),
126120
"wmt_enfr_characters": (
127121
lambda: wmt.enfr_character_generator(FLAGS.tmp_dir, True),
128122
lambda: wmt.enfr_character_generator(FLAGS.tmp_dir, False)),
@@ -140,14 +134,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
140134
"wmt_ende_bpe32k": (
141135
lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, True),
142136
lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, False)),
143-
"wmt_ende_tokens_8k": (
144-
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**13),
145-
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**13)
146-
),
147-
"wmt_ende_tokens_32k": (
148-
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15),
149-
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15)
150-
),
151137
"wmt_zhen_tokens_32k": (
152138
lambda: wmt.zhen_wordpiece_token_generator(FLAGS.tmp_dir, True,
153139
2**15, 2**15),
@@ -174,26 +160,9 @@ _SUPPORTED_PROBLEM_GENERATORS = {
174160
"image_cifar10_test": (
175161
lambda: image.cifar10_generator(FLAGS.tmp_dir, True, 50000),
176162
lambda: image.cifar10_generator(FLAGS.tmp_dir, False, 10000)),
177-
"image_mscoco_characters_tune": (
178-
lambda: image.mscoco_generator(FLAGS.tmp_dir, True, 70000),
179-
lambda: image.mscoco_generator(FLAGS.tmp_dir, True, 10000, 70000)),
180163
"image_mscoco_characters_test": (
181164
lambda: image.mscoco_generator(FLAGS.tmp_dir, True, 80000),
182165
lambda: image.mscoco_generator(FLAGS.tmp_dir, False, 40000)),
183-
"image_mscoco_tokens_8k_tune": (
184-
lambda: image.mscoco_generator(
185-
FLAGS.tmp_dir,
186-
True,
187-
70000,
188-
vocab_filename="tokens.vocab.%d" % 2**13,
189-
vocab_size=2**13),
190-
lambda: image.mscoco_generator(
191-
FLAGS.tmp_dir,
192-
True,
193-
10000,
194-
70000,
195-
vocab_filename="tokens.vocab.%d" % 2**13,
196-
vocab_size=2**13)),
197166
"image_mscoco_tokens_8k_test": (
198167
lambda: image.mscoco_generator(
199168
FLAGS.tmp_dir,
@@ -207,20 +176,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
207176
40000,
208177
vocab_filename="tokens.vocab.%d" % 2**13,
209178
vocab_size=2**13)),
210-
"image_mscoco_tokens_32k_tune": (
211-
lambda: image.mscoco_generator(
212-
FLAGS.tmp_dir,
213-
True,
214-
70000,
215-
vocab_filename="tokens.vocab.%d" % 2**15,
216-
vocab_size=2**15),
217-
lambda: image.mscoco_generator(
218-
FLAGS.tmp_dir,
219-
True,
220-
10000,
221-
70000,
222-
vocab_filename="tokens.vocab.%d" % 2**15,
223-
vocab_size=2**15)),
224179
"image_mscoco_tokens_32k_test": (
225180
lambda: image.mscoco_generator(
226181
FLAGS.tmp_dir,
@@ -308,8 +263,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
308263

309264
# pylint: enable=g-long-lambda
310265

311-
UNSHUFFLED_SUFFIX = "-unshuffled"
312-
313266

314267
def set_random_seed():
315268
"""Set the random seed from flag everywhere."""
@@ -322,13 +275,15 @@ def main(_):
322275
tf.logging.set_verbosity(tf.logging.INFO)
323276

324277
# Calculate the list of problems to generate.
325-
problems = list(sorted(_SUPPORTED_PROBLEM_GENERATORS))
278+
problems = sorted(
279+
list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems())
326280
if FLAGS.problem and FLAGS.problem[-1] == "*":
327281
problems = [p for p in problems if p.startswith(FLAGS.problem[:-1])]
328282
elif FLAGS.problem:
329283
problems = [p for p in problems if p == FLAGS.problem]
330284
else:
331285
problems = []
286+
332287
# Remove TIMIT if paths are not given.
333288
if not FLAGS.timit_paths:
334289
problems = [p for p in problems if "timit" not in p]
@@ -340,7 +295,8 @@ def main(_):
340295
problems = [p for p in problems if "ende_bpe" not in p]
341296

342297
if not problems:
343-
problems_str = "\n * ".join(sorted(_SUPPORTED_PROBLEM_GENERATORS))
298+
problems_str = "\n * ".join(
299+
sorted(list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems()))
344300
error_msg = ("You must specify one of the supported problems to "
345301
"generate data for:\n * " + problems_str + "\n")
346302
error_msg += ("TIMIT, ende_bpe and parsing need data_sets specified with "
@@ -357,40 +313,50 @@ def main(_):
357313
for problem in problems:
358314
set_random_seed()
359315

360-
training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[problem]
361-
362-
if isinstance(dev_gen, int):
363-
# The dev set and test sets are generated as extra shards using the
364-
# training generator. The integer specifies the number of training
365-
# shards. FLAGS.num_shards is ignored.
366-
num_training_shards = dev_gen
367-
tf.logging.info("Generating data for %s.", problem)
368-
all_output_files = generator_utils.combined_data_filenames(
369-
problem + UNSHUFFLED_SUFFIX, FLAGS.data_dir, num_training_shards)
370-
generator_utils.generate_files(
371-
training_gen(), all_output_files, FLAGS.max_cases)
316+
if problem in _SUPPORTED_PROBLEM_GENERATORS:
317+
generate_data_for_problem(problem)
372318
else:
373-
# usual case - train data and dev data are generated using separate
374-
# generators.
375-
tf.logging.info("Generating training data for %s.", problem)
376-
train_output_files = generator_utils.train_data_filenames(
377-
problem + UNSHUFFLED_SUFFIX, FLAGS.data_dir, FLAGS.num_shards)
378-
generator_utils.generate_files(
379-
training_gen(), train_output_files, FLAGS.max_cases)
380-
tf.logging.info("Generating development data for %s.", problem)
381-
dev_shards = 10 if "coco" in problem else 1
382-
dev_output_files = generator_utils.dev_data_filenames(
383-
problem + UNSHUFFLED_SUFFIX, FLAGS.data_dir, dev_shards)
384-
generator_utils.generate_files(dev_gen(), dev_output_files)
385-
all_output_files = train_output_files + dev_output_files
319+
generate_data_for_registered_problem(problem)
320+
321+
322+
def generate_data_for_problem(problem):
323+
"""Generate data for a problem in _SUPPORTED_PROBLEM_GENERATORS."""
324+
training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[problem]
325+
326+
if isinstance(dev_gen, int):
327+
# The dev set and test sets are generated as extra shards using the
328+
# training generator. The integer specifies the number of training
329+
# shards. FLAGS.num_shards is ignored.
330+
num_training_shards = dev_gen
331+
tf.logging.info("Generating data for %s.", problem)
332+
all_output_files = generator_utils.combined_data_filenames(
333+
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir,
334+
num_training_shards)
335+
generator_utils.generate_files(training_gen(), all_output_files,
336+
FLAGS.max_cases)
337+
else:
338+
# usual case - train data and dev data are generated using separate
339+
# generators.
340+
tf.logging.info("Generating training data for %s.", problem)
341+
train_output_files = generator_utils.train_data_filenames(
342+
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir,
343+
FLAGS.num_shards)
344+
generator_utils.generate_files(training_gen(), train_output_files,
345+
FLAGS.max_cases)
346+
tf.logging.info("Generating development data for %s.", problem)
347+
dev_shards = 10 if "coco" in problem else 1
348+
dev_output_files = generator_utils.dev_data_filenames(
349+
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, dev_shards)
350+
generator_utils.generate_files(dev_gen(), dev_output_files)
351+
all_output_files = train_output_files + dev_output_files
352+
353+
tf.logging.info("Shuffling data...")
354+
generator_utils.shuffle_dataset(all_output_files)
355+
386356

387-
tf.logging.info("Shuffling data...")
388-
for fname in all_output_files:
389-
records = generator_utils.read_records(fname)
390-
random.shuffle(records)
391-
out_fname = fname.replace(UNSHUFFLED_SUFFIX, "")
392-
generator_utils.write_records(records, out_fname)
393-
tf.gfile.Remove(fname)
357+
def generate_data_for_registered_problem(problem_name):
358+
problem = registry.problem(problem_name)
359+
problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir)
394360

395361

396362
if __name__ == "__main__":

tensor2tensor/bin/t2t-make-tf-configs

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright 2017 Google Inc.
2+
# Copyright 2017 The Tensor2Tensor Authors.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -17,13 +17,13 @@
1717
1818
Usage:
1919
20-
`t2t-make-tf-configs --workers="server1:1234" --ps="server3:2134,server4:2334"`
20+
`t2t-make-tf-configs --masters="server1:1234" --ps="server3:2134,server4:2334"`
2121
22-
Outputs 1 line per job to stdout, first the workers, then the parameter servers.
22+
Outputs 1 line per job to stdout, first the masters, then the parameter servers.
2323
Each line has the TF_CONFIG, then a tab, then the command line flags for that
2424
job.
2525
26-
If there is a single worker, workers will have the `--sync` flag.
26+
If there is a single master, it will have the `--sync` flag.
2727
"""
2828
from __future__ import absolute_import
2929
from __future__ import division
@@ -38,31 +38,32 @@ import tensorflow as tf
3838
flags = tf.flags
3939
FLAGS = flags.FLAGS
4040

41-
flags.DEFINE_string("workers", "", "Comma-separated list of worker addresses")
41+
flags.DEFINE_string("masters", "", "Comma-separated list of master addresses")
4242
flags.DEFINE_string("ps", "", "Comma-separated list of ps addresses")
4343

4444

4545
def main(_):
46-
if not (FLAGS.workers and FLAGS.ps):
47-
raise ValueError("Must provide --workers and --ps")
46+
if not (FLAGS.masters and FLAGS.ps):
47+
raise ValueError("Must provide --masters and --ps")
4848

49-
workers = FLAGS.workers.split(",")
49+
masters = FLAGS.masters.split(",")
5050
ps = FLAGS.ps.split(",")
5151

52-
cluster = {"ps": ps, "worker": workers}
52+
cluster = {"ps": ps, "master": masters}
5353

54-
for task_type, jobs in (("worker", workers), ("ps", ps)):
54+
for task_type, jobs in (("master", masters), ("ps", ps)):
5555
for idx, job in enumerate(jobs):
56-
if task_type == "worker":
56+
if task_type == "master":
5757
cmd_line_flags = " ".join([
5858
"--master=grpc://%s" % job,
5959
"--ps_replicas=%d" % len(ps),
60-
"--worker_replicas=%d" % len(workers),
60+
"--worker_replicas=%d" % len(masters),
6161
"--worker_gpu=1",
6262
"--worker_id=%d" % idx,
63+
"--worker_job='/job:master'",
6364
"--ps_gpu=1",
6465
"--schedule=train",
65-
"--sync" if len(workers) == 1 else "",
66+
"--sync" if len(masters) == 1 else "",
6667
])
6768
else:
6869
cmd_line_flags = " ".join([

tensor2tensor/bin/t2t-trainer

100755100644
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright 2017 Google Inc.
2+
# Copyright 2017 The Tensor2Tensor Authors.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.

tensor2tensor/data_generators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2017 Google Inc.
1+
# Copyright 2017 The Tensor2Tensor Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

0 commit comments

Comments
 (0)