Skip to content

Commit 865d6d3

Browse files
Ricardobrandonwillard
authored andcommitted
Add an interface and basic implementations for truncated variables
1 parent 7890228 commit 865d6d3

File tree

2 files changed

+486
-0
lines changed

2 files changed

+486
-0
lines changed

aeppl/truncation.py

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
from functools import singledispatch
2+
from typing import Optional, Tuple
3+
4+
import aesara.tensor as at
5+
import aesara.tensor.random.basic as arb
6+
import numpy as np
7+
from aesara import scan
8+
from aesara.compile.builders import OpFromGraph
9+
from aesara.graph.op import Op
10+
from aesara.raise_op import CheckAndRaise
11+
from aesara.scan import until
12+
from aesara.tensor.random import RandomStream
13+
from aesara.tensor.random.op import RandomVariable
14+
from aesara.tensor.var import TensorConstant, TensorVariable
15+
16+
from aeppl.abstract import MeasurableVariable, _get_measurable_outputs
17+
from aeppl.logprob import (
18+
CheckParameterValue,
19+
_logcdf,
20+
_logprob,
21+
icdf,
22+
logcdf,
23+
logdiffexp,
24+
)
25+
26+
27+
class TruncatedRV(OpFromGraph):
28+
"""An `Op` constructed from an Aesara graph that represents a truncated univariate random variable."""
29+
30+
default_output = 0
31+
base_rv_op = None
32+
33+
def __init__(self, base_rv_op: Op, *args, **kwargs):
34+
self.base_rv_op = base_rv_op
35+
super().__init__(*args, **kwargs)
36+
37+
38+
MeasurableVariable.register(TruncatedRV)
39+
40+
41+
@_get_measurable_outputs.register(TruncatedRV)
42+
def _get_measurable_outputs_TruncatedRV(op, node):
43+
return [node.outputs[0]]
44+
45+
46+
@singledispatch
47+
def _truncated(op: Op, lower, upper, *params):
48+
"""Return the truncated equivalent of another `RandomVariable`."""
49+
raise NotImplementedError(
50+
f"{op} does not have an equivalent truncated version implemented"
51+
)
52+
53+
54+
class TruncationError(Exception):
55+
"""Exception for errors generated from truncated graphs"""
56+
57+
58+
class TruncationCheck(CheckAndRaise):
59+
"""Implements a check in truncated graphs.
60+
61+
Raises `TruncationError` if the check is not True.
62+
"""
63+
64+
def __init__(self, msg=""):
65+
super().__init__(TruncationError, msg)
66+
67+
def __str__(self):
68+
return f"TruncationCheck{{{self.msg}}}"
69+
70+
71+
def truncate(
72+
rv: TensorVariable,
73+
lower=None,
74+
upper=None,
75+
max_n_steps: int = 10_000,
76+
srng: Optional[RandomStream] = None,
77+
) -> Tuple[TensorVariable, Tuple[TensorVariable, TensorVariable]]:
78+
"""Truncate a univariate `RandomVariable` between `lower` and `upper`.
79+
80+
If `lower` or `upper` is ``None``, the variable is not truncated on that side.
81+
82+
Depending on whether or not a dispatch implementation is available, this
83+
function returns either a specialized `Op`, or an equivalent graph
84+
representing the truncation process via inverse CDF or rejection
85+
sampling.
86+
87+
The argument `max_n_steps` controls the maximum number of resamples that are
88+
attempted when performing rejection sampling. A `TruncationError` is raised if
89+
convergence is not reached after that many steps.
90+
91+
Returns
92+
=======
93+
`TensorVariable` graph representing the truncated `RandomVariable` and respective updates
94+
"""
95+
96+
if lower is None and upper is None:
97+
raise ValueError("lower and upper cannot both be None")
98+
99+
if not (isinstance(rv.owner.op, RandomVariable) and rv.owner.op.ndim_supp == 0):
100+
raise NotImplementedError(
101+
f"Truncation is only implemented for univariate random variables, got {rv.owner.op}"
102+
)
103+
104+
lower = at.as_tensor_variable(lower) if lower is not None else at.constant(-np.inf)
105+
upper = at.as_tensor_variable(upper) if upper is not None else at.constant(np.inf)
106+
107+
if srng is None:
108+
srng = RandomStream()
109+
110+
# Try to use specialized Op
111+
try:
112+
truncated_rv, updates = _truncated(
113+
rv.owner.op, lower, upper, srng, *rv.owner.inputs[1:]
114+
)
115+
return truncated_rv, updates
116+
except NotImplementedError:
117+
pass
118+
119+
# Variables with `_` suffix identify dummy inputs for the OpFromGraph
120+
# We will use the Shared RNG variable directly because Scan demands it, even
121+
# though it would not be necessary for the icdf OpFromGraph
122+
graph_inputs = [*rv.owner.inputs[1:], lower, upper]
123+
graph_inputs_ = [inp.type() for inp in graph_inputs]
124+
size_, dtype_, *rv_inputs_, lower_, upper_ = graph_inputs_
125+
rv_ = srng.gen(rv.owner.op, *rv_inputs_, size=size_, dtype=dtype_)
126+
127+
# Try to use inverted cdf sampling
128+
try:
129+
# For left truncated discrete RVs, we need to include the whole lower bound.
130+
# This may result in draws below the truncation range, if any uniform == 0
131+
lower_value = lower_ - 1 if rv.owner.op.dtype.startswith("int") else lower_
132+
cdf_lower_ = at.exp(logcdf(rv_, lower_value))
133+
cdf_upper_ = at.exp(logcdf(rv_, upper_))
134+
uniform_ = srng.uniform(
135+
cdf_lower_,
136+
cdf_upper_,
137+
size=size_,
138+
)
139+
truncated_rv_ = icdf(rv_, uniform_)
140+
truncated_rv = TruncatedRV(
141+
base_rv_op=rv.owner.op,
142+
inputs=graph_inputs_,
143+
outputs=[truncated_rv_, uniform_.owner.outputs[0]],
144+
inline=True,
145+
)(*graph_inputs)
146+
updates = {truncated_rv.owner.inputs[-1]: truncated_rv.owner.outputs[-1]}
147+
return truncated_rv, updates
148+
except NotImplementedError:
149+
pass
150+
151+
# Fallback to rejection sampling
152+
# TODO: Handle potential broadcast by lower / upper
153+
def loop_fn(truncated_rv, reject_draws, lower, upper, size, dtype, *rv_inputs):
154+
new_truncated_rv = srng.gen(rv.owner.op, *rv_inputs, size=size, dtype=dtype) # type: ignore
155+
truncated_rv = at.set_subtensor(
156+
truncated_rv[reject_draws],
157+
new_truncated_rv[reject_draws],
158+
)
159+
reject_draws = at.or_((truncated_rv < lower), (truncated_rv > upper))
160+
161+
return (truncated_rv, reject_draws), until(~at.any(reject_draws))
162+
163+
(truncated_rv_, reject_draws_), updates = scan(
164+
loop_fn,
165+
outputs_info=[
166+
at.zeros_like(rv_),
167+
at.ones_like(rv_, dtype=bool),
168+
],
169+
non_sequences=[lower_, upper_, size_, dtype_, *rv_inputs_],
170+
n_steps=max_n_steps,
171+
strict=True,
172+
)
173+
174+
truncated_rv_ = truncated_rv_[-1]
175+
convergence_ = ~at.any(reject_draws_[-1])
176+
truncated_rv_ = TruncationCheck(
177+
f"Truncation did not converge in {max_n_steps} steps"
178+
)(truncated_rv_, convergence_)
179+
180+
truncated_rv = TruncatedRV(
181+
base_rv_op=rv.owner.op,
182+
inputs=graph_inputs_,
183+
# This will fail with `n_steps==1`, because in that case `Scan` won't return any updates
184+
outputs=[truncated_rv_, rv_.owner.outputs[0], tuple(updates.values())[0]],
185+
inline=True,
186+
)(*graph_inputs)
187+
# TODO: Is the order of multiple shared variables determnistic?
188+
assert truncated_rv.owner.inputs[-2] is rv_.owner.inputs[0]
189+
updates = {
190+
truncated_rv.owner.inputs[-2]: truncated_rv.owner.outputs[-2],
191+
truncated_rv.owner.inputs[-1]: truncated_rv.owner.outputs[-1],
192+
}
193+
return truncated_rv, updates
194+
195+
196+
@_logprob.register(TruncatedRV)
197+
def truncated_logprob(op, values, *inputs, **kwargs):
198+
(value,) = values
199+
200+
# Rejection sample graph has two rngs
201+
if len(op.shared_inputs) == 2:
202+
*rv_inputs, lower_bound, upper_bound, _, rng = inputs
203+
else:
204+
*rv_inputs, lower_bound, upper_bound, rng = inputs
205+
rv_inputs = [rng, *rv_inputs]
206+
207+
base_rv_op = op.base_rv_op
208+
logp = _logprob(base_rv_op, (value,), *rv_inputs, **kwargs)
209+
# For left truncated RVs, we don't want to include the lower bound in the
210+
# normalization term
211+
lower_bound_value = (
212+
lower_bound - 1 if base_rv_op.dtype.startswith("int") else lower_bound
213+
)
214+
lower_logcdf = _logcdf(base_rv_op, lower_bound_value, *rv_inputs, **kwargs)
215+
upper_logcdf = _logcdf(base_rv_op, upper_bound, *rv_inputs, **kwargs)
216+
217+
if base_rv_op.name:
218+
logp.name = f"{base_rv_op}_logprob"
219+
lower_logcdf.name = f"{base_rv_op}_lower_logcdf"
220+
upper_logcdf.name = f"{base_rv_op}_upper_logcdf"
221+
222+
is_lower_bounded = not (
223+
isinstance(lower_bound, TensorConstant)
224+
and np.all(np.isneginf(lower_bound.value))
225+
)
226+
is_upper_bounded = not (
227+
isinstance(upper_bound, TensorConstant) and np.all(np.isinf(upper_bound.value))
228+
)
229+
230+
lognorm = 0
231+
if is_lower_bounded and is_upper_bounded:
232+
lognorm = logdiffexp(upper_logcdf, lower_logcdf)
233+
elif is_lower_bounded:
234+
lognorm = at.log1mexp(lower_logcdf)
235+
elif is_upper_bounded:
236+
lognorm = upper_logcdf
237+
238+
logp = logp - lognorm
239+
240+
if is_lower_bounded:
241+
logp = at.switch(value < lower_bound, -np.inf, logp)
242+
243+
if is_upper_bounded:
244+
logp = at.switch(value <= upper_bound, logp, -np.inf)
245+
246+
if is_lower_bounded and is_upper_bounded:
247+
logp = CheckParameterValue("lower_bound <= upper_bound")(
248+
logp, at.all(at.le(lower_bound, upper_bound))
249+
)
250+
251+
return logp
252+
253+
254+
@_truncated.register(arb.UniformRV)
255+
def uniform_truncated(op, lower, upper, srng, size, dtype, lower_orig, upper_orig):
256+
truncated_uniform = srng.gen(
257+
op,
258+
at.max((lower_orig, lower)),
259+
at.min((upper_orig, upper)),
260+
size=size,
261+
dtype=dtype,
262+
)
263+
return truncated_uniform, {
264+
truncated_uniform.owner.inputs[0]: truncated_uniform.owner.outputs[0]
265+
}

0 commit comments

Comments
 (0)