25
25
26
26
import numpy as np
27
27
28
- from tensor2tensor .trax import inputs
28
+ from tensor2tensor .trax import inputs as inputs_lib
29
29
from tensor2tensor .trax import models
30
30
from tensor2tensor .trax import trax
31
31
@@ -43,7 +43,7 @@ def input_stream():
43
43
yield (np .random .rand (* ([batch_size ] + list (input_shape ))),
44
44
np .random .randint (num_classes , size = batch_size ))
45
45
46
- return inputs .Inputs (
46
+ return inputs_lib .Inputs (
47
47
train_stream = input_stream ,
48
48
eval_stream = input_stream ,
49
49
input_shape = input_shape )
@@ -57,32 +57,37 @@ def tmp_dir(self):
57
57
yield tmp
58
58
gfile .rmtree (tmp )
59
59
60
- @property
61
- def train_args (self ):
62
- num_classes = 4
63
- return dict (
64
- model = functools .partial (models .MLP ,
60
+ def test_train_eval_predict (self ):
61
+ with self .tmp_dir () as output_dir :
62
+ # Prepare model and inputs
63
+ num_classes = 4
64
+ train_steps = 2
65
+ eval_steps = 2
66
+ model = functools .partial (models .MLP ,
65
67
hidden_size = 16 ,
66
- num_output_classes = num_classes ),
67
- inputs = lambda : test_inputs (num_classes ),
68
- train_steps = 3 ,
69
- eval_steps = 2 )
68
+ num_output_classes = num_classes )
69
+ inputs = lambda : test_inputs (num_classes )
70
70
71
- def _test_train (self , train_args ):
72
- with self .tmp_dir () as output_dir :
73
- state = trax .train (output_dir , ** train_args )
71
+ # Train and evaluate
72
+ state = trax .train (output_dir ,
73
+ model = model ,
74
+ inputs = inputs ,
75
+ train_steps = train_steps ,
76
+ eval_steps = eval_steps )
74
77
75
78
# Assert total train steps
76
- self .assertEqual (train_args [ " train_steps" ] , state .step )
79
+ self .assertEqual (train_steps , state .step )
77
80
78
- # Assert 2 epochs ran
81
+ # Assert 2 evaluations ran
79
82
train_acc = state .history .get ("train" , "metrics/accuracy" )
80
83
eval_acc = state .history .get ("eval" , "metrics/accuracy" )
81
84
self .assertEqual (len (train_acc ), len (eval_acc ))
82
85
self .assertEqual (2 , len (eval_acc ))
83
86
84
- def test_train (self ):
85
- self ._test_train (self .train_args )
87
+ # Predict with final params
88
+ _ , predict_fun = model ()
89
+ inputs = inputs ().train_stream ()
90
+ predict_fun (state .params , next (inputs )[0 ])
86
91
87
92
88
93
if __name__ == "__main__" :
0 commit comments