Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 8de0584

Browse files
Ryan SepassiCopybara-Service
authored andcommitted
Update trax_test to test_train_eval_predict
PiperOrigin-RevId: 237274100
1 parent b448041 commit 8de0584

File tree

1 file changed

+23
-18
lines changed

1 file changed

+23
-18
lines changed

tensor2tensor/trax/trax_test.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import numpy as np
2727

28-
from tensor2tensor.trax import inputs
28+
from tensor2tensor.trax import inputs as inputs_lib
2929
from tensor2tensor.trax import models
3030
from tensor2tensor.trax import trax
3131

@@ -43,7 +43,7 @@ def input_stream():
4343
yield (np.random.rand(*([batch_size] + list(input_shape))),
4444
np.random.randint(num_classes, size=batch_size))
4545

46-
return inputs.Inputs(
46+
return inputs_lib.Inputs(
4747
train_stream=input_stream,
4848
eval_stream=input_stream,
4949
input_shape=input_shape)
@@ -57,32 +57,37 @@ def tmp_dir(self):
5757
yield tmp
5858
gfile.rmtree(tmp)
5959

60-
@property
61-
def train_args(self):
62-
num_classes = 4
63-
return dict(
64-
model=functools.partial(models.MLP,
60+
def test_train_eval_predict(self):
61+
with self.tmp_dir() as output_dir:
62+
# Prepare model and inputs
63+
num_classes = 4
64+
train_steps = 2
65+
eval_steps = 2
66+
model = functools.partial(models.MLP,
6567
hidden_size=16,
66-
num_output_classes=num_classes),
67-
inputs=lambda: test_inputs(num_classes),
68-
train_steps=3,
69-
eval_steps=2)
68+
num_output_classes=num_classes)
69+
inputs = lambda: test_inputs(num_classes)
7070

71-
def _test_train(self, train_args):
72-
with self.tmp_dir() as output_dir:
73-
state = trax.train(output_dir, **train_args)
71+
# Train and evaluate
72+
state = trax.train(output_dir,
73+
model=model,
74+
inputs=inputs,
75+
train_steps=train_steps,
76+
eval_steps=eval_steps)
7477

7578
# Assert total train steps
76-
self.assertEqual(train_args["train_steps"], state.step)
79+
self.assertEqual(train_steps, state.step)
7780

78-
# Assert 2 epochs ran
81+
# Assert 2 evaluations ran
7982
train_acc = state.history.get("train", "metrics/accuracy")
8083
eval_acc = state.history.get("eval", "metrics/accuracy")
8184
self.assertEqual(len(train_acc), len(eval_acc))
8285
self.assertEqual(2, len(eval_acc))
8386

84-
def test_train(self):
85-
self._test_train(self.train_args)
87+
# Predict with final params
88+
_, predict_fun = model()
89+
inputs = inputs().train_stream()
90+
predict_fun(state.params, next(inputs)[0])
8691

8792

8893
if __name__ == "__main__":

0 commit comments

Comments
 (0)