1
1
#!/usr/bin/env python
2
- # Copyright 2017 Google Inc .
2
+ # Copyright 2017 The Tensor2Tensor Authors .
3
3
#
4
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
5
# you may not use this file except in compliance with the License.
@@ -24,6 +24,9 @@ takes 2 arguments - input_directory and mode (one of "train" or "dev") - and
24
24
yields for each training example a dictionary mapping string feature names to
25
25
lists of {string, int, float}. The generator will be run once for each mode.
26
26
"""
27
+ from __future__ import absolute_import
28
+ from __future__ import division
29
+ from __future__ import print_function
27
30
28
31
import random
29
32
import tempfile
@@ -34,6 +37,7 @@ import numpy as np
34
37
35
38
from tensor2tensor .data_generators import algorithmic
36
39
from tensor2tensor .data_generators import algorithmic_math
40
+ from tensor2tensor .data_generators import all_problems # pylint: disable=unused-import
37
41
from tensor2tensor .data_generators import audio
38
42
from tensor2tensor .data_generators import generator_utils
39
43
from tensor2tensor .data_generators import image
@@ -43,6 +47,7 @@ from tensor2tensor.data_generators import snli
43
47
from tensor2tensor .data_generators import wiki
44
48
from tensor2tensor .data_generators import wmt
45
49
from tensor2tensor .data_generators import wsj_parsing
50
+ from tensor2tensor .utils import registry
46
51
47
52
import tensorflow as tf
48
53
@@ -62,12 +67,6 @@ flags.DEFINE_integer("random_seed", 429459, "Random seed to use.")
62
67
# Mapping from problems that we can generate data for to their generators.
63
68
# pylint: disable=g-long-lambda
64
69
_SUPPORTED_PROBLEM_GENERATORS = {
65
- "algorithmic_identity_binary40" : (
66
- lambda : algorithmic .identity_generator (2 , 40 , 100000 ),
67
- lambda : algorithmic .identity_generator (2 , 400 , 10000 )),
68
- "algorithmic_identity_decimal40" : (
69
- lambda : algorithmic .identity_generator (10 , 40 , 100000 ),
70
- lambda : algorithmic .identity_generator (10 , 400 , 10000 )),
71
70
"algorithmic_shift_decimal40" : (
72
71
lambda : algorithmic .shift_generator (20 , 10 , 40 , 100000 ),
73
72
lambda : algorithmic .shift_generator (20 , 10 , 80 , 10000 )),
@@ -104,9 +103,9 @@ _SUPPORTED_PROBLEM_GENERATORS = {
104
103
lambda : algorithmic_math .algebra_inverse (26 , 3 , 3 , 10000 )),
105
104
"ice_parsing_tokens" : (
106
105
lambda : wmt .tabbed_parsing_token_generator (FLAGS .tmp_dir ,
107
- True , "ice" , 2 ** 13 , 2 ** 8 ),
106
+ True , "ice" , 2 ** 13 , 2 ** 8 ),
108
107
lambda : wmt .tabbed_parsing_token_generator (FLAGS .tmp_dir ,
109
- False , "ice" , 2 ** 13 , 2 ** 8 )),
108
+ False , "ice" , 2 ** 13 , 2 ** 8 )),
110
109
"ice_parsing_characters" : (
111
110
lambda : wmt .tabbed_parsing_character_generator (FLAGS .tmp_dir , True ),
112
111
lambda : wmt .tabbed_parsing_character_generator (FLAGS .tmp_dir , False )),
@@ -118,11 +117,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
118
117
2 ** 14 , 2 ** 9 ),
119
118
lambda : wsj_parsing .parsing_token_generator (FLAGS .tmp_dir , False ,
120
119
2 ** 14 , 2 ** 9 )),
121
- "wsj_parsing_tokens_32k" : (
122
- lambda : wsj_parsing .parsing_token_generator (FLAGS .tmp_dir , True ,
123
- 2 ** 15 , 2 ** 9 ),
124
- lambda : wsj_parsing .parsing_token_generator (FLAGS .tmp_dir , False ,
125
- 2 ** 15 , 2 ** 9 )),
126
120
"wmt_enfr_characters" : (
127
121
lambda : wmt .enfr_character_generator (FLAGS .tmp_dir , True ),
128
122
lambda : wmt .enfr_character_generator (FLAGS .tmp_dir , False )),
@@ -140,14 +134,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
140
134
"wmt_ende_bpe32k" : (
141
135
lambda : wmt .ende_bpe_token_generator (FLAGS .tmp_dir , True ),
142
136
lambda : wmt .ende_bpe_token_generator (FLAGS .tmp_dir , False )),
143
- "wmt_ende_tokens_8k" : (
144
- lambda : wmt .ende_wordpiece_token_generator (FLAGS .tmp_dir , True , 2 ** 13 ),
145
- lambda : wmt .ende_wordpiece_token_generator (FLAGS .tmp_dir , False , 2 ** 13 )
146
- ),
147
- "wmt_ende_tokens_32k" : (
148
- lambda : wmt .ende_wordpiece_token_generator (FLAGS .tmp_dir , True , 2 ** 15 ),
149
- lambda : wmt .ende_wordpiece_token_generator (FLAGS .tmp_dir , False , 2 ** 15 )
150
- ),
151
137
"wmt_zhen_tokens_32k" : (
152
138
lambda : wmt .zhen_wordpiece_token_generator (FLAGS .tmp_dir , True ,
153
139
2 ** 15 , 2 ** 15 ),
@@ -174,26 +160,9 @@ _SUPPORTED_PROBLEM_GENERATORS = {
174
160
"image_cifar10_test" : (
175
161
lambda : image .cifar10_generator (FLAGS .tmp_dir , True , 50000 ),
176
162
lambda : image .cifar10_generator (FLAGS .tmp_dir , False , 10000 )),
177
- "image_mscoco_characters_tune" : (
178
- lambda : image .mscoco_generator (FLAGS .tmp_dir , True , 70000 ),
179
- lambda : image .mscoco_generator (FLAGS .tmp_dir , True , 10000 , 70000 )),
180
163
"image_mscoco_characters_test" : (
181
164
lambda : image .mscoco_generator (FLAGS .tmp_dir , True , 80000 ),
182
165
lambda : image .mscoco_generator (FLAGS .tmp_dir , False , 40000 )),
183
- "image_mscoco_tokens_8k_tune" : (
184
- lambda : image .mscoco_generator (
185
- FLAGS .tmp_dir ,
186
- True ,
187
- 70000 ,
188
- vocab_filename = "tokens.vocab.%d" % 2 ** 13 ,
189
- vocab_size = 2 ** 13 ),
190
- lambda : image .mscoco_generator (
191
- FLAGS .tmp_dir ,
192
- True ,
193
- 10000 ,
194
- 70000 ,
195
- vocab_filename = "tokens.vocab.%d" % 2 ** 13 ,
196
- vocab_size = 2 ** 13 )),
197
166
"image_mscoco_tokens_8k_test" : (
198
167
lambda : image .mscoco_generator (
199
168
FLAGS .tmp_dir ,
@@ -207,20 +176,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
207
176
40000 ,
208
177
vocab_filename = "tokens.vocab.%d" % 2 ** 13 ,
209
178
vocab_size = 2 ** 13 )),
210
- "image_mscoco_tokens_32k_tune" : (
211
- lambda : image .mscoco_generator (
212
- FLAGS .tmp_dir ,
213
- True ,
214
- 70000 ,
215
- vocab_filename = "tokens.vocab.%d" % 2 ** 15 ,
216
- vocab_size = 2 ** 15 ),
217
- lambda : image .mscoco_generator (
218
- FLAGS .tmp_dir ,
219
- True ,
220
- 10000 ,
221
- 70000 ,
222
- vocab_filename = "tokens.vocab.%d" % 2 ** 15 ,
223
- vocab_size = 2 ** 15 )),
224
179
"image_mscoco_tokens_32k_test" : (
225
180
lambda : image .mscoco_generator (
226
181
FLAGS .tmp_dir ,
@@ -308,8 +263,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
308
263
309
264
# pylint: enable=g-long-lambda
310
265
311
- UNSHUFFLED_SUFFIX = "-unshuffled"
312
-
313
266
314
267
def set_random_seed ():
315
268
"""Set the random seed from flag everywhere."""
@@ -322,13 +275,15 @@ def main(_):
322
275
tf .logging .set_verbosity (tf .logging .INFO )
323
276
324
277
# Calculate the list of problems to generate.
325
- problems = list (sorted (_SUPPORTED_PROBLEM_GENERATORS ))
278
+ problems = sorted (
279
+ list (_SUPPORTED_PROBLEM_GENERATORS ) + registry .list_problems ())
326
280
if FLAGS .problem and FLAGS .problem [- 1 ] == "*" :
327
281
problems = [p for p in problems if p .startswith (FLAGS .problem [:- 1 ])]
328
282
elif FLAGS .problem :
329
283
problems = [p for p in problems if p == FLAGS .problem ]
330
284
else :
331
285
problems = []
286
+
332
287
# Remove TIMIT if paths are not given.
333
288
if not FLAGS .timit_paths :
334
289
problems = [p for p in problems if "timit" not in p ]
@@ -340,7 +295,8 @@ def main(_):
340
295
problems = [p for p in problems if "ende_bpe" not in p ]
341
296
342
297
if not problems :
343
- problems_str = "\n * " .join (sorted (_SUPPORTED_PROBLEM_GENERATORS ))
298
+ problems_str = "\n * " .join (
299
+ sorted (list (_SUPPORTED_PROBLEM_GENERATORS ) + registry .list_problems ()))
344
300
error_msg = ("You must specify one of the supported problems to "
345
301
"generate data for:\n * " + problems_str + "\n " )
346
302
error_msg += ("TIMIT, ende_bpe and parsing need data_sets specified with "
@@ -357,40 +313,50 @@ def main(_):
357
313
for problem in problems :
358
314
set_random_seed ()
359
315
360
- training_gen , dev_gen = _SUPPORTED_PROBLEM_GENERATORS [problem ]
361
-
362
- if isinstance (dev_gen , int ):
363
- # The dev set and test sets are generated as extra shards using the
364
- # training generator. The integer specifies the number of training
365
- # shards. FLAGS.num_shards is ignored.
366
- num_training_shards = dev_gen
367
- tf .logging .info ("Generating data for %s." , problem )
368
- all_output_files = generator_utils .combined_data_filenames (
369
- problem + UNSHUFFLED_SUFFIX , FLAGS .data_dir , num_training_shards )
370
- generator_utils .generate_files (
371
- training_gen (), all_output_files , FLAGS .max_cases )
316
+ if problem in _SUPPORTED_PROBLEM_GENERATORS :
317
+ generate_data_for_problem (problem )
372
318
else :
373
- # usual case - train data and dev data are generated using separate
374
- # generators.
375
- tf .logging .info ("Generating training data for %s." , problem )
376
- train_output_files = generator_utils .train_data_filenames (
377
- problem + UNSHUFFLED_SUFFIX , FLAGS .data_dir , FLAGS .num_shards )
378
- generator_utils .generate_files (
379
- training_gen (), train_output_files , FLAGS .max_cases )
380
- tf .logging .info ("Generating development data for %s." , problem )
381
- dev_shards = 10 if "coco" in problem else 1
382
- dev_output_files = generator_utils .dev_data_filenames (
383
- problem + UNSHUFFLED_SUFFIX , FLAGS .data_dir , dev_shards )
384
- generator_utils .generate_files (dev_gen (), dev_output_files )
385
- all_output_files = train_output_files + dev_output_files
319
+ generate_data_for_registered_problem (problem )
320
+
321
+
322
+ def generate_data_for_problem (problem ):
323
+ """Generate data for a problem in _SUPPORTED_PROBLEM_GENERATORS."""
324
+ training_gen , dev_gen = _SUPPORTED_PROBLEM_GENERATORS [problem ]
325
+
326
+ if isinstance (dev_gen , int ):
327
+ # The dev set and test sets are generated as extra shards using the
328
+ # training generator. The integer specifies the number of training
329
+ # shards. FLAGS.num_shards is ignored.
330
+ num_training_shards = dev_gen
331
+ tf .logging .info ("Generating data for %s." , problem )
332
+ all_output_files = generator_utils .combined_data_filenames (
333
+ problem + generator_utils .UNSHUFFLED_SUFFIX , FLAGS .data_dir ,
334
+ num_training_shards )
335
+ generator_utils .generate_files (training_gen (), all_output_files ,
336
+ FLAGS .max_cases )
337
+ else :
338
+ # usual case - train data and dev data are generated using separate
339
+ # generators.
340
+ tf .logging .info ("Generating training data for %s." , problem )
341
+ train_output_files = generator_utils .train_data_filenames (
342
+ problem + generator_utils .UNSHUFFLED_SUFFIX , FLAGS .data_dir ,
343
+ FLAGS .num_shards )
344
+ generator_utils .generate_files (training_gen (), train_output_files ,
345
+ FLAGS .max_cases )
346
+ tf .logging .info ("Generating development data for %s." , problem )
347
+ dev_shards = 10 if "coco" in problem else 1
348
+ dev_output_files = generator_utils .dev_data_filenames (
349
+ problem + generator_utils .UNSHUFFLED_SUFFIX , FLAGS .data_dir , dev_shards )
350
+ generator_utils .generate_files (dev_gen (), dev_output_files )
351
+ all_output_files = train_output_files + dev_output_files
352
+
353
+ tf .logging .info ("Shuffling data..." )
354
+ generator_utils .shuffle_dataset (all_output_files )
355
+
386
356
387
- tf .logging .info ("Shuffling data..." )
388
- for fname in all_output_files :
389
- records = generator_utils .read_records (fname )
390
- random .shuffle (records )
391
- out_fname = fname .replace (UNSHUFFLED_SUFFIX , "" )
392
- generator_utils .write_records (records , out_fname )
393
- tf .gfile .Remove (fname )
357
+ def generate_data_for_registered_problem (problem_name ):
358
+ problem = registry .problem (problem_name )
359
+ problem .generate_data (FLAGS .data_dir , FLAGS .tmp_dir )
394
360
395
361
396
362
if __name__ == "__main__" :
0 commit comments