1
1
from functools import singledispatch
2
- from typing import Tuple
2
+ from typing import Optional , Tuple
3
3
4
4
import aesara .tensor as at
5
5
import aesara .tensor .random .basic as arb
6
6
import numpy as np
7
- from aesara import scan , shared
7
+ from aesara import scan
8
8
from aesara .compile .builders import OpFromGraph
9
9
from aesara .graph .op import Op
10
10
from aesara .raise_op import CheckAndRaise
11
11
from aesara .scan import until
12
+ from aesara .tensor .random import RandomStream
12
13
from aesara .tensor .random .op import RandomVariable
13
14
from aesara .tensor .var import TensorConstant , TensorVariable
14
15
@@ -68,7 +69,11 @@ def __str__(self):
68
69
69
70
70
71
def truncate (
71
- rv : TensorVariable , lower = None , upper = None , max_n_steps : int = 10_000 , rng = None
72
+ rv : TensorVariable ,
73
+ lower = None ,
74
+ upper = None ,
75
+ max_n_steps : int = 10_000 ,
76
+ srng : Optional [RandomStream ] = None ,
72
77
) -> Tuple [TensorVariable , Tuple [TensorVariable , TensorVariable ]]:
73
78
"""Truncate a univariate `RandomVariable` between `lower` and `upper`.
74
79
@@ -99,13 +104,13 @@ def truncate(
99
104
lower = at .as_tensor_variable (lower ) if lower is not None else at .constant (- np .inf )
100
105
upper = at .as_tensor_variable (upper ) if upper is not None else at .constant (np .inf )
101
106
102
- if rng is None :
103
- rng = shared ( np . random . RandomState (), borrow = True )
107
+ if srng is None :
108
+ srng = RandomStream ( )
104
109
105
110
# Try to use specialized Op
106
111
try :
107
112
truncated_rv , updates = _truncated (
108
- rv .owner .op , lower , upper , rng , * rv .owner .inputs [1 :]
113
+ rv .owner .op , lower , upper , srng , * rv .owner .inputs [1 :]
109
114
)
110
115
return truncated_rv , updates
111
116
except NotImplementedError :
@@ -116,8 +121,8 @@ def truncate(
116
121
# though it would not be necessary for the icdf OpFromGraph
117
122
graph_inputs = [* rv .owner .inputs [1 :], lower , upper ]
118
123
graph_inputs_ = [inp .type () for inp in graph_inputs ]
119
- * rv_inputs_ , lower_ , upper_ = graph_inputs_
120
- rv_ = rv .owner .op . make_node ( rng , * rv_inputs_ ). default_output ( )
124
+ size_ , dtype_ , * rv_inputs_ , lower_ , upper_ = graph_inputs_
125
+ rv_ = srng . gen ( rv .owner .op , * rv_inputs_ , size = size_ , dtype = dtype_ )
121
126
122
127
# Try to use inverted cdf sampling
123
128
try :
@@ -126,11 +131,10 @@ def truncate(
126
131
lower_value = lower_ - 1 if rv .owner .op .dtype .startswith ("int" ) else lower_
127
132
cdf_lower_ = at .exp (logcdf (rv_ , lower_value ))
128
133
cdf_upper_ = at .exp (logcdf (rv_ , upper_ ))
129
- uniform_ = at . random .uniform (
134
+ uniform_ = srng .uniform (
130
135
cdf_lower_ ,
131
136
cdf_upper_ ,
132
- rng = rng ,
133
- size = rv_inputs_ [0 ],
137
+ size = size_ ,
134
138
)
135
139
truncated_rv_ = icdf (rv_ , uniform_ )
136
140
truncated_rv = TruncatedRV (
@@ -146,27 +150,23 @@ def truncate(
146
150
147
151
# Fallback to rejection sampling
148
152
# TODO: Handle potential broadcast by lower / upper
149
- def loop_fn (truncated_rv , reject_draws , lower , upper , rng , * rv_inputs ):
150
- next_rng , new_truncated_rv = rv .owner .op . make_node ( rng , * rv_inputs ). outputs
153
+ def loop_fn (truncated_rv , reject_draws , lower , upper , size , dtype , * rv_inputs ):
154
+ new_truncated_rv = srng . gen ( rv .owner .op , * rv_inputs , size = size , dtype = dtype ) # type: ignore
151
155
truncated_rv = at .set_subtensor (
152
156
truncated_rv [reject_draws ],
153
157
new_truncated_rv [reject_draws ],
154
158
)
155
159
reject_draws = at .or_ ((truncated_rv < lower ), (truncated_rv > upper ))
156
160
157
- return (
158
- (truncated_rv , reject_draws ),
159
- [(rng , next_rng )],
160
- until (~ at .any (reject_draws )),
161
- )
161
+ return (truncated_rv , reject_draws ), until (~ at .any (reject_draws ))
162
162
163
163
(truncated_rv_ , reject_draws_ ), updates = scan (
164
164
loop_fn ,
165
165
outputs_info = [
166
166
at .zeros_like (rv_ ),
167
167
at .ones_like (rv_ , dtype = bool ),
168
168
],
169
- non_sequences = [lower_ , upper_ , rng , * rv_inputs_ ],
169
+ non_sequences = [lower_ , upper_ , size_ , dtype_ , * rv_inputs_ ],
170
170
n_steps = max_n_steps ,
171
171
strict = True ,
172
172
)
@@ -180,18 +180,28 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
180
180
truncated_rv = TruncatedRV (
181
181
base_rv_op = rv .owner .op ,
182
182
inputs = graph_inputs_ ,
183
- outputs = [truncated_rv_ , tuple (updates .values ())[0 ]],
183
+ # This will fail with `n_steps==1`, because in that case `Scan` won't return any updates
184
+ outputs = [truncated_rv_ , rv_ .owner .outputs [0 ], tuple (updates .values ())[0 ]],
184
185
inline = True ,
185
186
)(* graph_inputs )
186
- updates = {truncated_rv .owner .inputs [- 1 ]: truncated_rv .owner .outputs [- 1 ]}
187
+ # TODO: Is the order of multiple shared variables determnistic?
188
+ assert truncated_rv .owner .inputs [- 2 ] is rv_ .owner .inputs [0 ]
189
+ updates = {
190
+ truncated_rv .owner .inputs [- 2 ]: truncated_rv .owner .outputs [- 2 ],
191
+ truncated_rv .owner .inputs [- 1 ]: truncated_rv .owner .outputs [- 1 ],
192
+ }
187
193
return truncated_rv , updates
188
194
189
195
190
196
@_logprob .register (TruncatedRV )
191
197
def truncated_logprob (op , values , * inputs , ** kwargs ):
192
198
(value ,) = values
193
199
194
- * rv_inputs , lower_bound , upper_bound , rng = inputs
200
+ # Rejection sample graph has two rngs
201
+ if len (op .shared_inputs ) == 2 :
202
+ * rv_inputs , lower_bound , upper_bound , _ , rng = inputs
203
+ else :
204
+ * rv_inputs , lower_bound , upper_bound , rng = inputs
195
205
rv_inputs = [rng , * rv_inputs ]
196
206
197
207
base_rv_op = op .base_rv_op
@@ -242,11 +252,11 @@ def truncated_logprob(op, values, *inputs, **kwargs):
242
252
243
253
244
254
@_truncated .register (arb .UniformRV )
245
- def uniform_truncated (op , lower , upper , rng , size , dtype , lower_orig , upper_orig ):
246
- truncated_uniform = at .random .uniform (
255
+ def uniform_truncated (op , lower , upper , srng , size , dtype , lower_orig , upper_orig ):
256
+ truncated_uniform = srng .gen (
257
+ op ,
247
258
at .max ((lower_orig , lower )),
248
259
at .min ((upper_orig , upper )),
249
- rng = rng ,
250
260
size = size ,
251
261
dtype = dtype ,
252
262
)
0 commit comments