Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 6671139

Browse files
T2T TeamCopybara-Service
authored andcommitted
Update datagen to work with environment problems. This will run the environment with randomly sampled actions and records the environment state + actions.
Flags: - env_problem_batch_size: Controls how many examples we generate - env_problem_max_env_steps: Controls how many steps to run in each example. PiperOrigin-RevId: 237324759
1 parent 19c8871 commit 6671139

File tree

3 files changed

+106
-64
lines changed

3 files changed

+106
-64
lines changed

tensor2tensor/bin/t2t_datagen.py

Lines changed: 100 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
from tensor2tensor import problems as problems_lib # pylint: disable=unused-import
4040
from tensor2tensor.data_generators import generator_utils
41+
from tensor2tensor.envs import env_problem_utils
4142
from tensor2tensor.utils import registry
4243
from tensor2tensor.utils import usr_dir
4344

@@ -54,7 +55,6 @@
5455
# Improrting here to prevent pylint from ungrouped-imports warning.
5556
import tensorflow as tf # pylint: disable=g-import-not-at-top
5657

57-
5858
flags = tf.flags
5959
FLAGS = flags.FLAGS
6060

@@ -65,10 +65,18 @@
6565
"The name of the problem to generate data for.")
6666
flags.DEFINE_string("exclude_problems", "",
6767
"Comma-separates list of problems to exclude.")
68-
flags.DEFINE_integer("num_shards", 0, "How many shards to use. Ignored for "
69-
"registered Problems.")
68+
flags.DEFINE_integer(
69+
"num_shards", 0, "How many shards to use. Ignored for "
70+
"registered Problems.")
7071
flags.DEFINE_integer("max_cases", 0,
7172
"Maximum number of cases to generate (unbounded if 0).")
73+
flags.DEFINE_integer(
74+
"env_problem_max_env_steps", 0,
75+
"Maximum number of steps to take for environment-based problems. "
76+
"Actions are chosen randomly")
77+
flags.DEFINE_integer(
78+
"env_problem_batch_size", 0,
79+
"Number of environments to simulate for environment-based problems.")
7280
flags.DEFINE_bool("only_list", False,
7381
"If true, we only list the problems that will be generated.")
7482
flags.DEFINE_integer("random_seed", 429459, "Random seed to use.")
@@ -78,58 +86,66 @@
7886
flags.DEFINE_integer(
7987
"num_concurrent_processes", None,
8088
"Applies only to problems for which multiprocess_generate=True.")
81-
flags.DEFINE_string("t2t_usr_dir", "",
82-
"Path to a Python module that will be imported. The "
83-
"__init__.py file should include the necessary imports. "
84-
"The imported files should contain registrations, "
85-
"e.g. @registry.register_problem calls, that will then be "
86-
"available to t2t-datagen.")
89+
flags.DEFINE_string(
90+
"t2t_usr_dir", "", "Path to a Python module that will be imported. The "
91+
"__init__.py file should include the necessary imports. "
92+
"The imported files should contain registrations, "
93+
"e.g. @registry.register_problem calls, that will then be "
94+
"available to t2t-datagen.")
8795

8896
# Mapping from problems that we can generate data for to their generators.
8997
# pylint: disable=g-long-lambda
9098
_SUPPORTED_PROBLEM_GENERATORS = {
91-
"algorithmic_algebra_inverse": (
92-
lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
93-
lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000),
94-
lambda: None), # test set
95-
"parsing_english_ptb8k": (
96-
lambda: wsj_parsing.parsing_token_generator(
99+
"algorithmic_algebra_inverse":
100+
(lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
101+
lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000),
102+
lambda: None), # test set
103+
"parsing_english_ptb8k":
104+
(lambda: wsj_parsing.parsing_token_generator(
97105
FLAGS.data_dir, FLAGS.tmp_dir, True, 2**13, 2**9),
98-
lambda: wsj_parsing.parsing_token_generator(
99-
FLAGS.data_dir, FLAGS.tmp_dir, False, 2**13, 2**9),
100-
lambda: None), # test set
101-
"parsing_english_ptb16k": (
102-
lambda: wsj_parsing.parsing_token_generator(
106+
lambda: wsj_parsing.parsing_token_generator(
107+
FLAGS.data_dir, FLAGS.tmp_dir, False, 2**13, 2**9),
108+
lambda: None), # test set
109+
"parsing_english_ptb16k":
110+
(lambda: wsj_parsing.parsing_token_generator(
103111
FLAGS.data_dir, FLAGS.tmp_dir, True, 2**14, 2**9),
104-
lambda: wsj_parsing.parsing_token_generator(
105-
FLAGS.data_dir, FLAGS.tmp_dir, False, 2**14, 2**9),
106-
lambda: None), # test set
107-
"inference_snli32k": (
108-
lambda: snli.snli_token_generator(FLAGS.tmp_dir, True, 2**15),
109-
lambda: snli.snli_token_generator(FLAGS.tmp_dir, False, 2**15),
110-
lambda: None), # test set
111-
"audio_timit_characters_test": (
112-
lambda: audio.timit_generator(
113-
FLAGS.data_dir, FLAGS.tmp_dir, True, 1718),
114-
lambda: audio.timit_generator(
115-
FLAGS.data_dir, FLAGS.tmp_dir, False, 626),
116-
lambda: None), # test set
117-
"audio_timit_tokens_8k_test": (
118-
lambda: audio.timit_generator(
119-
FLAGS.data_dir, FLAGS.tmp_dir, True, 1718,
120-
vocab_filename="vocab.endefr.%d" % 2**13, vocab_size=2**13),
121-
lambda: audio.timit_generator(
122-
FLAGS.data_dir, FLAGS.tmp_dir, False, 626,
123-
vocab_filename="vocab.endefr.%d" % 2**13, vocab_size=2**13),
124-
lambda: None), # test set
125-
"audio_timit_tokens_32k_test": (
126-
lambda: audio.timit_generator(
127-
FLAGS.data_dir, FLAGS.tmp_dir, True, 1718,
128-
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15),
129-
lambda: audio.timit_generator(
130-
FLAGS.data_dir, FLAGS.tmp_dir, False, 626,
131-
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15),
132-
lambda: None), # test set
112+
lambda: wsj_parsing.parsing_token_generator(
113+
FLAGS.data_dir, FLAGS.tmp_dir, False, 2**14, 2**9),
114+
lambda: None), # test set
115+
"inference_snli32k":
116+
(lambda: snli.snli_token_generator(FLAGS.tmp_dir, True, 2**15),
117+
lambda: snli.snli_token_generator(FLAGS.tmp_dir, False, 2**15),
118+
lambda: None), # test set
119+
"audio_timit_characters_test": (lambda: audio.timit_generator(
120+
FLAGS.data_dir, FLAGS.tmp_dir, True, 1718
121+
), lambda: audio.timit_generator(FLAGS.data_dir, FLAGS.tmp_dir, False, 626),
122+
lambda: None), # test set
123+
"audio_timit_tokens_8k_test": (lambda: audio.timit_generator(
124+
FLAGS.data_dir,
125+
FLAGS.tmp_dir,
126+
True,
127+
1718,
128+
vocab_filename="vocab.endefr.%d" % 2**13,
129+
vocab_size=2**13), lambda: audio.timit_generator(
130+
FLAGS.data_dir,
131+
FLAGS.tmp_dir,
132+
False,
133+
626,
134+
vocab_filename="vocab.endefr.%d" % 2**13,
135+
vocab_size=2**13), lambda: None), # test set
136+
"audio_timit_tokens_32k_test": (lambda: audio.timit_generator(
137+
FLAGS.data_dir,
138+
FLAGS.tmp_dir,
139+
True,
140+
1718,
141+
vocab_filename="vocab.endefr.%d" % 2**15,
142+
vocab_size=2**15), lambda: audio.timit_generator(
143+
FLAGS.data_dir,
144+
FLAGS.tmp_dir,
145+
False,
146+
626,
147+
vocab_filename="vocab.endefr.%d" % 2**15,
148+
vocab_size=2**15), lambda: None), # test set
133149
}
134150

135151
# pylint: enable=g-long-lambda
@@ -147,7 +163,8 @@ def main(_):
147163

148164
# Calculate the list of problems to generate.
149165
problems = sorted(
150-
list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_base_problems())
166+
list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_base_problems() +
167+
registry.list_env_problems())
151168
for exclude in FLAGS.exclude_problems.split(","):
152169
if exclude:
153170
problems = [p for p in problems if exclude not in p]
@@ -169,8 +186,9 @@ def main(_):
169186

170187
if not problems:
171188
problems_str = "\n * ".join(
172-
sorted(list(_SUPPORTED_PROBLEM_GENERATORS) +
173-
registry.list_base_problems()))
189+
sorted(
190+
list(_SUPPORTED_PROBLEM_GENERATORS) +
191+
registry.list_base_problems() + registry.list_env_problems()))
174192
error_msg = ("You must specify one of the supported problems to "
175193
"generate data for:\n * " + problems_str + "\n")
176194
error_msg += ("TIMIT and parsing need data_sets specified with "
@@ -179,24 +197,28 @@ def main(_):
179197

180198
if not FLAGS.data_dir:
181199
FLAGS.data_dir = tempfile.gettempdir()
182-
tf.logging.warning("It is strongly recommended to specify --data_dir. "
183-
"Data will be written to default data_dir=%s.",
184-
FLAGS.data_dir)
200+
tf.logging.warning(
201+
"It is strongly recommended to specify --data_dir. "
202+
"Data will be written to default data_dir=%s.", FLAGS.data_dir)
185203
FLAGS.data_dir = os.path.expanduser(FLAGS.data_dir)
186204
tf.gfile.MakeDirs(FLAGS.data_dir)
187205

188-
tf.logging.info("Generating problems:\n%s"
189-
% registry.display_list_by_prefix(problems,
190-
starting_spaces=4))
206+
tf.logging.info("Generating problems:\n%s" %
207+
registry.display_list_by_prefix(problems, starting_spaces=4))
191208
if FLAGS.only_list:
192209
return
193210
for problem in problems:
194211
set_random_seed()
195212

196213
if problem in _SUPPORTED_PROBLEM_GENERATORS:
197214
generate_data_for_problem(problem)
198-
else:
215+
elif problem in registry.list_base_problems():
199216
generate_data_for_registered_problem(problem)
217+
elif problem in registry.list_env_problems():
218+
generate_data_for_env_problem(problem)
219+
else:
220+
tf.logging.error("Problem %s is not a supported problem for datagen.",
221+
problem)
200222

201223

202224
def generate_data_for_problem(problem):
@@ -235,6 +257,24 @@ def generate_data_in_process(arg):
235257
problem.generate_data(data_dir, tmp_dir, task_id)
236258

237259

260+
def generate_data_for_env_problem(problem_name):
261+
"""Generate data for `EnvProblem`s."""
262+
assert FLAGS.env_problem_max_env_steps > 0, ("--env_problem_max_env_steps "
263+
"should be greater than zero")
264+
assert FLAGS.env_problem_batch_size > 0, ("--env_problem_batch_size should be"
265+
" greather than zero")
266+
problem = registry.env_problem(problem_name)
267+
task_id = None if FLAGS.task_id < 0 else FLAGS.task_id
268+
data_dir = os.path.expanduser(FLAGS.data_dir)
269+
tmp_dir = os.path.expanduser(FLAGS.tmp_dir)
270+
# TODO(msaffar): Handle large values for env_problem_batch_size where we
271+
# cannot create that many environments within the same process.
272+
problem.initialize(batch_size=FLAGS.env_problem_batch_size)
273+
env_problem_utils.play_env_problem_randomly(
274+
problem, num_steps=FLAGS.env_problem_max_env_steps)
275+
problem.generate_data(data_dir=data_dir, tmp_dir=tmp_dir, task_id=task_id)
276+
277+
238278
def generate_data_for_registered_problem(problem_name):
239279
"""Generate data for a registered problem."""
240280
tf.logging.info("Generating data for %s.", problem_name)
@@ -260,6 +300,7 @@ def generate_data_for_registered_problem(problem_name):
260300
else:
261301
problem.generate_data(data_dir, tmp_dir, task_id)
262302

303+
263304
if __name__ == "__main__":
264305
tf.logging.set_verbosity(tf.logging.INFO)
265306
tf.app.run()

tensor2tensor/data_generators/all_problems.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@
9191
"tensor2tensor.data_generators.wikitext103",
9292
"tensor2tensor.data_generators.wsj_parsing",
9393
"tensor2tensor.data_generators.wnli",
94+
"tensor2tensor.envs.mujoco_problems",
95+
"tensor2tensor.envs.tic_tac_toe_env_problem",
9496
]
9597
ALL_MODULES = list(MODULES)
9698

tensor2tensor/envs/env_problem.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -671,14 +671,13 @@ def _generate_time_steps(self, trajectory_list):
671671
if not processed_reward:
672672
processed_reward = 0
673673

674-
if time_step.action:
675-
action = gym_spaces_utils.gym_space_encode(self.action_space,
676-
time_step.action)
677-
else:
674+
action = time_step.action
675+
if action is None:
678676
# The last time-step doesn't have action, and this action shouldn't be
679677
# used, gym's spaces have a `sample` function, so let's just sample an
680678
# action and use that.
681-
action = [self.action_space.sample()]
679+
action = self.action_space.sample()
680+
action = gym_spaces_utils.gym_space_encode(self.action_space, action)
682681

683682
if six.PY3:
684683
# py3 complains that, to_example cannot handle np.int64 !

0 commit comments

Comments
 (0)