|
| 1 | +from functools import singledispatch |
| 2 | +from typing import Optional, Tuple |
| 3 | + |
| 4 | +import aesara.tensor as at |
| 5 | +import aesara.tensor.random.basic as arb |
| 6 | +import numpy as np |
| 7 | +from aesara import scan |
| 8 | +from aesara.compile.builders import OpFromGraph |
| 9 | +from aesara.graph.op import Op |
| 10 | +from aesara.raise_op import CheckAndRaise |
| 11 | +from aesara.scan import until |
| 12 | +from aesara.tensor.random import RandomStream |
| 13 | +from aesara.tensor.random.op import RandomVariable |
| 14 | +from aesara.tensor.var import TensorConstant, TensorVariable |
| 15 | + |
| 16 | +from aeppl.abstract import MeasurableVariable, _get_measurable_outputs |
| 17 | +from aeppl.logprob import ( |
| 18 | + CheckParameterValue, |
| 19 | + _logcdf, |
| 20 | + _logprob, |
| 21 | + icdf, |
| 22 | + logcdf, |
| 23 | + logdiffexp, |
| 24 | +) |
| 25 | + |
| 26 | + |
| 27 | +class TruncatedRV(OpFromGraph): |
| 28 | + """An `Op` constructed from an Aesara graph that represents a truncated univariate random variable.""" |
| 29 | + |
| 30 | + default_output = 0 |
| 31 | + base_rv_op = None |
| 32 | + |
| 33 | + def __init__(self, base_rv_op: Op, *args, **kwargs): |
| 34 | + self.base_rv_op = base_rv_op |
| 35 | + super().__init__(*args, **kwargs) |
| 36 | + |
| 37 | + |
| 38 | +MeasurableVariable.register(TruncatedRV) |
| 39 | + |
| 40 | + |
| 41 | +@_get_measurable_outputs.register(TruncatedRV) |
| 42 | +def _get_measurable_outputs_TruncatedRV(op, node): |
| 43 | + return [node.outputs[0]] |
| 44 | + |
| 45 | + |
| 46 | +@singledispatch |
| 47 | +def _truncated(op: Op, lower, upper, *params): |
| 48 | + """Return the truncated equivalent of another `RandomVariable`.""" |
| 49 | + raise NotImplementedError( |
| 50 | + f"{op} does not have an equivalent truncated version implemented" |
| 51 | + ) |
| 52 | + |
| 53 | + |
| 54 | +class TruncationError(Exception): |
| 55 | + """Exception for errors generated from truncated graphs""" |
| 56 | + |
| 57 | + |
| 58 | +class TruncationCheck(CheckAndRaise): |
| 59 | + """Implements a check in truncated graphs. |
| 60 | +
|
| 61 | + Raises `TruncationError` if the check is not True. |
| 62 | + """ |
| 63 | + |
| 64 | + def __init__(self, msg=""): |
| 65 | + super().__init__(TruncationError, msg) |
| 66 | + |
| 67 | + def __str__(self): |
| 68 | + return f"TruncationCheck{{{self.msg}}}" |
| 69 | + |
| 70 | + |
| 71 | +def truncate( |
| 72 | + rv: TensorVariable, |
| 73 | + lower=None, |
| 74 | + upper=None, |
| 75 | + max_n_steps: int = 10_000, |
| 76 | + srng: Optional[RandomStream] = None, |
| 77 | +) -> Tuple[TensorVariable, Tuple[TensorVariable, TensorVariable]]: |
| 78 | + """Truncate a univariate `RandomVariable` between `lower` and `upper`. |
| 79 | +
|
| 80 | + If `lower` or `upper` is ``None``, the variable is not truncated on that side. |
| 81 | +
|
| 82 | + Depending on whether or not a dispatch implementation is available, this |
| 83 | + function returns either a specialized `Op`, or an equivalent graph |
| 84 | + representing the truncation process via inverse CDF or rejection |
| 85 | + sampling. |
| 86 | +
|
| 87 | + The argument `max_n_steps` controls the maximum number of resamples that are |
| 88 | + attempted when performing rejection sampling. A `TruncationError` is raised if |
| 89 | + convergence is not reached after that many steps. |
| 90 | +
|
| 91 | + Returns |
| 92 | + ======= |
| 93 | + `TensorVariable` graph representing the truncated `RandomVariable` and respective updates |
| 94 | + """ |
| 95 | + |
| 96 | + if lower is None and upper is None: |
| 97 | + raise ValueError("lower and upper cannot both be None") |
| 98 | + |
| 99 | + if not (isinstance(rv.owner.op, RandomVariable) and rv.owner.op.ndim_supp == 0): |
| 100 | + raise NotImplementedError( |
| 101 | + f"Truncation is only implemented for univariate random variables, got {rv.owner.op}" |
| 102 | + ) |
| 103 | + |
| 104 | + lower = at.as_tensor_variable(lower) if lower is not None else at.constant(-np.inf) |
| 105 | + upper = at.as_tensor_variable(upper) if upper is not None else at.constant(np.inf) |
| 106 | + |
| 107 | + if srng is None: |
| 108 | + srng = RandomStream() |
| 109 | + |
| 110 | + # Try to use specialized Op |
| 111 | + try: |
| 112 | + truncated_rv, updates = _truncated( |
| 113 | + rv.owner.op, lower, upper, srng, *rv.owner.inputs[1:] |
| 114 | + ) |
| 115 | + return truncated_rv, updates |
| 116 | + except NotImplementedError: |
| 117 | + pass |
| 118 | + |
| 119 | + # Variables with `_` suffix identify dummy inputs for the OpFromGraph |
| 120 | + # We will use the Shared RNG variable directly because Scan demands it, even |
| 121 | + # though it would not be necessary for the icdf OpFromGraph |
| 122 | + graph_inputs = [*rv.owner.inputs[1:], lower, upper] |
| 123 | + graph_inputs_ = [inp.type() for inp in graph_inputs] |
| 124 | + size_, dtype_, *rv_inputs_, lower_, upper_ = graph_inputs_ |
| 125 | + rv_ = srng.gen(rv.owner.op, *rv_inputs_, size=size_, dtype=dtype_) |
| 126 | + |
| 127 | + # Try to use inverted cdf sampling |
| 128 | + try: |
| 129 | + # For left truncated discrete RVs, we need to include the whole lower bound. |
| 130 | + # This may result in draws below the truncation range, if any uniform == 0 |
| 131 | + lower_value = lower_ - 1 if rv.owner.op.dtype.startswith("int") else lower_ |
| 132 | + cdf_lower_ = at.exp(logcdf(rv_, lower_value)) |
| 133 | + cdf_upper_ = at.exp(logcdf(rv_, upper_)) |
| 134 | + uniform_ = srng.uniform( |
| 135 | + cdf_lower_, |
| 136 | + cdf_upper_, |
| 137 | + size=size_, |
| 138 | + ) |
| 139 | + truncated_rv_ = icdf(rv_, uniform_) |
| 140 | + truncated_rv = TruncatedRV( |
| 141 | + base_rv_op=rv.owner.op, |
| 142 | + inputs=graph_inputs_, |
| 143 | + outputs=[truncated_rv_, uniform_.owner.outputs[0]], |
| 144 | + inline=True, |
| 145 | + )(*graph_inputs) |
| 146 | + updates = {truncated_rv.owner.inputs[-1]: truncated_rv.owner.outputs[-1]} |
| 147 | + return truncated_rv, updates |
| 148 | + except NotImplementedError: |
| 149 | + pass |
| 150 | + |
| 151 | + # Fallback to rejection sampling |
| 152 | + # TODO: Handle potential broadcast by lower / upper |
| 153 | + def loop_fn(truncated_rv, reject_draws, lower, upper, size, dtype, *rv_inputs): |
| 154 | + new_truncated_rv = srng.gen(rv.owner.op, *rv_inputs, size=size, dtype=dtype) # type: ignore |
| 155 | + truncated_rv = at.set_subtensor( |
| 156 | + truncated_rv[reject_draws], |
| 157 | + new_truncated_rv[reject_draws], |
| 158 | + ) |
| 159 | + reject_draws = at.or_((truncated_rv < lower), (truncated_rv > upper)) |
| 160 | + |
| 161 | + return (truncated_rv, reject_draws), until(~at.any(reject_draws)) |
| 162 | + |
| 163 | + (truncated_rv_, reject_draws_), updates = scan( |
| 164 | + loop_fn, |
| 165 | + outputs_info=[ |
| 166 | + at.zeros_like(rv_), |
| 167 | + at.ones_like(rv_, dtype=bool), |
| 168 | + ], |
| 169 | + non_sequences=[lower_, upper_, size_, dtype_, *rv_inputs_], |
| 170 | + n_steps=max_n_steps, |
| 171 | + strict=True, |
| 172 | + ) |
| 173 | + |
| 174 | + truncated_rv_ = truncated_rv_[-1] |
| 175 | + convergence_ = ~at.any(reject_draws_[-1]) |
| 176 | + truncated_rv_ = TruncationCheck( |
| 177 | + f"Truncation did not converge in {max_n_steps} steps" |
| 178 | + )(truncated_rv_, convergence_) |
| 179 | + |
| 180 | + truncated_rv = TruncatedRV( |
| 181 | + base_rv_op=rv.owner.op, |
| 182 | + inputs=graph_inputs_, |
| 183 | + # This will fail with `n_steps==1`, because in that case `Scan` won't return any updates |
| 184 | + outputs=[truncated_rv_, rv_.owner.outputs[0], tuple(updates.values())[0]], |
| 185 | + inline=True, |
| 186 | + )(*graph_inputs) |
| 187 | + # TODO: Is the order of multiple shared variables determnistic? |
| 188 | + assert truncated_rv.owner.inputs[-2] is rv_.owner.inputs[0] |
| 189 | + updates = { |
| 190 | + truncated_rv.owner.inputs[-2]: truncated_rv.owner.outputs[-2], |
| 191 | + truncated_rv.owner.inputs[-1]: truncated_rv.owner.outputs[-1], |
| 192 | + } |
| 193 | + return truncated_rv, updates |
| 194 | + |
| 195 | + |
| 196 | +@_logprob.register(TruncatedRV) |
| 197 | +def truncated_logprob(op, values, *inputs, **kwargs): |
| 198 | + (value,) = values |
| 199 | + |
| 200 | + # Rejection sample graph has two rngs |
| 201 | + if len(op.shared_inputs) == 2: |
| 202 | + *rv_inputs, lower_bound, upper_bound, _, rng = inputs |
| 203 | + else: |
| 204 | + *rv_inputs, lower_bound, upper_bound, rng = inputs |
| 205 | + rv_inputs = [rng, *rv_inputs] |
| 206 | + |
| 207 | + base_rv_op = op.base_rv_op |
| 208 | + logp = _logprob(base_rv_op, (value,), *rv_inputs, **kwargs) |
| 209 | + # For left truncated RVs, we don't want to include the lower bound in the |
| 210 | + # normalization term |
| 211 | + lower_bound_value = ( |
| 212 | + lower_bound - 1 if base_rv_op.dtype.startswith("int") else lower_bound |
| 213 | + ) |
| 214 | + lower_logcdf = _logcdf(base_rv_op, lower_bound_value, *rv_inputs, **kwargs) |
| 215 | + upper_logcdf = _logcdf(base_rv_op, upper_bound, *rv_inputs, **kwargs) |
| 216 | + |
| 217 | + if base_rv_op.name: |
| 218 | + logp.name = f"{base_rv_op}_logprob" |
| 219 | + lower_logcdf.name = f"{base_rv_op}_lower_logcdf" |
| 220 | + upper_logcdf.name = f"{base_rv_op}_upper_logcdf" |
| 221 | + |
| 222 | + is_lower_bounded = not ( |
| 223 | + isinstance(lower_bound, TensorConstant) |
| 224 | + and np.all(np.isneginf(lower_bound.value)) |
| 225 | + ) |
| 226 | + is_upper_bounded = not ( |
| 227 | + isinstance(upper_bound, TensorConstant) and np.all(np.isinf(upper_bound.value)) |
| 228 | + ) |
| 229 | + |
| 230 | + lognorm = 0 |
| 231 | + if is_lower_bounded and is_upper_bounded: |
| 232 | + lognorm = logdiffexp(upper_logcdf, lower_logcdf) |
| 233 | + elif is_lower_bounded: |
| 234 | + lognorm = at.log1mexp(lower_logcdf) |
| 235 | + elif is_upper_bounded: |
| 236 | + lognorm = upper_logcdf |
| 237 | + |
| 238 | + logp = logp - lognorm |
| 239 | + |
| 240 | + if is_lower_bounded: |
| 241 | + logp = at.switch(value < lower_bound, -np.inf, logp) |
| 242 | + |
| 243 | + if is_upper_bounded: |
| 244 | + logp = at.switch(value <= upper_bound, logp, -np.inf) |
| 245 | + |
| 246 | + if is_lower_bounded and is_upper_bounded: |
| 247 | + logp = CheckParameterValue("lower_bound <= upper_bound")( |
| 248 | + logp, at.all(at.le(lower_bound, upper_bound)) |
| 249 | + ) |
| 250 | + |
| 251 | + return logp |
| 252 | + |
| 253 | + |
| 254 | +@_truncated.register(arb.UniformRV) |
| 255 | +def uniform_truncated(op, lower, upper, srng, size, dtype, lower_orig, upper_orig): |
| 256 | + truncated_uniform = srng.gen( |
| 257 | + op, |
| 258 | + at.max((lower_orig, lower)), |
| 259 | + at.min((upper_orig, upper)), |
| 260 | + size=size, |
| 261 | + dtype=dtype, |
| 262 | + ) |
| 263 | + return truncated_uniform, { |
| 264 | + truncated_uniform.owner.inputs[0]: truncated_uniform.owner.outputs[0] |
| 265 | + } |
0 commit comments