4
4
import aesara .tensor as at
5
5
import aesara .tensor .random .basic as arb
6
6
import numpy as np
7
- from aesara import scan , shared
7
+ from aesara import scan
8
8
from aesara .compile .builders import OpFromGraph
9
9
from aesara .graph .basic import Node
10
10
from aesara .graph .fg import FunctionGraph
15
15
from aesara .scalar .basic import clip as scalar_clip
16
16
from aesara .scan import until
17
17
from aesara .tensor .elemwise import Elemwise
18
+ from aesara .tensor .random import RandomStream
18
19
from aesara .tensor .random .op import RandomVariable
19
20
from aesara .tensor .var import TensorConstant , TensorVariable
20
21
@@ -188,7 +189,11 @@ def __str__(self):
188
189
189
190
190
191
def truncate (
191
- rv : TensorVariable , lower = None , upper = None , max_n_steps : int = 10_000 , rng = None
192
+ rv : TensorVariable ,
193
+ lower = None ,
194
+ upper = None ,
195
+ max_n_steps : int = 10_000 ,
196
+ srng : Optional [RandomStream ] = None ,
192
197
) -> Tuple [TensorVariable , Tuple [TensorVariable , TensorVariable ]]:
193
198
"""Truncate a univariate `RandomVariable` between lower and upper.
194
199
@@ -218,13 +223,13 @@ def truncate(
218
223
lower = at .as_tensor_variable (lower ) if lower is not None else at .constant (- np .inf )
219
224
upper = at .as_tensor_variable (upper ) if upper is not None else at .constant (np .inf )
220
225
221
- if rng is None :
222
- rng = shared ( np . random . RandomState (), borrow = True )
226
+ if srng is None :
227
+ srng = RandomStream ( )
223
228
224
229
# Try to use specialized Op
225
230
try :
226
231
truncated_rv , updates = _truncated (
227
- rv .owner .op , lower , upper , rng , * rv .owner .inputs [1 :]
232
+ rv .owner .op , lower , upper , srng , * rv .owner .inputs [1 :]
228
233
)
229
234
return truncated_rv , updates
230
235
except NotImplementedError :
@@ -235,8 +240,8 @@ def truncate(
235
240
# though it would not be necessary for the icdf OpFromGraph
236
241
graph_inputs = [* rv .owner .inputs [1 :], lower , upper ]
237
242
graph_inputs_ = [inp .type () for inp in graph_inputs ]
238
- * rv_inputs_ , lower_ , upper_ = graph_inputs_
239
- rv_ = rv .owner .op . make_node ( rng , * rv_inputs_ ). default_output ( )
243
+ size_ , dtype_ , * rv_inputs_ , lower_ , upper_ = graph_inputs_
244
+ rv_ = srng . gen ( rv .owner .op , * rv_inputs_ , size = size_ , dtype = dtype_ )
240
245
241
246
# Try to use inverted cdf sampling
242
247
try :
@@ -245,11 +250,10 @@ def truncate(
245
250
lower_value = lower_ - 1 if rv .owner .op .dtype .startswith ("int" ) else lower_
246
251
cdf_lower_ = at .exp (logcdf (rv_ , lower_value ))
247
252
cdf_upper_ = at .exp (logcdf (rv_ , upper_ ))
248
- uniform_ = at . random .uniform (
253
+ uniform_ = srng .uniform (
249
254
cdf_lower_ ,
250
255
cdf_upper_ ,
251
- rng = rng ,
252
- size = rv_inputs_ [0 ],
256
+ size = size_ ,
253
257
)
254
258
truncated_rv_ = icdf (rv_ , uniform_ )
255
259
truncated_rv = TruncatedRV (
@@ -265,27 +269,23 @@ def truncate(
265
269
266
270
# Fallback to rejection sampling
267
271
# TODO: Handle potential broadcast by lower / upper
268
- def loop_fn (truncated_rv , reject_draws , lower , upper , rng , * rv_inputs ):
269
- next_rng , new_truncated_rv = rv .owner .op . make_node ( rng , * rv_inputs ). outputs
272
+ def loop_fn (truncated_rv , reject_draws , lower , upper , size , dtype , * rv_inputs ):
273
+ new_truncated_rv = srng . gen ( rv .owner .op , * rv_inputs , size = size , dtype = dtype ) # type: ignore
270
274
truncated_rv = at .set_subtensor (
271
275
truncated_rv [reject_draws ],
272
276
new_truncated_rv [reject_draws ],
273
277
)
274
278
reject_draws = at .or_ ((truncated_rv < lower ), (truncated_rv > upper ))
275
279
276
- return (
277
- (truncated_rv , reject_draws ),
278
- [(rng , next_rng )],
279
- until (~ at .any (reject_draws )),
280
- )
280
+ return (truncated_rv , reject_draws ), until (~ at .any (reject_draws ))
281
281
282
282
(truncated_rv_ , reject_draws_ ), updates = scan (
283
283
loop_fn ,
284
284
outputs_info = [
285
285
at .zeros_like (rv_ ),
286
286
at .ones_like (rv_ , dtype = bool ),
287
287
],
288
- non_sequences = [lower_ , upper_ , rng , * rv_inputs_ ],
288
+ non_sequences = [lower_ , upper_ , size_ , dtype_ , * rv_inputs_ ],
289
289
n_steps = max_n_steps ,
290
290
strict = True ,
291
291
)
@@ -299,18 +299,28 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
299
299
truncated_rv = TruncatedRV (
300
300
base_rv_op = rv .owner .op ,
301
301
inputs = graph_inputs_ ,
302
- outputs = [truncated_rv_ , tuple (updates .values ())[0 ]],
302
+ # This will fail with `n_steps==1`, because in that case `Scan` won't return any updates
303
+ outputs = [truncated_rv_ , rv_ .owner .outputs [0 ], tuple (updates .values ())[0 ]],
303
304
inline = True ,
304
305
)(* graph_inputs )
305
- updates = {truncated_rv .owner .inputs [- 1 ]: truncated_rv .owner .outputs [- 1 ]}
306
+ # TODO: Is the order of multiple shared variables determnistic?
307
+ assert truncated_rv .owner .inputs [- 2 ] is rv_ .owner .inputs [0 ]
308
+ updates = {
309
+ truncated_rv .owner .inputs [- 2 ]: truncated_rv .owner .outputs [- 2 ],
310
+ truncated_rv .owner .inputs [- 1 ]: truncated_rv .owner .outputs [- 1 ],
311
+ }
306
312
return truncated_rv , updates
307
313
308
314
309
315
@_logprob .register (TruncatedRV )
310
316
def truncated_logprob (op , values , * inputs , ** kwargs ):
311
317
(value ,) = values
312
318
313
- * rv_inputs , lower_bound , upper_bound , rng = inputs
319
+ # Rejection sample graph has two rngs
320
+ if len (op .shared_inputs ) == 2 :
321
+ * rv_inputs , lower_bound , upper_bound , _ , rng = inputs
322
+ else :
323
+ * rv_inputs , lower_bound , upper_bound , rng = inputs
314
324
rv_inputs = [rng , * rv_inputs ]
315
325
316
326
base_rv_op = op .base_rv_op
@@ -361,11 +371,11 @@ def truncated_logprob(op, values, *inputs, **kwargs):
361
371
362
372
363
373
@_truncated .register (arb .UniformRV )
364
- def uniform_truncated (op , lower , upper , rng , size , dtype , lower_orig , upper_orig ):
365
- truncated_uniform = at .random .uniform (
374
+ def uniform_truncated (op , lower , upper , srng , size , dtype , lower_orig , upper_orig ):
375
+ truncated_uniform = srng .gen (
376
+ op ,
366
377
at .max ((lower_orig , lower )),
367
378
at .min ((upper_orig , upper )),
368
- rng = rng ,
369
379
size = size ,
370
380
dtype = dtype ,
371
381
)
0 commit comments