Skip to content

Commit 50f15b4

Browse files
author
Ricardo
committed
Implement truncated variables
1 parent 6f37386 commit 50f15b4

File tree

3 files changed

+474
-5
lines changed

3 files changed

+474
-5
lines changed

aeppl/logprob.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ def xlogy0(m, x):
4545
return at.switch(at.eq(x, 0), at.switch(at.eq(m, 0), 0.0, -np.inf), m * at.log(x))
4646

4747

48+
def logdiffexp(a, b):
49+
"""log(exp(a) - exp(b))"""
50+
# TODO: This should be a basic Aesara stabilization
51+
return a + at.log1mexp(b - a)
52+
53+
4854
def logprob(rv_var, *rv_values, **kwargs):
4955
"""Create a graph for the log-probability of a ``RandomVariable``."""
5056
logprob = _logprob(rv_var.owner.op, rv_values, *rv_var.owner.inputs, **kwargs)

aeppl/truncation.py

Lines changed: 254 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,36 @@
1-
from typing import List, Optional
1+
from functools import singledispatch
2+
from typing import List, Optional, Tuple
23

34
import aesara.tensor as at
5+
import aesara.tensor.random.basic as arb
46
import numpy as np
7+
from aesara import scan, shared
8+
from aesara.compile.builders import OpFromGraph
59
from aesara.graph.basic import Node
610
from aesara.graph.fg import FunctionGraph
11+
from aesara.graph.op import Op
712
from aesara.graph.opt import local_optimizer
13+
from aesara.raise_op import CheckAndRaise
814
from aesara.scalar.basic import Clip
915
from aesara.scalar.basic import clip as scalar_clip
16+
from aesara.scan import until
1017
from aesara.tensor.elemwise import Elemwise
11-
from aesara.tensor.var import TensorConstant
18+
from aesara.tensor.random.op import RandomVariable
19+
from aesara.tensor.var import TensorConstant, TensorVariable
1220

13-
from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs
14-
from aeppl.logprob import CheckParameterValue, _logcdf, _logprob
21+
from aeppl.abstract import (
22+
MeasurableVariable,
23+
_get_measurable_outputs,
24+
assign_custom_measurable_outputs,
25+
)
26+
from aeppl.logprob import (
27+
CheckParameterValue,
28+
_logcdf,
29+
_logprob,
30+
icdf,
31+
logcdf,
32+
logdiffexp,
33+
)
1534
from aeppl.opt import measurable_ir_rewrites_db
1635

1736

@@ -122,3 +141,234 @@ def censor_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
122141
)
123142

124143
return logprob
144+
145+
146+
class TruncatedRV(OpFromGraph):
147+
"""An `Op` constructed from an Aesara graph that represents a truncated univariate random variable."""
148+
149+
default_output = 0
150+
base_rv_op = None
151+
152+
def __init__(self, base_rv_op: Op, *args, **kwargs):
153+
self.base_rv_op = base_rv_op
154+
super().__init__(*args, **kwargs)
155+
156+
157+
MeasurableVariable.register(TruncatedRV)
158+
159+
160+
@_get_measurable_outputs.register(TruncatedRV)
161+
def _get_measurable_outputs_TruncatedRV(op, node):
162+
return [node.outputs[0]]
163+
164+
165+
@singledispatch
166+
def _truncated(op: Op, lower, upper, *params):
167+
"""Return the truncated equivalent of another `RandomVariable`."""
168+
raise NotImplementedError(
169+
f"{op} does not have an equivalent truncated version implemented"
170+
)
171+
172+
173+
class TruncationError(Exception):
174+
"""Exception for errors generated from truncated graphs"""
175+
176+
177+
class TruncationCheck(CheckAndRaise):
178+
"""Implements a check in truncated graphs.
179+
180+
Raises `TruncationError` if the check is not True.
181+
"""
182+
183+
def __init__(self, msg=""):
184+
super().__init__(TruncationError, msg)
185+
186+
def __str__(self):
187+
return f"TruncationCheck{{{self.msg}}}"
188+
189+
190+
def truncate(
191+
rv: TensorVariable, lower=None, upper=None, max_n_steps: int = 10_000, rng=None
192+
) -> Tuple[TensorVariable, Tuple[TensorVariable, TensorVariable]]:
193+
"""Truncate a univariate `RandomVariable` between lower and upper.
194+
195+
If lower or upper is ``None``, the variable is not truncated on that side.
196+
197+
Depending on dispatched implementations, this function returns either a specialized
198+
`Op`, or equivalent graph representing the truncation process, via inverse CDF
199+
sampling, or rejection sampling.
200+
201+
The argument `max_n_steps` controls the maximum number of resamples that are
202+
attempted when performing rejection sampling. A `TruncationError` is raised if
203+
convergence is not reached after that many steps.
204+
205+
Returns
206+
=======
207+
`TensorVariable` graph representing the truncated `RandomVariable` and respective updates
208+
"""
209+
210+
if lower is None and upper is None:
211+
raise ValueError("lower and upper cannot both be None")
212+
213+
if not (isinstance(rv.owner.op, RandomVariable) and rv.owner.op.ndim_supp == 0):
214+
raise NotImplementedError(
215+
f"Truncation is only implemented for univariate random variables, got {rv.owner.op}"
216+
)
217+
218+
lower = at.as_tensor_variable(lower) if lower is not None else at.constant(-np.inf)
219+
upper = at.as_tensor_variable(upper) if upper is not None else at.constant(np.inf)
220+
221+
if rng is None:
222+
rng = shared(np.random.RandomState(), borrow=True)
223+
224+
# Try to use specialized Op
225+
try:
226+
truncated_rv, updates = _truncated(
227+
rv.owner.op, lower, upper, rng, *rv.owner.inputs[1:]
228+
)
229+
return truncated_rv, updates
230+
except NotImplementedError:
231+
pass
232+
233+
# Variables with `_` suffix identify dummy inputs for the OpFromGraph
234+
# We will use the Shared RNG variable directly because Scan demands it, even
235+
# though it would not be necessary for the icdf OpFromGraph
236+
graph_inputs = [*rv.owner.inputs[1:], lower, upper]
237+
graph_inputs_ = [inp.type() for inp in graph_inputs]
238+
*rv_inputs_, lower_, upper_ = graph_inputs_
239+
rv_ = rv.owner.op.make_node(rng, *rv_inputs_).default_output()
240+
241+
# Try to use inverted cdf sampling
242+
try:
243+
# For left truncated discrete RVs, we need to include the whole lower bound.
244+
# This may result in draws below the truncation range, if any uniform == 0
245+
lower_value = lower_ - 1 if rv.owner.op.dtype.startswith("int") else lower_
246+
cdf_lower_ = at.exp(logcdf(rv_, lower_value))
247+
cdf_upper_ = at.exp(logcdf(rv_, upper_))
248+
uniform_ = at.random.uniform(
249+
cdf_lower_,
250+
cdf_upper_,
251+
rng=rng,
252+
size=rv_inputs_[0],
253+
)
254+
truncated_rv_ = icdf(rv_, uniform_)
255+
truncated_rv = TruncatedRV(
256+
base_rv_op=rv.owner.op,
257+
inputs=graph_inputs_,
258+
outputs=[truncated_rv_, uniform_.owner.outputs[0]],
259+
inline=True,
260+
)(*graph_inputs)
261+
updates = {truncated_rv.owner.inputs[-1]: truncated_rv.owner.outputs[-1]}
262+
return truncated_rv, updates
263+
except NotImplementedError:
264+
pass
265+
266+
# Fallback to rejection sampling
267+
# TODO: Handle potential broadcast by lower / upper
268+
def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
269+
next_rng, new_truncated_rv = rv.owner.op.make_node(rng, *rv_inputs).outputs
270+
truncated_rv = at.set_subtensor(
271+
truncated_rv[reject_draws],
272+
new_truncated_rv[reject_draws],
273+
)
274+
reject_draws = at.or_((truncated_rv < lower), (truncated_rv > upper))
275+
276+
return (
277+
(truncated_rv, reject_draws),
278+
[(rng, next_rng)],
279+
until(~at.any(reject_draws)),
280+
)
281+
282+
(truncated_rv_, reject_draws_), updates = scan(
283+
loop_fn,
284+
outputs_info=[
285+
at.zeros_like(rv_),
286+
at.ones_like(rv_, dtype=bool),
287+
],
288+
non_sequences=[lower_, upper_, rng, *rv_inputs_],
289+
n_steps=max_n_steps,
290+
strict=True,
291+
)
292+
293+
truncated_rv_ = truncated_rv_[-1]
294+
convergence_ = ~at.any(reject_draws_[-1])
295+
truncated_rv_ = TruncationCheck(
296+
f"Truncation did not converge in {max_n_steps} steps"
297+
)(truncated_rv_, convergence_)
298+
299+
truncated_rv = TruncatedRV(
300+
base_rv_op=rv.owner.op,
301+
inputs=graph_inputs_,
302+
outputs=[truncated_rv_, tuple(updates.values())[0]],
303+
inline=True,
304+
)(*graph_inputs)
305+
updates = {truncated_rv.owner.inputs[-1]: truncated_rv.owner.outputs[-1]}
306+
return truncated_rv, updates
307+
308+
309+
@_logprob.register(TruncatedRV)
310+
def truncated_logprob(op, values, *inputs, **kwargs):
311+
(value,) = values
312+
313+
*rv_inputs, lower_bound, upper_bound, rng = inputs
314+
rv_inputs = [rng, *rv_inputs]
315+
316+
base_rv_op = op.base_rv_op
317+
logp = _logprob(base_rv_op, (value,), *rv_inputs, **kwargs)
318+
# For left truncated RVs, we don't want to include the lower bound in the
319+
# normalization term
320+
lower_bound_value = (
321+
lower_bound - 1 if base_rv_op.dtype.startswith("int") else lower_bound
322+
)
323+
lower_logcdf = _logcdf(base_rv_op, lower_bound_value, *rv_inputs, **kwargs)
324+
upper_logcdf = _logcdf(base_rv_op, upper_bound, *rv_inputs, **kwargs)
325+
326+
if base_rv_op.name:
327+
logp.name = f"{base_rv_op}_logprob"
328+
lower_logcdf.name = f"{base_rv_op}_lower_logcdf"
329+
upper_logcdf.name = f"{base_rv_op}_upper_logcdf"
330+
331+
is_lower_bounded = not (
332+
isinstance(lower_bound, TensorConstant)
333+
and np.all(np.isneginf(lower_bound.value))
334+
)
335+
is_upper_bounded = not (
336+
isinstance(upper_bound, TensorConstant) and np.all(np.isinf(upper_bound.value))
337+
)
338+
339+
lognorm = 0
340+
if is_lower_bounded and is_upper_bounded:
341+
lognorm = logdiffexp(upper_logcdf, lower_logcdf)
342+
elif is_lower_bounded:
343+
lognorm = at.log1mexp(lower_logcdf)
344+
elif is_upper_bounded:
345+
lognorm = upper_logcdf
346+
347+
logp = logp - lognorm
348+
349+
if is_lower_bounded:
350+
logp = at.switch(value < lower_bound, -np.inf, logp)
351+
352+
if is_upper_bounded:
353+
logp = at.switch(value <= upper_bound, logp, -np.inf)
354+
355+
if is_lower_bounded and is_upper_bounded:
356+
logp = CheckParameterValue("lower_bound <= upper_bound")(
357+
logp, at.all(at.le(lower_bound, upper_bound))
358+
)
359+
360+
return logp
361+
362+
363+
@_truncated.register(arb.UniformRV)
364+
def uniform_truncated(op, lower, upper, rng, size, dtype, lower_orig, upper_orig):
365+
truncated_uniform = at.random.uniform(
366+
at.max((lower_orig, lower)),
367+
at.min((upper_orig, upper)),
368+
rng=rng,
369+
size=size,
370+
dtype=dtype,
371+
)
372+
return truncated_uniform, {
373+
truncated_uniform.owner.inputs[0]: truncated_uniform.owner.outputs[0]
374+
}

0 commit comments

Comments
 (0)