Skip to content

Commit 564569f

Browse files
Ricardobrandonwillard
authored andcommitted
Use RandomStream in truncate
1 parent 2c3d461 commit 564569f

File tree

2 files changed

+42
-32
lines changed

2 files changed

+42
-32
lines changed

aeppl/truncation.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from functools import singledispatch
2-
from typing import Tuple
2+
from typing import Optional, Tuple
33

44
import aesara.tensor as at
55
import aesara.tensor.random.basic as arb
66
import numpy as np
7-
from aesara import scan, shared
7+
from aesara import scan
88
from aesara.compile.builders import OpFromGraph
99
from aesara.graph.op import Op
1010
from aesara.raise_op import CheckAndRaise
1111
from aesara.scan import until
12+
from aesara.tensor.random import RandomStream
1213
from aesara.tensor.random.op import RandomVariable
1314
from aesara.tensor.var import TensorConstant, TensorVariable
1415

@@ -68,7 +69,11 @@ def __str__(self):
6869

6970

7071
def truncate(
71-
rv: TensorVariable, lower=None, upper=None, max_n_steps: int = 10_000, rng=None
72+
rv: TensorVariable,
73+
lower=None,
74+
upper=None,
75+
max_n_steps: int = 10_000,
76+
srng: Optional[RandomStream] = None,
7277
) -> Tuple[TensorVariable, Tuple[TensorVariable, TensorVariable]]:
7378
"""Truncate a univariate `RandomVariable` between `lower` and `upper`.
7479
@@ -99,13 +104,13 @@ def truncate(
99104
lower = at.as_tensor_variable(lower) if lower is not None else at.constant(-np.inf)
100105
upper = at.as_tensor_variable(upper) if upper is not None else at.constant(np.inf)
101106

102-
if rng is None:
103-
rng = shared(np.random.RandomState(), borrow=True)
107+
if srng is None:
108+
srng = RandomStream()
104109

105110
# Try to use specialized Op
106111
try:
107112
truncated_rv, updates = _truncated(
108-
rv.owner.op, lower, upper, rng, *rv.owner.inputs[1:]
113+
rv.owner.op, lower, upper, srng, *rv.owner.inputs[1:]
109114
)
110115
return truncated_rv, updates
111116
except NotImplementedError:
@@ -116,8 +121,8 @@ def truncate(
116121
# though it would not be necessary for the icdf OpFromGraph
117122
graph_inputs = [*rv.owner.inputs[1:], lower, upper]
118123
graph_inputs_ = [inp.type() for inp in graph_inputs]
119-
*rv_inputs_, lower_, upper_ = graph_inputs_
120-
rv_ = rv.owner.op.make_node(rng, *rv_inputs_).default_output()
124+
size_, dtype_, *rv_inputs_, lower_, upper_ = graph_inputs_
125+
rv_ = srng.gen(rv.owner.op, *rv_inputs_, size=size_, dtype=dtype_)
121126

122127
# Try to use inverted cdf sampling
123128
try:
@@ -126,11 +131,10 @@ def truncate(
126131
lower_value = lower_ - 1 if rv.owner.op.dtype.startswith("int") else lower_
127132
cdf_lower_ = at.exp(logcdf(rv_, lower_value))
128133
cdf_upper_ = at.exp(logcdf(rv_, upper_))
129-
uniform_ = at.random.uniform(
134+
uniform_ = srng.uniform(
130135
cdf_lower_,
131136
cdf_upper_,
132-
rng=rng,
133-
size=rv_inputs_[0],
137+
size=size_,
134138
)
135139
truncated_rv_ = icdf(rv_, uniform_)
136140
truncated_rv = TruncatedRV(
@@ -146,27 +150,23 @@ def truncate(
146150

147151
# Fallback to rejection sampling
148152
# TODO: Handle potential broadcast by lower / upper
149-
def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
150-
next_rng, new_truncated_rv = rv.owner.op.make_node(rng, *rv_inputs).outputs
153+
def loop_fn(truncated_rv, reject_draws, lower, upper, size, dtype, *rv_inputs):
154+
new_truncated_rv = srng.gen(rv.owner.op, *rv_inputs, size=size, dtype=dtype) # type: ignore
151155
truncated_rv = at.set_subtensor(
152156
truncated_rv[reject_draws],
153157
new_truncated_rv[reject_draws],
154158
)
155159
reject_draws = at.or_((truncated_rv < lower), (truncated_rv > upper))
156160

157-
return (
158-
(truncated_rv, reject_draws),
159-
[(rng, next_rng)],
160-
until(~at.any(reject_draws)),
161-
)
161+
return (truncated_rv, reject_draws), until(~at.any(reject_draws))
162162

163163
(truncated_rv_, reject_draws_), updates = scan(
164164
loop_fn,
165165
outputs_info=[
166166
at.zeros_like(rv_),
167167
at.ones_like(rv_, dtype=bool),
168168
],
169-
non_sequences=[lower_, upper_, rng, *rv_inputs_],
169+
non_sequences=[lower_, upper_, size_, dtype_, *rv_inputs_],
170170
n_steps=max_n_steps,
171171
strict=True,
172172
)
@@ -180,18 +180,28 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
180180
truncated_rv = TruncatedRV(
181181
base_rv_op=rv.owner.op,
182182
inputs=graph_inputs_,
183-
outputs=[truncated_rv_, tuple(updates.values())[0]],
183+
# This will fail with `n_steps==1`, because in that case `Scan` won't return any updates
184+
outputs=[truncated_rv_, rv_.owner.outputs[0], tuple(updates.values())[0]],
184185
inline=True,
185186
)(*graph_inputs)
186-
updates = {truncated_rv.owner.inputs[-1]: truncated_rv.owner.outputs[-1]}
187+
# TODO: Is the order of multiple shared variables determnistic?
188+
assert truncated_rv.owner.inputs[-2] is rv_.owner.inputs[0]
189+
updates = {
190+
truncated_rv.owner.inputs[-2]: truncated_rv.owner.outputs[-2],
191+
truncated_rv.owner.inputs[-1]: truncated_rv.owner.outputs[-1],
192+
}
187193
return truncated_rv, updates
188194

189195

190196
@_logprob.register(TruncatedRV)
191197
def truncated_logprob(op, values, *inputs, **kwargs):
192198
(value,) = values
193199

194-
*rv_inputs, lower_bound, upper_bound, rng = inputs
200+
# Rejection sample graph has two rngs
201+
if len(op.shared_inputs) == 2:
202+
*rv_inputs, lower_bound, upper_bound, _, rng = inputs
203+
else:
204+
*rv_inputs, lower_bound, upper_bound, rng = inputs
195205
rv_inputs = [rng, *rv_inputs]
196206

197207
base_rv_op = op.base_rv_op
@@ -242,11 +252,11 @@ def truncated_logprob(op, values, *inputs, **kwargs):
242252

243253

244254
@_truncated.register(arb.UniformRV)
245-
def uniform_truncated(op, lower, upper, rng, size, dtype, lower_orig, upper_orig):
246-
truncated_uniform = at.random.uniform(
255+
def uniform_truncated(op, lower, upper, srng, size, dtype, lower_orig, upper_orig):
256+
truncated_uniform = srng.gen(
257+
op,
247258
at.max((lower_orig, lower)),
248259
at.min((upper_orig, upper)),
249-
rng=rng,
250260
size=size,
251261
dtype=dtype,
252262
)

tests/test_truncation.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ def _icdf_not_implemented(*args, **kwargs):
5050
def test_truncation_specialized_op():
5151
x = at.random.uniform(0, 10, name="x", size=100)
5252

53-
rng = aesara.shared(np.random.RandomState())
54-
xt, _ = truncate(x, lower=5, upper=15, rng=rng)
53+
srng = at.random.RandomStream()
54+
xt, _ = truncate(x, lower=5, upper=15, srng=srng)
5555
assert isinstance(xt.owner.op, UniformRV)
56-
assert xt.owner.inputs[0] is rng
56+
assert xt.owner.inputs[0] is srng.updates()[0][0]
5757

5858
lower_upper = at.stack(xt.owner.inputs[3:])
5959
assert np.all(lower_upper.eval() == [5, 10])
@@ -68,10 +68,10 @@ def test_truncation_continuous_random(op_type, lower, upper):
6868
normal_op = icdf_normal if op_type == "icdf" else rejection_normal
6969
x = normal_op(loc, scale, name="x", size=100)
7070

71-
rng = aesara.shared(np.random.RandomState())
72-
xt, xt_update = truncate(x, lower=lower, upper=upper, rng=rng)
71+
srng = at.random.RandomStream()
72+
xt, xt_update = truncate(x, lower=lower, upper=upper, srng=srng)
7373
assert isinstance(xt.owner.op, TruncatedRV)
74-
assert xt.owner.inputs[-1] is rng
74+
assert xt.owner.inputs[-1] is srng.updates()[1 if op_type == "icdf" else 2][0]
7575
assert xt.type.dtype == x.type.dtype
7676
assert xt.type.ndim == x.type.ndim
7777

@@ -94,7 +94,7 @@ def test_truncation_continuous_random(op_type, lower, upper):
9494
assert scipy.stats.cramervonmises(xt_draws.ravel(), ref_xt.cdf).pvalue > 0.001
9595

9696
# Test max_n_steps
97-
xt, xt_update = truncate(x, lower=lower, upper=upper, max_n_steps=1)
97+
xt, xt_update = truncate(x, lower=lower, upper=upper, max_n_steps=2)
9898
xt_fn = aesara.function([], xt, updates=xt_update)
9999
if op_type == "icdf":
100100
xt_draws = xt_fn()

0 commit comments

Comments
 (0)