38
38
39
39
from tensor2tensor import problems as problems_lib # pylint: disable=unused-import
40
40
from tensor2tensor .data_generators import generator_utils
41
+ from tensor2tensor .envs import env_problem_utils
41
42
from tensor2tensor .utils import registry
42
43
from tensor2tensor .utils import usr_dir
43
44
54
55
# Improrting here to prevent pylint from ungrouped-imports warning.
55
56
import tensorflow as tf # pylint: disable=g-import-not-at-top
56
57
57
-
58
58
flags = tf .flags
59
59
FLAGS = flags .FLAGS
60
60
65
65
"The name of the problem to generate data for." )
66
66
flags .DEFINE_string ("exclude_problems" , "" ,
67
67
"Comma-separates list of problems to exclude." )
68
- flags .DEFINE_integer ("num_shards" , 0 , "How many shards to use. Ignored for "
69
- "registered Problems." )
68
+ flags .DEFINE_integer (
69
+ "num_shards" , 0 , "How many shards to use. Ignored for "
70
+ "registered Problems." )
70
71
flags .DEFINE_integer ("max_cases" , 0 ,
71
72
"Maximum number of cases to generate (unbounded if 0)." )
73
+ flags .DEFINE_integer (
74
+ "env_problem_max_env_steps" , 0 ,
75
+ "Maximum number of steps to take for environment-based problems. "
76
+ "Actions are chosen randomly" )
77
+ flags .DEFINE_integer (
78
+ "env_problem_batch_size" , 0 ,
79
+ "Number of environments to simulate for environment-based problems." )
72
80
flags .DEFINE_bool ("only_list" , False ,
73
81
"If true, we only list the problems that will be generated." )
74
82
flags .DEFINE_integer ("random_seed" , 429459 , "Random seed to use." )
78
86
flags .DEFINE_integer (
79
87
"num_concurrent_processes" , None ,
80
88
"Applies only to problems for which multiprocess_generate=True." )
81
- flags .DEFINE_string ("t2t_usr_dir" , "" ,
82
- "Path to a Python module that will be imported. The "
83
- "__init__.py file should include the necessary imports. "
84
- "The imported files should contain registrations, "
85
- "e.g. @registry.register_problem calls, that will then be "
86
- "available to t2t-datagen." )
89
+ flags .DEFINE_string (
90
+ "t2t_usr_dir" , "" , "Path to a Python module that will be imported. The "
91
+ "__init__.py file should include the necessary imports. "
92
+ "The imported files should contain registrations, "
93
+ "e.g. @registry.register_problem calls, that will then be "
94
+ "available to t2t-datagen." )
87
95
88
96
# Mapping from problems that we can generate data for to their generators.
89
97
# pylint: disable=g-long-lambda
90
98
_SUPPORTED_PROBLEM_GENERATORS = {
91
- "algorithmic_algebra_inverse" : (
92
- lambda : algorithmic_math .algebra_inverse (26 , 0 , 2 , 100000 ),
93
- lambda : algorithmic_math .algebra_inverse (26 , 3 , 3 , 10000 ),
94
- lambda : None ), # test set
95
- "parsing_english_ptb8k" : (
96
- lambda : wsj_parsing .parsing_token_generator (
99
+ "algorithmic_algebra_inverse" :
100
+ ( lambda : algorithmic_math .algebra_inverse (26 , 0 , 2 , 100000 ),
101
+ lambda : algorithmic_math .algebra_inverse (26 , 3 , 3 , 10000 ),
102
+ lambda : None ), # test set
103
+ "parsing_english_ptb8k" :
104
+ ( lambda : wsj_parsing .parsing_token_generator (
97
105
FLAGS .data_dir , FLAGS .tmp_dir , True , 2 ** 13 , 2 ** 9 ),
98
- lambda : wsj_parsing .parsing_token_generator (
99
- FLAGS .data_dir , FLAGS .tmp_dir , False , 2 ** 13 , 2 ** 9 ),
100
- lambda : None ), # test set
101
- "parsing_english_ptb16k" : (
102
- lambda : wsj_parsing .parsing_token_generator (
106
+ lambda : wsj_parsing .parsing_token_generator (
107
+ FLAGS .data_dir , FLAGS .tmp_dir , False , 2 ** 13 , 2 ** 9 ),
108
+ lambda : None ), # test set
109
+ "parsing_english_ptb16k" :
110
+ ( lambda : wsj_parsing .parsing_token_generator (
103
111
FLAGS .data_dir , FLAGS .tmp_dir , True , 2 ** 14 , 2 ** 9 ),
104
- lambda : wsj_parsing .parsing_token_generator (
105
- FLAGS .data_dir , FLAGS .tmp_dir , False , 2 ** 14 , 2 ** 9 ),
106
- lambda : None ), # test set
107
- "inference_snli32k" : (
108
- lambda : snli .snli_token_generator (FLAGS .tmp_dir , True , 2 ** 15 ),
109
- lambda : snli .snli_token_generator (FLAGS .tmp_dir , False , 2 ** 15 ),
110
- lambda : None ), # test set
111
- "audio_timit_characters_test" : (
112
- lambda : audio .timit_generator (
113
- FLAGS .data_dir , FLAGS .tmp_dir , True , 1718 ),
114
- lambda : audio .timit_generator (
115
- FLAGS .data_dir , FLAGS .tmp_dir , False , 626 ),
116
- lambda : None ), # test set
117
- "audio_timit_tokens_8k_test" : (
118
- lambda : audio .timit_generator (
119
- FLAGS .data_dir , FLAGS .tmp_dir , True , 1718 ,
120
- vocab_filename = "vocab.endefr.%d" % 2 ** 13 , vocab_size = 2 ** 13 ),
121
- lambda : audio .timit_generator (
122
- FLAGS .data_dir , FLAGS .tmp_dir , False , 626 ,
123
- vocab_filename = "vocab.endefr.%d" % 2 ** 13 , vocab_size = 2 ** 13 ),
124
- lambda : None ), # test set
125
- "audio_timit_tokens_32k_test" : (
126
- lambda : audio .timit_generator (
127
- FLAGS .data_dir , FLAGS .tmp_dir , True , 1718 ,
128
- vocab_filename = "vocab.endefr.%d" % 2 ** 15 , vocab_size = 2 ** 15 ),
129
- lambda : audio .timit_generator (
130
- FLAGS .data_dir , FLAGS .tmp_dir , False , 626 ,
131
- vocab_filename = "vocab.endefr.%d" % 2 ** 15 , vocab_size = 2 ** 15 ),
132
- lambda : None ), # test set
112
+ lambda : wsj_parsing .parsing_token_generator (
113
+ FLAGS .data_dir , FLAGS .tmp_dir , False , 2 ** 14 , 2 ** 9 ),
114
+ lambda : None ), # test set
115
+ "inference_snli32k" :
116
+ (lambda : snli .snli_token_generator (FLAGS .tmp_dir , True , 2 ** 15 ),
117
+ lambda : snli .snli_token_generator (FLAGS .tmp_dir , False , 2 ** 15 ),
118
+ lambda : None ), # test set
119
+ "audio_timit_characters_test" : (lambda : audio .timit_generator (
120
+ FLAGS .data_dir , FLAGS .tmp_dir , True , 1718
121
+ ), lambda : audio .timit_generator (FLAGS .data_dir , FLAGS .tmp_dir , False , 626 ),
122
+ lambda : None ), # test set
123
+ "audio_timit_tokens_8k_test" : (lambda : audio .timit_generator (
124
+ FLAGS .data_dir ,
125
+ FLAGS .tmp_dir ,
126
+ True ,
127
+ 1718 ,
128
+ vocab_filename = "vocab.endefr.%d" % 2 ** 13 ,
129
+ vocab_size = 2 ** 13 ), lambda : audio .timit_generator (
130
+ FLAGS .data_dir ,
131
+ FLAGS .tmp_dir ,
132
+ False ,
133
+ 626 ,
134
+ vocab_filename = "vocab.endefr.%d" % 2 ** 13 ,
135
+ vocab_size = 2 ** 13 ), lambda : None ), # test set
136
+ "audio_timit_tokens_32k_test" : (lambda : audio .timit_generator (
137
+ FLAGS .data_dir ,
138
+ FLAGS .tmp_dir ,
139
+ True ,
140
+ 1718 ,
141
+ vocab_filename = "vocab.endefr.%d" % 2 ** 15 ,
142
+ vocab_size = 2 ** 15 ), lambda : audio .timit_generator (
143
+ FLAGS .data_dir ,
144
+ FLAGS .tmp_dir ,
145
+ False ,
146
+ 626 ,
147
+ vocab_filename = "vocab.endefr.%d" % 2 ** 15 ,
148
+ vocab_size = 2 ** 15 ), lambda : None ), # test set
133
149
}
134
150
135
151
# pylint: enable=g-long-lambda
@@ -147,7 +163,8 @@ def main(_):
147
163
148
164
# Calculate the list of problems to generate.
149
165
problems = sorted (
150
- list (_SUPPORTED_PROBLEM_GENERATORS ) + registry .list_base_problems ())
166
+ list (_SUPPORTED_PROBLEM_GENERATORS ) + registry .list_base_problems () +
167
+ registry .list_env_problems ())
151
168
for exclude in FLAGS .exclude_problems .split ("," ):
152
169
if exclude :
153
170
problems = [p for p in problems if exclude not in p ]
@@ -169,8 +186,9 @@ def main(_):
169
186
170
187
if not problems :
171
188
problems_str = "\n * " .join (
172
- sorted (list (_SUPPORTED_PROBLEM_GENERATORS ) +
173
- registry .list_base_problems ()))
189
+ sorted (
190
+ list (_SUPPORTED_PROBLEM_GENERATORS ) +
191
+ registry .list_base_problems () + registry .list_env_problems ()))
174
192
error_msg = ("You must specify one of the supported problems to "
175
193
"generate data for:\n * " + problems_str + "\n " )
176
194
error_msg += ("TIMIT and parsing need data_sets specified with "
@@ -179,24 +197,28 @@ def main(_):
179
197
180
198
if not FLAGS .data_dir :
181
199
FLAGS .data_dir = tempfile .gettempdir ()
182
- tf .logging .warning ("It is strongly recommended to specify --data_dir. "
183
- "Data will be written to default data_dir=%s." ,
184
- FLAGS .data_dir )
200
+ tf .logging .warning (
201
+ "It is strongly recommended to specify -- data_dir. "
202
+ "Data will be written to default data_dir=%s." , FLAGS .data_dir )
185
203
FLAGS .data_dir = os .path .expanduser (FLAGS .data_dir )
186
204
tf .gfile .MakeDirs (FLAGS .data_dir )
187
205
188
- tf .logging .info ("Generating problems:\n %s"
189
- % registry .display_list_by_prefix (problems ,
190
- starting_spaces = 4 ))
206
+ tf .logging .info ("Generating problems:\n %s" %
207
+ registry .display_list_by_prefix (problems , starting_spaces = 4 ))
191
208
if FLAGS .only_list :
192
209
return
193
210
for problem in problems :
194
211
set_random_seed ()
195
212
196
213
if problem in _SUPPORTED_PROBLEM_GENERATORS :
197
214
generate_data_for_problem (problem )
198
- else :
215
+ elif problem in registry . list_base_problems () :
199
216
generate_data_for_registered_problem (problem )
217
+ elif problem in registry .list_env_problems ():
218
+ generate_data_for_env_problem (problem )
219
+ else :
220
+ tf .logging .error ("Problem %s is not a supported problem for datagen." ,
221
+ problem )
200
222
201
223
202
224
def generate_data_for_problem (problem ):
@@ -235,6 +257,24 @@ def generate_data_in_process(arg):
235
257
problem .generate_data (data_dir , tmp_dir , task_id )
236
258
237
259
260
+ def generate_data_for_env_problem (problem_name ):
261
+ """Generate data for `EnvProblem`s."""
262
+ assert FLAGS .env_problem_max_env_steps > 0 , ("--env_problem_max_env_steps "
263
+ "should be greater than zero" )
264
+ assert FLAGS .env_problem_batch_size > 0 , ("--env_problem_batch_size should be"
265
+ " greather than zero" )
266
+ problem = registry .env_problem (problem_name )
267
+ task_id = None if FLAGS .task_id < 0 else FLAGS .task_id
268
+ data_dir = os .path .expanduser (FLAGS .data_dir )
269
+ tmp_dir = os .path .expanduser (FLAGS .tmp_dir )
270
+ # TODO(msaffar): Handle large values for env_problem_batch_size where we
271
+ # cannot create that many environments within the same process.
272
+ problem .initialize (batch_size = FLAGS .env_problem_batch_size )
273
+ env_problem_utils .play_env_problem_randomly (
274
+ problem , num_steps = FLAGS .env_problem_max_env_steps )
275
+ problem .generate_data (data_dir = data_dir , tmp_dir = tmp_dir , task_id = task_id )
276
+
277
+
238
278
def generate_data_for_registered_problem (problem_name ):
239
279
"""Generate data for a registered problem."""
240
280
tf .logging .info ("Generating data for %s." , problem_name )
@@ -260,6 +300,7 @@ def generate_data_for_registered_problem(problem_name):
260
300
else :
261
301
problem .generate_data (data_dir , tmp_dir , task_id )
262
302
303
+
263
304
if __name__ == "__main__" :
264
305
tf .logging .set_verbosity (tf .logging .INFO )
265
306
tf .app .run ()
0 commit comments