Skip to content

Commit 46f6775

Browse files
Ricardo Vieirabrandonwillard
authored andcommitted
Implement logprob rewrites for rounding operations
1 parent be8ae3a commit 46f6775

File tree

3 files changed

+147
-2
lines changed

3 files changed

+147
-2
lines changed

aeppl/censoring.py

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from aesara.graph.basic import Node
66
from aesara.graph.fg import FunctionGraph
77
from aesara.graph.rewriting.basic import node_rewriter
8-
from aesara.scalar.basic import Clip
8+
from aesara.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven
99
from aesara.scalar.basic import clip as scalar_clip
1010
from aesara.tensor.elemwise import Elemwise
1111
from aesara.tensor.var import TensorConstant
@@ -15,7 +15,7 @@
1515
MeasurableVariable,
1616
assign_custom_measurable_outputs,
1717
)
18-
from aeppl.logprob import CheckParameterValue, _logcdf, _logprob
18+
from aeppl.logprob import CheckParameterValue, _logcdf, _logprob, logdiffexp
1919
from aeppl.rewriting import measurable_ir_rewrites_db
2020

2121

@@ -142,3 +142,113 @@ def clip_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
142142
)
143143

144144
return logprob
145+
146+
147+
class MeasurableRound(MeasurableElemwise):
148+
"""A placeholder used to specify a log-likelihood for a clipped RV sub-graph."""
149+
150+
valid_scalar_types = (RoundHalfToEven, Floor, Ceil)
151+
152+
153+
@node_rewriter(tracks=[Elemwise])
154+
def find_measurable_roundings(
155+
fgraph: FunctionGraph, node: Node
156+
) -> Optional[List[MeasurableRound]]:
157+
158+
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
159+
if rv_map_feature is None:
160+
return None # pragma: no cover
161+
162+
if isinstance(node.op, MeasurableRound):
163+
return None # pragma: no cover
164+
165+
if not (
166+
isinstance(node.op, Elemwise)
167+
and isinstance(node.op.scalar_op, MeasurableRound.valid_scalar_types)
168+
):
169+
return None
170+
171+
(rounded_var,) = node.outputs
172+
(base_var,) = node.inputs
173+
174+
if not (
175+
base_var.owner
176+
and isinstance(base_var.owner.op, MeasurableVariable)
177+
and base_var not in rv_map_feature.rv_values
178+
# Rounding only makes sense for continuous variables
179+
and base_var.dtype.startswith("float")
180+
):
181+
return None
182+
183+
# Make base_var unmeasurable
184+
unmeasurable_base_var = assign_custom_measurable_outputs(base_var.owner)
185+
186+
rounded_op = MeasurableRound(node.op.scalar_op)
187+
rounded_rv = rounded_op.make_node(unmeasurable_base_var).default_output()
188+
rounded_rv.name = rounded_var.name
189+
return [rounded_rv]
190+
191+
192+
measurable_ir_rewrites_db.register(
193+
"find_measurable_roundings",
194+
find_measurable_roundings,
195+
0,
196+
"basic",
197+
"censoring",
198+
)
199+
200+
201+
@_logprob.register(MeasurableRound)
202+
def round_logprob(op, values, base_rv, **kwargs):
203+
r"""Logprob of a rounded censored distribution
204+
205+
The probability of a distribution rounded to the nearest integer is given by
206+
.. math::
207+
\begin{cases}
208+
\text{CDF}(x+\frac{1}{2}, dist) - \text{CDF}(x-\frac{1}{2}, dist) & \text{for } x \in \mathbb{Z}, \\
209+
0 & \text{otherwise},
210+
\end{cases}
211+
212+
The probability of a distribution rounded up is given by
213+
.. math::
214+
\begin{cases}
215+
\text{CDF}(x, dist) - \text{CDF}(x-1, dist) & \text{for } x \in \mathbb{Z}, \\
216+
0 & \text{otherwise},
217+
\end{cases}
218+
219+
The probability of a distribution rounded down is given by
220+
.. math::
221+
\begin{cases}
222+
\text{CDF}(x+1, dist) - \text{CDF}(x, dist) & \text{for } x \in \mathbb{Z}, \\
223+
0 & \text{otherwise},
224+
\end{cases}
225+
226+
"""
227+
(value,) = values
228+
229+
if isinstance(op.scalar_op, RoundHalfToEven):
230+
value = at.round(value)
231+
value_upper = value + 0.5
232+
value_lower = value - 0.5
233+
elif isinstance(op.scalar_op, Floor):
234+
value = at.floor(value)
235+
value_upper = value + 1.0
236+
value_lower = value
237+
elif isinstance(op.scalar_op, Ceil):
238+
value = at.ceil(value)
239+
value_upper = value
240+
value_lower = value - 1.0
241+
else:
242+
raise TypeError(f"Unsupported scalar_op {op.scalar_op}") # pragma: no cover
243+
244+
base_rv_op = base_rv.owner.op
245+
base_rv_inputs = base_rv.owner.inputs
246+
247+
logcdf_upper = _logcdf(base_rv_op, value_upper, *base_rv_inputs, **kwargs)
248+
logcdf_lower = _logcdf(base_rv_op, value_lower, *base_rv_inputs, **kwargs)
249+
250+
if base_rv_op.name:
251+
logcdf_upper.name = f"{base_rv_op}_logcdf_upper"
252+
logcdf_lower.name = f"{base_rv_op}_logcdf_lower"
253+
254+
return logdiffexp(logcdf_upper, logcdf_lower)

aeppl/logprob.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ def binomln(n, k):
4040
return at.gammaln(n + 1) - at.gammaln(k + 1) - at.gammaln(n - k + 1)
4141

4242

43+
def logdiffexp(a, b):
44+
"""log(exp(a) - exp(b))"""
45+
return a + at.log1mexp(b - a)
46+
47+
4348
def xlogy0(m, x):
4449
# TODO: This should probably be a basic Aesara stabilization
4550
return at.switch(at.eq(x, 0), at.switch(at.eq(m, 0), 0.0, -np.inf), m * at.log(x))

tests/test_censoring.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,33 @@ def test_clip_transform():
188188
)
189189

190190
assert np.isclose(obs_logp, exp_logp)
191+
192+
193+
@pytest.mark.parametrize("rounding_op", (at.round, at.floor, at.ceil))
194+
def test_rounding(rounding_op):
195+
loc = 1
196+
scale = 2
197+
test_value = np.arange(-3, 4)
198+
199+
x = at.random.normal(loc, scale, size=test_value.shape, name="x")
200+
xr = rounding_op(x)
201+
xr.name = "xr"
202+
203+
xr_vv = xr.clone()
204+
logp = joint_logprob({xr: xr_vv}, sum=False)
205+
assert logp is not None
206+
207+
x_sp = st.norm(loc, scale)
208+
if rounding_op == at.round:
209+
expected_logp = np.log(x_sp.cdf(test_value + 0.5) - x_sp.cdf(test_value - 0.5))
210+
elif rounding_op == at.floor:
211+
expected_logp = np.log(x_sp.cdf(test_value + 1.0) - x_sp.cdf(test_value))
212+
elif rounding_op == at.ceil:
213+
expected_logp = np.log(x_sp.cdf(test_value) - x_sp.cdf(test_value - 1.0))
214+
else:
215+
raise NotImplementedError()
216+
217+
assert np.allclose(
218+
logp.eval({xr_vv: test_value}),
219+
expected_logp,
220+
)

0 commit comments

Comments
 (0)