Skip to content

Commit be8ae3a

Browse files
Ricardo Vieirabrandonwillard
authored andcommitted
Implement MeasurableElemwise
Subclasses are measurable by default, and they must specify what scalar op types are compatible.
1 parent 18d6e59 commit be8ae3a

File tree

4 files changed

+52
-11
lines changed

4 files changed

+52
-11
lines changed

aeppl/abstract.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import abc
22
from copy import copy
33
from functools import singledispatch
4-
from typing import Callable, List
4+
from typing import Callable, List, Tuple
55

66
from aesara.graph.basic import Apply, Variable
77
from aesara.graph.op import Op
88
from aesara.graph.utils import MetaType
9+
from aesara.tensor.elemwise import Elemwise
910
from aesara.tensor.random.op import RandomVariable
1011

1112

@@ -116,3 +117,20 @@ def assign_custom_measurable_outputs(
116117
_get_measurable_outputs.register(new_op_type)(measurable_outputs_fn)
117118

118119
return new_node
120+
121+
122+
class MeasurableElemwise(Elemwise):
123+
"""Base class for Measurable Elemwise variables"""
124+
125+
valid_scalar_types: Tuple[MetaType, ...] = ()
126+
127+
def __init__(self, scalar_op, *args, **kwargs):
128+
if not isinstance(scalar_op, self.valid_scalar_types):
129+
raise TypeError(
130+
f"scalar_op {scalar_op} is not valid for class {self.__class__}. "
131+
f"Acceptable types are {self.valid_scalar_types}"
132+
)
133+
super().__init__(scalar_op, *args, **kwargs)
134+
135+
136+
MeasurableVariable.register(MeasurableElemwise)

aeppl/censoring.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,22 @@
1010
from aesara.tensor.elemwise import Elemwise
1111
from aesara.tensor.var import TensorConstant
1212

13-
from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs
13+
from aeppl.abstract import (
14+
MeasurableElemwise,
15+
MeasurableVariable,
16+
assign_custom_measurable_outputs,
17+
)
1418
from aeppl.logprob import CheckParameterValue, _logcdf, _logprob
1519
from aeppl.rewriting import measurable_ir_rewrites_db
1620

1721

18-
class MeasurableClip(Elemwise):
22+
class MeasurableClip(MeasurableElemwise):
1923
"""A placeholder used to specify a log-likelihood for a clipped RV sub-graph."""
2024

25+
valid_scalar_types = (Clip,)
26+
2127

22-
MeasurableVariable.register(MeasurableClip)
28+
measurable_clip = MeasurableClip(scalar_clip)
2329

2430

2531
@node_rewriter(tracks=[Elemwise])
@@ -54,10 +60,9 @@ def find_measurable_clips(
5460
lower_bound = lower_bound if (lower_bound is not base_var) else at.constant(-np.inf)
5561
upper_bound = upper_bound if (upper_bound is not base_var) else at.constant(np.inf)
5662

57-
clipped_op = MeasurableClip(scalar_clip)
5863
# Make base_var unmeasurable
5964
unmeasurable_base_var = assign_custom_measurable_outputs(base_var.owner)
60-
clipped_rv_node = clipped_op.make_node(
65+
clipped_rv_node = measurable_clip.make_node(
6166
unmeasurable_base_var, lower_bound, upper_bound
6267
)
6368
clipped_rv = clipped_rv_node.outputs[0]

aeppl/transforms.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from aesara.tensor.var import TensorVariable
2121

2222
from aeppl.abstract import (
23+
MeasurableElemwise,
2324
MeasurableVariable,
2425
_get_measurable_outputs,
2526
assign_custom_measurable_outputs,
@@ -213,9 +214,11 @@ def apply(self, fgraph: FunctionGraph):
213214
return self.default_transform_rewrite.rewrite(fgraph)
214215

215216

216-
class MeasurableTransform(Elemwise):
217+
class MeasurableTransform(MeasurableElemwise):
217218
"""A placeholder used to specify a log-likelihood for a transformed measurable variable"""
218219

220+
valid_scalar_types = (Exp, Log, Add, Mul)
221+
219222
# Cannot use `transform` as name because it would clash with the property added by
220223
# the `TransformValuesRewrite`
221224
transform_elemwise: RVTransform
@@ -229,9 +232,6 @@ def __init__(
229232
super().__init__(*args, **kwargs)
230233

231234

232-
MeasurableVariable.register(MeasurableTransform)
233-
234-
235235
@_get_measurable_outputs.register(MeasurableTransform)
236236
def _get_measurable_outputs_Transform(op, node):
237237
return [node.default_output()]
@@ -261,7 +261,7 @@ def find_measurable_transforms(
261261
) -> Optional[List[Node]]:
262262
"""Find measurable transformations from Elemwise operators."""
263263
scalar_op = node.op.scalar_op
264-
if not isinstance(scalar_op, (Exp, Log, Add, Mul)):
264+
if not isinstance(scalar_op, MeasurableTransform.valid_scalar_types):
265265
return None
266266

267267
# Node was already converted

tests/test_abstract.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
import re
2+
13
import aesara.tensor as at
24
import pytest
5+
from aesara.scalar import Exp, exp
36
from aesara.tensor.random.basic import NormalRV
47

58
from aeppl.abstract import (
9+
MeasurableElemwise,
10+
MeasurableVariable,
611
UnmeasurableVariable,
712
_get_measurable_outputs,
813
assign_custom_measurable_outputs,
@@ -92,3 +97,16 @@ def test_assign_custom_measurable_outputs():
9297

9398
with pytest.raises(ValueError):
9499
assign_custom_measurable_outputs(unmeas_X_rv.owner, lambda x: x)
100+
101+
102+
def test_measurable_elemwise():
103+
# Default does not accept any scalar_op
104+
with pytest.raises(TypeError, match=re.escape("scalar_op exp is not valid")):
105+
MeasurableElemwise(exp)
106+
107+
class TestMeasurableElemwise(MeasurableElemwise):
108+
valid_scalar_types = (Exp,)
109+
110+
measurable_exp_op = TestMeasurableElemwise(scalar_op=exp)
111+
measurable_exp = measurable_exp_op(0.0)
112+
assert isinstance(measurable_exp.owner.op, MeasurableVariable)

0 commit comments

Comments
 (0)