Skip to content

Commit 4e7852c

Browse files
author
Ricardo
committed
Use RandomStream in truncate
1 parent 50f15b4 commit 4e7852c

File tree

2 files changed

+41
-31
lines changed

2 files changed

+41
-31
lines changed

aeppl/truncation.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import aesara.tensor as at
55
import aesara.tensor.random.basic as arb
66
import numpy as np
7-
from aesara import scan, shared
7+
from aesara import scan
88
from aesara.compile.builders import OpFromGraph
99
from aesara.graph.basic import Node
1010
from aesara.graph.fg import FunctionGraph
@@ -15,6 +15,7 @@
1515
from aesara.scalar.basic import clip as scalar_clip
1616
from aesara.scan import until
1717
from aesara.tensor.elemwise import Elemwise
18+
from aesara.tensor.random import RandomStream
1819
from aesara.tensor.random.op import RandomVariable
1920
from aesara.tensor.var import TensorConstant, TensorVariable
2021

@@ -188,7 +189,11 @@ def __str__(self):
188189

189190

190191
def truncate(
191-
rv: TensorVariable, lower=None, upper=None, max_n_steps: int = 10_000, rng=None
192+
rv: TensorVariable,
193+
lower=None,
194+
upper=None,
195+
max_n_steps: int = 10_000,
196+
srng: Optional[RandomStream] = None,
192197
) -> Tuple[TensorVariable, Tuple[TensorVariable, TensorVariable]]:
193198
"""Truncate a univariate `RandomVariable` between lower and upper.
194199
@@ -218,13 +223,13 @@ def truncate(
218223
lower = at.as_tensor_variable(lower) if lower is not None else at.constant(-np.inf)
219224
upper = at.as_tensor_variable(upper) if upper is not None else at.constant(np.inf)
220225

221-
if rng is None:
222-
rng = shared(np.random.RandomState(), borrow=True)
226+
if srng is None:
227+
srng = RandomStream()
223228

224229
# Try to use specialized Op
225230
try:
226231
truncated_rv, updates = _truncated(
227-
rv.owner.op, lower, upper, rng, *rv.owner.inputs[1:]
232+
rv.owner.op, lower, upper, srng, *rv.owner.inputs[1:]
228233
)
229234
return truncated_rv, updates
230235
except NotImplementedError:
@@ -235,8 +240,8 @@ def truncate(
235240
# though it would not be necessary for the icdf OpFromGraph
236241
graph_inputs = [*rv.owner.inputs[1:], lower, upper]
237242
graph_inputs_ = [inp.type() for inp in graph_inputs]
238-
*rv_inputs_, lower_, upper_ = graph_inputs_
239-
rv_ = rv.owner.op.make_node(rng, *rv_inputs_).default_output()
243+
size_, dtype_, *rv_inputs_, lower_, upper_ = graph_inputs_
244+
rv_ = srng.gen(rv.owner.op, *rv_inputs_, size=size_, dtype=dtype_)
240245

241246
# Try to use inverted cdf sampling
242247
try:
@@ -245,11 +250,10 @@ def truncate(
245250
lower_value = lower_ - 1 if rv.owner.op.dtype.startswith("int") else lower_
246251
cdf_lower_ = at.exp(logcdf(rv_, lower_value))
247252
cdf_upper_ = at.exp(logcdf(rv_, upper_))
248-
uniform_ = at.random.uniform(
253+
uniform_ = srng.uniform(
249254
cdf_lower_,
250255
cdf_upper_,
251-
rng=rng,
252-
size=rv_inputs_[0],
256+
size=size_,
253257
)
254258
truncated_rv_ = icdf(rv_, uniform_)
255259
truncated_rv = TruncatedRV(
@@ -265,27 +269,23 @@ def truncate(
265269

266270
# Fallback to rejection sampling
267271
# TODO: Handle potential broadcast by lower / upper
268-
def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
269-
next_rng, new_truncated_rv = rv.owner.op.make_node(rng, *rv_inputs).outputs
272+
def loop_fn(truncated_rv, reject_draws, lower, upper, size, dtype, *rv_inputs):
273+
new_truncated_rv = srng.gen(rv.owner.op, *rv_inputs, size=size, dtype=dtype) # type: ignore
270274
truncated_rv = at.set_subtensor(
271275
truncated_rv[reject_draws],
272276
new_truncated_rv[reject_draws],
273277
)
274278
reject_draws = at.or_((truncated_rv < lower), (truncated_rv > upper))
275279

276-
return (
277-
(truncated_rv, reject_draws),
278-
[(rng, next_rng)],
279-
until(~at.any(reject_draws)),
280-
)
280+
return (truncated_rv, reject_draws), until(~at.any(reject_draws))
281281

282282
(truncated_rv_, reject_draws_), updates = scan(
283283
loop_fn,
284284
outputs_info=[
285285
at.zeros_like(rv_),
286286
at.ones_like(rv_, dtype=bool),
287287
],
288-
non_sequences=[lower_, upper_, rng, *rv_inputs_],
288+
non_sequences=[lower_, upper_, size_, dtype_, *rv_inputs_],
289289
n_steps=max_n_steps,
290290
strict=True,
291291
)
@@ -299,18 +299,28 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
299299
truncated_rv = TruncatedRV(
300300
base_rv_op=rv.owner.op,
301301
inputs=graph_inputs_,
302-
outputs=[truncated_rv_, tuple(updates.values())[0]],
302+
# This will fail with `n_steps==1`, because in that case `Scan` won't return any updates
303+
outputs=[truncated_rv_, rv_.owner.outputs[0], tuple(updates.values())[0]],
303304
inline=True,
304305
)(*graph_inputs)
305-
updates = {truncated_rv.owner.inputs[-1]: truncated_rv.owner.outputs[-1]}
306+
# TODO: Is the order of multiple shared variables determnistic?
307+
assert truncated_rv.owner.inputs[-2] is rv_.owner.inputs[0]
308+
updates = {
309+
truncated_rv.owner.inputs[-2]: truncated_rv.owner.outputs[-2],
310+
truncated_rv.owner.inputs[-1]: truncated_rv.owner.outputs[-1],
311+
}
306312
return truncated_rv, updates
307313

308314

309315
@_logprob.register(TruncatedRV)
310316
def truncated_logprob(op, values, *inputs, **kwargs):
311317
(value,) = values
312318

313-
*rv_inputs, lower_bound, upper_bound, rng = inputs
319+
# Rejection sample graph has two rngs
320+
if len(op.shared_inputs) == 2:
321+
*rv_inputs, lower_bound, upper_bound, _, rng = inputs
322+
else:
323+
*rv_inputs, lower_bound, upper_bound, rng = inputs
314324
rv_inputs = [rng, *rv_inputs]
315325

316326
base_rv_op = op.base_rv_op
@@ -361,11 +371,11 @@ def truncated_logprob(op, values, *inputs, **kwargs):
361371

362372

363373
@_truncated.register(arb.UniformRV)
364-
def uniform_truncated(op, lower, upper, rng, size, dtype, lower_orig, upper_orig):
365-
truncated_uniform = at.random.uniform(
374+
def uniform_truncated(op, lower, upper, srng, size, dtype, lower_orig, upper_orig):
375+
truncated_uniform = srng.gen(
376+
op,
366377
at.max((lower_orig, lower)),
367378
at.min((upper_orig, upper)),
368-
rng=rng,
369379
size=size,
370380
dtype=dtype,
371381
)

tests/test_truncation.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,10 @@ def _icdf_not_implemented(*args, **kwargs):
233233
def test_truncation_specialized_op():
234234
x = at.random.uniform(0, 10, name="x", size=100)
235235

236-
rng = aesara.shared(np.random.RandomState())
237-
xt, _ = truncate(x, lower=5, upper=15, rng=rng)
236+
srng = at.random.RandomStream()
237+
xt, _ = truncate(x, lower=5, upper=15, srng=srng)
238238
assert isinstance(xt.owner.op, UniformRV)
239-
assert xt.owner.inputs[0] is rng
239+
assert xt.owner.inputs[0] is srng.updates()[0][0]
240240

241241
lower_upper = at.stack(xt.owner.inputs[3:])
242242
assert np.all(lower_upper.eval() == [5, 10])
@@ -250,10 +250,10 @@ def test_truncation_continuous_random(op_type, lower, upper):
250250
normal_op = icdf_normal if op_type == "icdf" else rejection_normal
251251
x = normal_op(loc, scale, name="x", size=100)
252252

253-
rng = aesara.shared(np.random.RandomState())
254-
xt, xt_update = truncate(x, lower=lower, upper=upper, rng=rng)
253+
srng = at.random.RandomStream()
254+
xt, xt_update = truncate(x, lower=lower, upper=upper, srng=srng)
255255
assert isinstance(xt.owner.op, TruncatedRV)
256-
assert xt.owner.inputs[-1] is rng
256+
assert xt.owner.inputs[-1] is srng.updates()[1 if op_type == "icdf" else 2][0]
257257
assert xt.type == x.type
258258

259259
# Check that original op can be used on its own
@@ -275,7 +275,7 @@ def test_truncation_continuous_random(op_type, lower, upper):
275275
assert scipy.stats.cramervonmises(xt_draws.ravel(), ref_xt.cdf).pvalue > 0.001
276276

277277
# Test max_n_steps
278-
xt, xt_update = truncate(x, lower=lower, upper=upper, max_n_steps=1)
278+
xt, xt_update = truncate(x, lower=lower, upper=upper, max_n_steps=2)
279279
xt_fn = aesara.function([], xt, updates=xt_update)
280280
if op_type == "icdf":
281281
xt_draws = xt_fn()

0 commit comments

Comments
 (0)