|
5 | 5 | from aesara.graph.basic import Node
|
6 | 6 | from aesara.graph.fg import FunctionGraph
|
7 | 7 | from aesara.graph.rewriting.basic import node_rewriter
|
8 |
| -from aesara.scalar.basic import Clip |
| 8 | +from aesara.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven |
9 | 9 | from aesara.scalar.basic import clip as scalar_clip
|
10 | 10 | from aesara.tensor.elemwise import Elemwise
|
11 | 11 | from aesara.tensor.var import TensorConstant
|
|
15 | 15 | MeasurableVariable,
|
16 | 16 | assign_custom_measurable_outputs,
|
17 | 17 | )
|
18 |
| -from aeppl.logprob import CheckParameterValue, _logcdf, _logprob |
| 18 | +from aeppl.logprob import CheckParameterValue, _logcdf, _logprob, logdiffexp |
19 | 19 | from aeppl.rewriting import measurable_ir_rewrites_db
|
20 | 20 |
|
21 | 21 |
|
@@ -142,3 +142,113 @@ def clip_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
|
142 | 142 | )
|
143 | 143 |
|
144 | 144 | return logprob
|
| 145 | + |
| 146 | + |
| 147 | +class MeasurableRound(MeasurableElemwise): |
| 148 | + """A placeholder used to specify a log-likelihood for a clipped RV sub-graph.""" |
| 149 | + |
| 150 | + valid_scalar_types = (RoundHalfToEven, Floor, Ceil) |
| 151 | + |
| 152 | + |
| 153 | +@node_rewriter(tracks=[Elemwise]) |
| 154 | +def find_measurable_roundings( |
| 155 | + fgraph: FunctionGraph, node: Node |
| 156 | +) -> Optional[List[MeasurableRound]]: |
| 157 | + |
| 158 | + rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) |
| 159 | + if rv_map_feature is None: |
| 160 | + return None # pragma: no cover |
| 161 | + |
| 162 | + if isinstance(node.op, MeasurableRound): |
| 163 | + return None # pragma: no cover |
| 164 | + |
| 165 | + if not ( |
| 166 | + isinstance(node.op, Elemwise) |
| 167 | + and isinstance(node.op.scalar_op, MeasurableRound.valid_scalar_types) |
| 168 | + ): |
| 169 | + return None |
| 170 | + |
| 171 | + (rounded_var,) = node.outputs |
| 172 | + (base_var,) = node.inputs |
| 173 | + |
| 174 | + if not ( |
| 175 | + base_var.owner |
| 176 | + and isinstance(base_var.owner.op, MeasurableVariable) |
| 177 | + and base_var not in rv_map_feature.rv_values |
| 178 | + # Rounding only makes sense for continuous variables |
| 179 | + and base_var.dtype.startswith("float") |
| 180 | + ): |
| 181 | + return None |
| 182 | + |
| 183 | + # Make base_var unmeasurable |
| 184 | + unmeasurable_base_var = assign_custom_measurable_outputs(base_var.owner) |
| 185 | + |
| 186 | + rounded_op = MeasurableRound(node.op.scalar_op) |
| 187 | + rounded_rv = rounded_op.make_node(unmeasurable_base_var).default_output() |
| 188 | + rounded_rv.name = rounded_var.name |
| 189 | + return [rounded_rv] |
| 190 | + |
| 191 | + |
| 192 | +measurable_ir_rewrites_db.register( |
| 193 | + "find_measurable_roundings", |
| 194 | + find_measurable_roundings, |
| 195 | + 0, |
| 196 | + "basic", |
| 197 | + "censoring", |
| 198 | +) |
| 199 | + |
| 200 | + |
| 201 | +@_logprob.register(MeasurableRound) |
| 202 | +def round_logprob(op, values, base_rv, **kwargs): |
| 203 | + r"""Logprob of a rounded censored distribution |
| 204 | +
|
| 205 | + The probability of a distribution rounded to the nearest integer is given by |
| 206 | + .. math:: |
| 207 | + \begin{cases} |
| 208 | + \text{CDF}(x+\frac{1}{2}, dist) - \text{CDF}(x-\frac{1}{2}, dist) & \text{for } x \in \mathbb{Z}, \\ |
| 209 | + 0 & \text{otherwise}, |
| 210 | + \end{cases} |
| 211 | +
|
| 212 | + The probability of a distribution rounded up is given by |
| 213 | + .. math:: |
| 214 | + \begin{cases} |
| 215 | + \text{CDF}(x, dist) - \text{CDF}(x-1, dist) & \text{for } x \in \mathbb{Z}, \\ |
| 216 | + 0 & \text{otherwise}, |
| 217 | + \end{cases} |
| 218 | +
|
| 219 | + The probability of a distribution rounded down is given by |
| 220 | + .. math:: |
| 221 | + \begin{cases} |
| 222 | + \text{CDF}(x+1, dist) - \text{CDF}(x, dist) & \text{for } x \in \mathbb{Z}, \\ |
| 223 | + 0 & \text{otherwise}, |
| 224 | + \end{cases} |
| 225 | +
|
| 226 | + """ |
| 227 | + (value,) = values |
| 228 | + |
| 229 | + if isinstance(op.scalar_op, RoundHalfToEven): |
| 230 | + value = at.round(value) |
| 231 | + value_upper = value + 0.5 |
| 232 | + value_lower = value - 0.5 |
| 233 | + elif isinstance(op.scalar_op, Floor): |
| 234 | + value = at.floor(value) |
| 235 | + value_upper = value + 1.0 |
| 236 | + value_lower = value |
| 237 | + elif isinstance(op.scalar_op, Ceil): |
| 238 | + value = at.ceil(value) |
| 239 | + value_upper = value |
| 240 | + value_lower = value - 1.0 |
| 241 | + else: |
| 242 | + raise TypeError(f"Unsupported scalar_op {op.scalar_op}") # pragma: no cover |
| 243 | + |
| 244 | + base_rv_op = base_rv.owner.op |
| 245 | + base_rv_inputs = base_rv.owner.inputs |
| 246 | + |
| 247 | + logcdf_upper = _logcdf(base_rv_op, value_upper, *base_rv_inputs, **kwargs) |
| 248 | + logcdf_lower = _logcdf(base_rv_op, value_lower, *base_rv_inputs, **kwargs) |
| 249 | + |
| 250 | + if base_rv_op.name: |
| 251 | + logcdf_upper.name = f"{base_rv_op}_logcdf_upper" |
| 252 | + logcdf_lower.name = f"{base_rv_op}_logcdf_lower" |
| 253 | + |
| 254 | + return logdiffexp(logcdf_upper, logcdf_lower) |
0 commit comments