|
1 |
| -from typing import List, Optional |
| 1 | +from functools import singledispatch |
| 2 | +from typing import List, Optional, Tuple |
2 | 3 |
|
3 | 4 | import aesara.tensor as at
|
| 5 | +import aesara.tensor.random.basic as arb |
4 | 6 | import numpy as np
|
| 7 | +from aesara import scan, shared |
| 8 | +from aesara.compile.builders import OpFromGraph |
5 | 9 | from aesara.graph.basic import Node
|
6 | 10 | from aesara.graph.fg import FunctionGraph
|
| 11 | +from aesara.graph.op import Op |
7 | 12 | from aesara.graph.opt import local_optimizer
|
| 13 | +from aesara.raise_op import CheckAndRaise |
8 | 14 | from aesara.scalar.basic import Clip
|
9 | 15 | from aesara.scalar.basic import clip as scalar_clip
|
| 16 | +from aesara.scan import until |
10 | 17 | from aesara.tensor.elemwise import Elemwise
|
11 |
| -from aesara.tensor.var import TensorConstant |
| 18 | +from aesara.tensor.random.op import RandomVariable |
| 19 | +from aesara.tensor.var import TensorConstant, TensorVariable |
12 | 20 |
|
13 |
| -from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs |
14 |
| -from aeppl.logprob import CheckParameterValue, _logcdf, _logprob |
| 21 | +from aeppl.abstract import ( |
| 22 | + MeasurableVariable, |
| 23 | + _get_measurable_outputs, |
| 24 | + assign_custom_measurable_outputs, |
| 25 | +) |
| 26 | +from aeppl.logprob import ( |
| 27 | + CheckParameterValue, |
| 28 | + _logcdf, |
| 29 | + _logprob, |
| 30 | + icdf, |
| 31 | + logcdf, |
| 32 | + logdiffexp, |
| 33 | +) |
15 | 34 | from aeppl.opt import measurable_ir_rewrites_db
|
16 | 35 |
|
17 | 36 |
|
@@ -122,3 +141,234 @@ def censor_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
|
122 | 141 | )
|
123 | 142 |
|
124 | 143 | return logprob
|
| 144 | + |
| 145 | + |
| 146 | +class TruncatedRV(OpFromGraph): |
| 147 | + """An `Op` constructed from an Aesara graph that represents a truncated univariate random variable.""" |
| 148 | + |
| 149 | + default_output = 0 |
| 150 | + base_rv_op = None |
| 151 | + |
| 152 | + def __init__(self, base_rv_op: Op, *args, **kwargs): |
| 153 | + self.base_rv_op = base_rv_op |
| 154 | + super().__init__(*args, **kwargs) |
| 155 | + |
| 156 | + |
| 157 | +MeasurableVariable.register(TruncatedRV) |
| 158 | + |
| 159 | + |
| 160 | +@_get_measurable_outputs.register(TruncatedRV) |
| 161 | +def _get_measurable_outputs_TruncatedRV(op, node): |
| 162 | + return [node.outputs[0]] |
| 163 | + |
| 164 | + |
| 165 | +@singledispatch |
| 166 | +def _truncated(op: Op, lower, upper, *params): |
| 167 | + """Return the truncated equivalent of another `RandomVariable`.""" |
| 168 | + raise NotImplementedError( |
| 169 | + f"{op} does not have an equivalent truncated version implemented" |
| 170 | + ) |
| 171 | + |
| 172 | + |
| 173 | +class TruncationError(Exception): |
| 174 | + """Exception for errors generated from truncated graphs""" |
| 175 | + |
| 176 | + |
| 177 | +class TruncationCheck(CheckAndRaise): |
| 178 | + """Implements a check in truncated graphs. |
| 179 | +
|
| 180 | + Raises `TruncationError` if the check is not True. |
| 181 | + """ |
| 182 | + |
| 183 | + def __init__(self, msg=""): |
| 184 | + super().__init__(TruncationError, msg) |
| 185 | + |
| 186 | + def __str__(self): |
| 187 | + return f"TruncationCheck{{{self.msg}}}" |
| 188 | + |
| 189 | + |
| 190 | +def truncate( |
| 191 | + rv: TensorVariable, lower=None, upper=None, max_n_steps: int = 10_000, rng=None |
| 192 | +) -> Tuple[TensorVariable, Tuple[TensorVariable, TensorVariable]]: |
| 193 | + """Truncate a univariate `RandomVariable` between lower and upper. |
| 194 | +
|
| 195 | + If lower or upper is ``None``, the variable is not truncated on that side. |
| 196 | +
|
| 197 | + Depending on dispatched implementations, this function returns either a specialized |
| 198 | + `Op`, or equivalent graph representing the truncation process, via inverse CDF |
| 199 | + sampling, or rejection sampling. |
| 200 | +
|
| 201 | + The argument `max_n_steps` controls the maximum number of resamples that are |
| 202 | + attempted when performing rejection sampling. A `TruncationError` is raised if |
| 203 | + convergence is not reached after that many steps. |
| 204 | +
|
| 205 | + Returns |
| 206 | + ======= |
| 207 | + `TensorVariable` graph representing the truncated `RandomVariable` and respective updates |
| 208 | + """ |
| 209 | + |
| 210 | + if lower is None and upper is None: |
| 211 | + raise ValueError("lower and upper cannot both be None") |
| 212 | + |
| 213 | + if not (isinstance(rv.owner.op, RandomVariable) and rv.owner.op.ndim_supp == 0): |
| 214 | + raise NotImplementedError( |
| 215 | + f"Truncation is only implemented for univariate random variables, got {rv.owner.op}" |
| 216 | + ) |
| 217 | + |
| 218 | + lower = at.as_tensor_variable(lower) if lower is not None else at.constant(-np.inf) |
| 219 | + upper = at.as_tensor_variable(upper) if upper is not None else at.constant(np.inf) |
| 220 | + |
| 221 | + if rng is None: |
| 222 | + rng = shared(np.random.RandomState(), borrow=True) |
| 223 | + |
| 224 | + # Try to use specialized Op |
| 225 | + try: |
| 226 | + truncated_rv, updates = _truncated( |
| 227 | + rv.owner.op, lower, upper, rng, *rv.owner.inputs[1:] |
| 228 | + ) |
| 229 | + return truncated_rv, updates |
| 230 | + except NotImplementedError: |
| 231 | + pass |
| 232 | + |
| 233 | + # Variables with `_` suffix identify dummy inputs for the OpFromGraph |
| 234 | + # We will use the Shared RNG variable directly because Scan demands it, even |
| 235 | + # though it would not be necessary for the icdf OpFromGraph |
| 236 | + graph_inputs = [*rv.owner.inputs[1:], lower, upper] |
| 237 | + graph_inputs_ = [inp.type() for inp in graph_inputs] |
| 238 | + *rv_inputs_, lower_, upper_ = graph_inputs_ |
| 239 | + rv_ = rv.owner.op.make_node(rng, *rv_inputs_).default_output() |
| 240 | + |
| 241 | + # Try to use inverted cdf sampling |
| 242 | + try: |
| 243 | + # For left truncated discrete RVs, we need to include the whole lower bound. |
| 244 | + # This may result in draws below the truncation range, if any uniform == 0 |
| 245 | + lower_value = lower_ - 1 if rv.owner.op.dtype.startswith("int") else lower_ |
| 246 | + cdf_lower_ = at.exp(logcdf(rv_, lower_value)) |
| 247 | + cdf_upper_ = at.exp(logcdf(rv_, upper_)) |
| 248 | + uniform_ = at.random.uniform( |
| 249 | + cdf_lower_, |
| 250 | + cdf_upper_, |
| 251 | + rng=rng, |
| 252 | + size=rv_inputs_[0], |
| 253 | + ) |
| 254 | + truncated_rv_ = icdf(rv_, uniform_) |
| 255 | + truncated_rv = TruncatedRV( |
| 256 | + base_rv_op=rv.owner.op, |
| 257 | + inputs=graph_inputs_, |
| 258 | + outputs=[truncated_rv_, uniform_.owner.outputs[0]], |
| 259 | + inline=True, |
| 260 | + )(*graph_inputs) |
| 261 | + updates = {truncated_rv.owner.inputs[-1]: truncated_rv.owner.outputs[-1]} |
| 262 | + return truncated_rv, updates |
| 263 | + except NotImplementedError: |
| 264 | + pass |
| 265 | + |
| 266 | + # Fallback to rejection sampling |
| 267 | + # TODO: Handle potential broadcast by lower / upper |
| 268 | + def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs): |
| 269 | + next_rng, new_truncated_rv = rv.owner.op.make_node(rng, *rv_inputs).outputs |
| 270 | + truncated_rv = at.set_subtensor( |
| 271 | + truncated_rv[reject_draws], |
| 272 | + new_truncated_rv[reject_draws], |
| 273 | + ) |
| 274 | + reject_draws = at.or_((truncated_rv < lower), (truncated_rv > upper)) |
| 275 | + |
| 276 | + return ( |
| 277 | + (truncated_rv, reject_draws), |
| 278 | + [(rng, next_rng)], |
| 279 | + until(~at.any(reject_draws)), |
| 280 | + ) |
| 281 | + |
| 282 | + (truncated_rv_, reject_draws_), updates = scan( |
| 283 | + loop_fn, |
| 284 | + outputs_info=[ |
| 285 | + at.zeros_like(rv_), |
| 286 | + at.ones_like(rv_, dtype=bool), |
| 287 | + ], |
| 288 | + non_sequences=[lower_, upper_, rng, *rv_inputs_], |
| 289 | + n_steps=max_n_steps, |
| 290 | + strict=True, |
| 291 | + ) |
| 292 | + |
| 293 | + truncated_rv_ = truncated_rv_[-1] |
| 294 | + convergence_ = ~at.any(reject_draws_[-1]) |
| 295 | + truncated_rv_ = TruncationCheck( |
| 296 | + f"Truncation did not converge in {max_n_steps} steps" |
| 297 | + )(truncated_rv_, convergence_) |
| 298 | + |
| 299 | + truncated_rv = TruncatedRV( |
| 300 | + base_rv_op=rv.owner.op, |
| 301 | + inputs=graph_inputs_, |
| 302 | + outputs=[truncated_rv_, tuple(updates.values())[0]], |
| 303 | + inline=True, |
| 304 | + )(*graph_inputs) |
| 305 | + updates = {truncated_rv.owner.inputs[-1]: truncated_rv.owner.outputs[-1]} |
| 306 | + return truncated_rv, updates |
| 307 | + |
| 308 | + |
| 309 | +@_logprob.register(TruncatedRV) |
| 310 | +def truncated_logprob(op, values, *inputs, **kwargs): |
| 311 | + (value,) = values |
| 312 | + |
| 313 | + *rv_inputs, lower_bound, upper_bound, rng = inputs |
| 314 | + rv_inputs = [rng, *rv_inputs] |
| 315 | + |
| 316 | + base_rv_op = op.base_rv_op |
| 317 | + logp = _logprob(base_rv_op, (value,), *rv_inputs, **kwargs) |
| 318 | + # For left truncated RVs, we don't want to include the lower bound in the |
| 319 | + # normalization term |
| 320 | + lower_bound_value = ( |
| 321 | + lower_bound - 1 if base_rv_op.dtype.startswith("int") else lower_bound |
| 322 | + ) |
| 323 | + lower_logcdf = _logcdf(base_rv_op, lower_bound_value, *rv_inputs, **kwargs) |
| 324 | + upper_logcdf = _logcdf(base_rv_op, upper_bound, *rv_inputs, **kwargs) |
| 325 | + |
| 326 | + if base_rv_op.name: |
| 327 | + logp.name = f"{base_rv_op}_logprob" |
| 328 | + lower_logcdf.name = f"{base_rv_op}_lower_logcdf" |
| 329 | + upper_logcdf.name = f"{base_rv_op}_upper_logcdf" |
| 330 | + |
| 331 | + is_lower_bounded = not ( |
| 332 | + isinstance(lower_bound, TensorConstant) |
| 333 | + and np.all(np.isneginf(lower_bound.value)) |
| 334 | + ) |
| 335 | + is_upper_bounded = not ( |
| 336 | + isinstance(upper_bound, TensorConstant) and np.all(np.isinf(upper_bound.value)) |
| 337 | + ) |
| 338 | + |
| 339 | + lognorm = 0 |
| 340 | + if is_lower_bounded and is_upper_bounded: |
| 341 | + lognorm = logdiffexp(upper_logcdf, lower_logcdf) |
| 342 | + elif is_lower_bounded: |
| 343 | + lognorm = at.log1mexp(lower_logcdf) |
| 344 | + elif is_upper_bounded: |
| 345 | + lognorm = upper_logcdf |
| 346 | + |
| 347 | + logp = logp - lognorm |
| 348 | + |
| 349 | + if is_lower_bounded: |
| 350 | + logp = at.switch(value < lower_bound, -np.inf, logp) |
| 351 | + |
| 352 | + if is_upper_bounded: |
| 353 | + logp = at.switch(value <= upper_bound, logp, -np.inf) |
| 354 | + |
| 355 | + if is_lower_bounded and is_upper_bounded: |
| 356 | + logp = CheckParameterValue("lower_bound <= upper_bound")( |
| 357 | + logp, at.all(at.le(lower_bound, upper_bound)) |
| 358 | + ) |
| 359 | + |
| 360 | + return logp |
| 361 | + |
| 362 | + |
| 363 | +@_truncated.register(arb.UniformRV) |
| 364 | +def uniform_truncated(op, lower, upper, rng, size, dtype, lower_orig, upper_orig): |
| 365 | + truncated_uniform = at.random.uniform( |
| 366 | + at.max((lower_orig, lower)), |
| 367 | + at.min((upper_orig, upper)), |
| 368 | + rng=rng, |
| 369 | + size=size, |
| 370 | + dtype=dtype, |
| 371 | + ) |
| 372 | + return truncated_uniform, { |
| 373 | + truncated_uniform.owner.inputs[0]: truncated_uniform.owner.outputs[0] |
| 374 | + } |
0 commit comments