Skip to content

Commit 18d6e59

Browse files
Ricardo Vieirabrandonwillard
authored andcommitted
Refer to specific censoring type in IR
1 parent db0b59d commit 18d6e59

File tree

3 files changed

+44
-29
lines changed

3 files changed

+44
-29
lines changed

aeppl/censoring.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,24 @@
1515
from aeppl.rewriting import measurable_ir_rewrites_db
1616

1717

18-
class CensoredRV(Elemwise):
19-
"""A placeholder used to specify a log-likelihood for a censored RV sub-graph."""
18+
class MeasurableClip(Elemwise):
19+
"""A placeholder used to specify a log-likelihood for a clipped RV sub-graph."""
2020

2121

22-
MeasurableVariable.register(CensoredRV)
22+
MeasurableVariable.register(MeasurableClip)
2323

2424

2525
@node_rewriter(tracks=[Elemwise])
26-
def find_censored_rvs(fgraph: FunctionGraph, node: Node) -> Optional[List[CensoredRV]]:
26+
def find_measurable_clips(
27+
fgraph: FunctionGraph, node: Node
28+
) -> Optional[List[MeasurableClip]]:
2729
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)
2830

2931
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
3032
if rv_map_feature is None:
3133
return None # pragma: no cover
3234

33-
if isinstance(node.op, CensoredRV):
35+
if isinstance(node.op, MeasurableClip):
3436
return None # pragma: no cover
3537

3638
if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Clip)):
@@ -47,35 +49,48 @@ def find_censored_rvs(fgraph: FunctionGraph, node: Node) -> Optional[List[Censor
4749
return None
4850

4951
# Replace bounds by `+-inf` if `y = clip(x, x, ?)` or `y=clip(x, ?, x)`
50-
# This is used in `censor_logprob` to generate a more succint logprob graph
51-
# for one-sided censored random variables
52+
# This is used in `clip_logprob` to generate a more succint logprob graph
53+
# for one-sided clipped random variables
5254
lower_bound = lower_bound if (lower_bound is not base_var) else at.constant(-np.inf)
5355
upper_bound = upper_bound if (upper_bound is not base_var) else at.constant(np.inf)
5456

55-
censored_op = CensoredRV(scalar_clip)
57+
clipped_op = MeasurableClip(scalar_clip)
5658
# Make base_var unmeasurable
5759
unmeasurable_base_var = assign_custom_measurable_outputs(base_var.owner)
58-
censored_rv_node = censored_op.make_node(
60+
clipped_rv_node = clipped_op.make_node(
5961
unmeasurable_base_var, lower_bound, upper_bound
6062
)
61-
censored_rv = censored_rv_node.outputs[0]
63+
clipped_rv = clipped_rv_node.outputs[0]
6264

63-
censored_rv.name = clipped_var.name
65+
clipped_rv.name = clipped_var.name
6466

65-
return [censored_rv]
67+
return [clipped_rv]
6668

6769

6870
measurable_ir_rewrites_db.register(
69-
"find_censored_rvs",
70-
find_censored_rvs,
71+
"find_measurable_clips",
72+
find_measurable_clips,
7173
0,
7274
"basic",
7375
"censoring",
7476
)
7577

7678

77-
@_logprob.register(CensoredRV)
78-
def censor_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
79+
@_logprob.register(MeasurableClip)
80+
def clip_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
81+
r"""Logprob of a clipped censored distribution
82+
83+
The probability is given by
84+
.. math::
85+
\begin{cases}
86+
0 & \text{for } x < lower, \\
87+
\text{CDF}(lower, dist) & \text{for } x = lower, \\
88+
\text{P}(x, dist) & \text{for } lower < x < upper, \\
89+
1-\text{CDF}(upper, dist) & \text {for} x = upper, \\
90+
0 & \text{for } x > upper,
91+
\end{cases}
92+
93+
"""
7994
(value,) = values
8095

8196
base_rv_op = base_rv.owner.op
@@ -95,7 +110,7 @@ def censor_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
95110
is_upper_bounded = True
96111

97112
logccdf = at.log1mexp(logcdf)
98-
# For right censored discrete RVs, we need to add an extra term
113+
# For right clipped discrete RVs, we need to add an extra term
99114
# corresponding to the pmf at the upper bound
100115
if base_rv.dtype.startswith("int"):
101116
logccdf = at.logaddexp(logccdf, logprob)

tests/test_censoring.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
@aesara.config.change_flags(compute_test_value="raise")
14-
def test_continuous_rv_censoring():
14+
def test_continuous_rv_clip():
1515
x_rv = at.random.normal(0.5, 1)
1616
cens_x_rv = at.clip(x_rv, -2, 2)
1717

@@ -32,7 +32,7 @@ def test_continuous_rv_censoring():
3232
assert np.isclose(logp_fn(0), ref_scipy.logpdf(0))
3333

3434

35-
def test_discrete_rv_censoring():
35+
def test_discrete_rv_clip():
3636
x_rv = at.random.poisson(2)
3737
cens_x_rv = at.clip(x_rv, 1, 4)
3838

@@ -52,7 +52,7 @@ def test_discrete_rv_censoring():
5252
assert np.isclose(logp_fn(2), ref_scipy.logpmf(2))
5353

5454

55-
def test_one_sided_censoring():
55+
def test_one_sided_clip():
5656
x_rv = at.random.normal(0, 1)
5757
lb_cens_x_rv = at.clip(x_rv, -1, x_rv)
5858
ub_cens_x_rv = at.clip(x_rv, x_rv, 1)
@@ -74,7 +74,7 @@ def test_one_sided_censoring():
7474
np.testing.assert_almost_equal(logp_fn(1, -1), ref_scipy.logpdf(-1))
7575

7676

77-
def test_useless_censoring():
77+
def test_useless_clip():
7878
x_rv = at.random.normal(0.5, 1, size=3)
7979
cens_x_rv = at.clip(x_rv, x_rv, x_rv)
8080

@@ -89,7 +89,7 @@ def test_useless_censoring():
8989
np.testing.assert_allclose(logp_fn([-2, 0, 2]), ref_scipy.logpdf([-2, 0, 2]))
9090

9191

92-
def test_random_censoring():
92+
def test_random_clip():
9393
lb_rv = at.random.normal(0, 1, size=2)
9494
x_rv = at.random.normal(0, 2)
9595
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])
@@ -105,7 +105,7 @@ def test_random_censoring():
105105
assert res[1] != -np.inf
106106

107107

108-
def test_broadcasted_censoring_constant():
108+
def test_broadcasted_clip_constant():
109109
lb_rv = at.random.uniform(0, 1)
110110
x_rv = at.random.normal(0, 2)
111111
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])
@@ -117,7 +117,7 @@ def test_broadcasted_censoring_constant():
117117
assert_no_rvs(logp)
118118

119119

120-
def test_broadcasted_censoring_random():
120+
def test_broadcasted_clip_random():
121121
lb_rv = at.random.normal(0, 1)
122122
x_rv = at.random.normal(0, 2, size=2)
123123
cens_x_rv = at.clip(x_rv, lb_rv, 1)
@@ -129,7 +129,7 @@ def test_broadcasted_censoring_random():
129129
assert_no_rvs(logp)
130130

131131

132-
def test_fail_base_and_censored_have_values():
132+
def test_fail_base_and_clip_have_values():
133133
"""Test failure when both base_rv and clipped_rv are given value vars"""
134134
x_rv = at.random.normal(0, 1)
135135
cens_x_rv = at.clip(x_rv, x_rv, 1)
@@ -141,7 +141,7 @@ def test_fail_base_and_censored_have_values():
141141
factorized_joint_logprob({cens_x_rv: cens_x_vv, x_rv: x_vv})
142142

143143

144-
def test_fail_multiple_censored_single_base():
144+
def test_fail_multiple_clip_single_base():
145145
"""Test failure when multiple clipped_rvs share a single base_rv"""
146146
base_rv = at.random.normal(0, 1)
147147
cens_rv1 = at.clip(base_rv, -1, 1)
@@ -172,7 +172,7 @@ def test_deterministic_clipping():
172172
)
173173

174174

175-
def test_censored_transform():
175+
def test_clip_transform():
176176
x_rv = at.random.normal(0.5, 1)
177177
cens_x_rv = at.clip(x_rv, 0, x_rv)
178178

tests/test_composite_logprob.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import scipy.stats as st
55

66
from aeppl import joint_logprob
7+
from aeppl.censoring import MeasurableClip
78
from aeppl.rewriting import construct_ir_fgraph
8-
from aeppl.truncation import CensoredRV
99
from tests.utils import assert_no_rvs
1010

1111

@@ -97,7 +97,7 @@ def test_unvalued_ir_reversion():
9797
assert memo[y_rv] in z_fgraph.preserve_rv_mappings.measurable_conversions
9898

9999
measurable_y_rv = z_fgraph.preserve_rv_mappings.measurable_conversions[memo[y_rv]]
100-
assert isinstance(measurable_y_rv.owner.op, CensoredRV)
100+
assert isinstance(measurable_y_rv.owner.op, MeasurableClip)
101101

102102
# `construct_ir_fgraph` should've reverted the un-valued measurable IR
103103
# change

0 commit comments

Comments
 (0)