21
21
22
22
from absl .testing import absltest
23
23
from absl .testing import parameterized
24
+ import jax
25
+ import numpy as onp
26
+
27
+ from tensor2tensor .trax import backend
24
28
from tensor2tensor .trax import layers as tl
29
+ from tensor2tensor .trax .backend import numpy as np
25
30
from tensor2tensor .trax .models .research import reformer
26
31
27
32
33
+ class PoisonOnRNGMismatchAttention (tl .BaseCausalAttention ):
34
+ """Fills gradients with NaNs if reverse rng does not match forward rng."""
35
+
36
+ # pylint: disable=protected-access
37
+ def forward_and_backward (self , inputs , ct , rng = None , ** kwargs ):
38
+ assert backend .get_name () == 'jax' , (
39
+ 'JAX backend is required to use forward_and_backward.' )
40
+
41
+ if ct is not None and tl .Layer ._STASH_OUT is not None :
42
+ recovered_rng = tl .Layer ._STASH_OUT .pop (self )
43
+ is_same = (rng [0 ] == recovered_rng [0 ]) & (rng [1 ] == recovered_rng [1 ])
44
+ is_same = is_same .astype (np .float32 )
45
+ # Divides by zero if rngs are not the same, which results in NaNs.
46
+ inputs = (inputs [0 ] / is_same , inputs [1 ] / is_same , inputs [2 ] / is_same )
47
+
48
+ def _do_forward (x ): # pylint: disable=invalid-name
49
+ res , _ = self .forward (x , rng = rng , ** kwargs )
50
+ return res
51
+ output , vjpfun = jax .vjp (_do_forward , inputs )
52
+ return output , vjpfun (ct )[0 ]
53
+
54
+ def forward (self , inputs , params = (), state = (), rng = None , ** kwargs ):
55
+ if tl .Layer ._STASH_IN is not None :
56
+ tl .Layer ._STASH_IN [self ] = rng
57
+ return inputs [2 ], state
58
+ # pylint: enable=protected-access
59
+
60
+
28
61
class ReformerTest (parameterized .TestCase ):
29
62
30
63
def test_reformer_lm_forward_shape (self ):
@@ -39,6 +72,33 @@ def test_reformer_lm_forward_shape(self):
39
72
model , tuple (input_shape ), integer_inputs = True )
40
73
self .assertEqual (((1 , 8 , 16 ), (1 , 8 , 16 )), final_shape )
41
74
75
+ def test_reformer_rng_consistency (self ):
76
+ with backend .use_backend ('jax' ):
77
+ vocab_size = 16
78
+ batch_size = 1
79
+ input_shape = ((batch_size , 8 ), (batch_size , 8 ))
80
+ model = reformer .ReformerLM (
81
+ vocab_size , d_model = 32 , d_ff = 64 ,
82
+ d_attention_key = 16 , d_attention_value = 16 , n_layers = 1 , n_heads = 2 ,
83
+ max_len = 16 , n_chunks = 2 , n_attention_chunks = 1 , mode = 'train' ,
84
+ attention_type = PoisonOnRNGMismatchAttention )
85
+
86
+ rng = backend .random .get_prng (0 )
87
+ params , state = model .initialize_once (
88
+ input_shape , (np .int32 , np .int32 ), rng )
89
+
90
+ def dummy_loss_fn (params ):
91
+ inputs = (np .zeros (input_shape [0 ], dtype = np .int32 ),) * 2
92
+ output = model (inputs , params = params , state = state , rng = rng )
93
+ dummy_loss = backend .numpy .sum (output [0 ])
94
+ return dummy_loss
95
+
96
+ grad_fn = backend .grad (dummy_loss_fn )
97
+ grads = grad_fn (params )
98
+ # PoisonOnRNGMismatchAttention uses NaNs to signal an rng mismatch.
99
+ for grad in jax .tree_util .tree_leaves (grads ):
100
+ assert onp .all (onp .isfinite (grad ))
101
+
42
102
43
103
if __name__ == '__main__' :
44
104
absltest .main ()
0 commit comments