Skip to content

Commit 95a4ceb

Browse files
Preserve the identities of valued/observed variables
1 parent dd8c4a9 commit 95a4ceb

File tree

6 files changed

+70
-66
lines changed

6 files changed

+70
-66
lines changed

aeppl/joint_logprob.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -141,18 +141,16 @@ def conditional_logprob(
141141
# maps to the logprob graphs and value variables before returning them.
142142
rv_values = {**original_rv_values, **realized}
143143

144-
fgraph, _, memo = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter)
145-
146-
if extra_rewrites is not None:
147-
extra_rewrites.add_requirements(fgraph, rv_values, memo)
148-
extra_rewrites.apply(fgraph)
144+
fgraph, new_rv_values = construct_ir_fgraph(
145+
rv_values, ir_rewriter=ir_rewriter, extra_rewrites=extra_rewrites
146+
)
149147

150148
# We assign log-densities on a per-node basis, and not per-output/variable.
151149
realized_vars = set()
152150
new_to_old_rvs = {}
153151
nodes_to_vals: Dict["Apply", List[Tuple["Variable", "Variable"]]] = {}
154152

155-
for bnd_var, (old_mvar, old_val) in zip(fgraph.outputs, rv_values.items()):
153+
for bnd_var, (old_mvar, val) in zip(fgraph.outputs, new_rv_values.items()):
156154
mnode = bnd_var.owner
157155
assert mnode and isinstance(mnode.op, ValuedVariable)
158156

@@ -165,11 +163,7 @@ def conditional_logprob(
165163
if old_mvar in realized:
166164
realized_vars.add(rv_var)
167165

168-
# Do this just in case a value variable was changed. (Some transforms
169-
# do this.)
170-
new_val = memo[old_val]
171-
172-
nodes_to_vals.setdefault(rv_node, []).append((val_var, new_val))
166+
nodes_to_vals.setdefault(rv_node, []).append((val_var, val))
173167

174168
new_to_old_rvs[rv_var] = old_mvar
175169

aeppl/rewriting.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1-
from typing import Dict, Optional, Tuple
1+
from typing import Dict, Optional, Tuple, Union
22

33
import aesara.tensor as at
44
from aesara.compile.mode import optdb
55
from aesara.graph.basic import Apply, Variable
66
from aesara.graph.features import Feature
77
from aesara.graph.fg import FunctionGraph
8-
from aesara.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
8+
from aesara.graph.rewriting.basic import (
9+
GraphRewriter,
10+
NodeRewriter,
11+
in2out,
12+
node_rewriter,
13+
)
914
from aesara.graph.rewriting.db import EquilibriumDB, RewriteDatabaseQuery, SequenceDB
1015
from aesara.tensor.elemwise import DimShuffle, Elemwise
1116
from aesara.tensor.extra_ops import BroadcastTo
@@ -180,9 +185,10 @@ def incsubtensor_rv_replace(fgraph, node):
180185

181186

182187
def construct_ir_fgraph(
183-
rv_values: Dict[Variable, Variable],
188+
rvs_to_values: Dict[Variable, Variable],
184189
ir_rewriter: Optional[GraphRewriter] = None,
185-
) -> Tuple[FunctionGraph, Dict[Variable, Variable], Dict[Variable, Variable]]:
190+
extra_rewrites: Optional[Union[GraphRewriter, NodeRewriter]] = None,
191+
) -> Tuple[FunctionGraph, Dict[Variable, Variable]]:
186192
r"""Construct a `FunctionGraph` in measurable IR form for the keys in `rv_values`.
187193
188194
A custom IR rewriter can be specified. By default,
@@ -215,9 +221,8 @@ def construct_ir_fgraph(
215221
Returns
216222
-------
217223
A `FunctionGraph` of the measurable IR, a copy of `rv_values` containing
218-
the new, cloned versions of the original variables in `rv_values`, and
219-
a ``dict`` mapping all the original variables to their cloned values in
220-
the `FunctionGraph`.
224+
the new, cloned versions of the original variables in `rv_values`.
225+
221226
"""
222227

223228
# We're going to create a `FunctionGraph` that effectively represents the
@@ -233,16 +238,20 @@ def construct_ir_fgraph(
233238
# so that they're distinct nodes in the graph. This allows us to replace
234239
# all instances of the original random variables with their value
235240
# variables, while leaving the output clones untouched.
236-
rv_value_clones = {}
241+
rv_clone_to_value_clone = {}
242+
rv_to_value_clone = {}
243+
value_clone_to_value = {}
237244
measured_outputs = {}
238-
memo = {}
239-
for rv, val in rv_values.items():
245+
memo: Dict[Variable, Variable] = {}
246+
for rv, val in rvs_to_values.items():
240247
rv_node_clone = rv.owner.clone()
241248
rv_clone = rv_node_clone.outputs[rv.owner.outputs.index(rv)]
242-
rv_value_clones[rv_clone] = val
243-
measured_outputs[rv] = valued_variable(rv_clone, val)
244-
# Prevent value variables from being cloned
245-
memo[val] = val
249+
val_clone = val.clone()
250+
val_clone.name = "val_clone"
251+
rv_clone_to_value_clone[rv_clone] = val_clone
252+
rv_to_value_clone[rv] = val_clone
253+
value_clone_to_value[val_clone] = val
254+
measured_outputs[rv] = valued_variable(rv_clone, val_clone)
246255

247256
# We add `ShapeFeature` because it will get rid of references to the old
248257
# `RandomVariable`s that have been lifted; otherwise, it will be difficult
@@ -257,9 +266,6 @@ def construct_ir_fgraph(
257266
copy_inputs=False,
258267
)
259268

260-
# Update `rv_values` so that it uses the new cloned variables
261-
rv_value_clones = {memo[k]: v for k, v in rv_value_clones.items()}
262-
263269
# Replace valued non-output variables with their values
264270
fgraph.replace_all(
265271
[(memo[rv], val) for rv, val in measured_outputs.items() if rv in memo],
@@ -272,11 +278,22 @@ def construct_ir_fgraph(
272278

273279
ir_rewriter.rewrite(fgraph)
274280

281+
if extra_rewrites is not None:
282+
# Expect `value_clone_to_value` to be updated in-place
283+
extra_rewrites.add_requirements(fgraph, rv_to_value_clone, value_clone_to_value)
284+
extra_rewrites.apply(fgraph)
285+
275286
# Undo un-valued measurable IR rewrites
276287
new_to_old = tuple((v, k) for k, v in fgraph.measurable_conversions.items())
277-
fgraph.replace_all(new_to_old, reason="undo-unvalued-measurables")
288+
# and add the original value variables back in
289+
new_to_old += tuple(value_clone_to_value.items())
290+
fgraph.replace_all(
291+
new_to_old, reason="undo-unvalued-measurables", import_missing=True
292+
)
293+
294+
new_rvs_to_values = dict(zip(rvs_to_values.keys(), value_clone_to_value.values()))
278295

279-
return fgraph, rv_value_clones, memo
296+
return fgraph, new_rvs_to_values
280297

281298

282299
@register_useless

aeppl/transforms.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -161,14 +161,10 @@ def transform_values(fgraph: FunctionGraph, node: Apply):
161161
4. Replace the old `ValuedVariable` with a new one containing a
162162
`TransformedVariable` value.
163163
164-
Step 3. is currently accomplished by updating the `memo` dictionary
165-
associated with the `FunctionGraph`. Our main entry-point,
164+
Step 3. is currently accomplished by updating the `rvs_to_values`
165+
dictionary associated with the `FunctionGraph`. Our main entry-point,
166166
`conditional_logprob`, checks this dictionary for value variable changes.
167167
168-
TODO: This approach is less than ideal, because it puts awkward demands on
169-
users/callers of this rewrite to check with `memo`; let's see if we can do
170-
something better.
171-
172168
The new value variable mentioned in Step 2. may be of a different `Type`
173169
(e.g. extra/fewer dimensions) than the original value variable; this is why
174170
we must replace the corresponding original value variables before we
@@ -235,8 +231,8 @@ def transform_values(fgraph: FunctionGraph, node: Apply):
235231

236232
# This effectively lets the caller know that a value variable has been
237233
# replaced (i.e. they should filter all their old value variables through
238-
# the memo/replacements map).
239-
fgraph.memo[value_var] = trans_value_var
234+
# the replacements map).
235+
fgraph.value_clone_to_value[value_var] = trans_value_var
240236

241237
trans_var = trans_node.outputs[rv_var_out_idx]
242238
new_var = valued_variable(trans_var, untrans_value_var)
@@ -252,7 +248,7 @@ class TransformValuesMapping(Feature):
252248
253249
"""
254250

255-
def __init__(self, values_to_transforms, memo):
251+
def __init__(self, values_to_transforms, value_clone_to_value):
256252
"""
257253
Parameters
258254
==========
@@ -261,20 +257,19 @@ def __init__(self, values_to_transforms, memo):
261257
value variable can be assigned one of `RVTransform`,
262258
`DEFAULT_TRANSFORM`, or ``None``. Random variables with no
263259
transform specified remain unchanged.
264-
memo
265-
Mapping from variables to their clones. This is updated
266-
in-place whenever a value variable is transformed.
267-
260+
value_clone_to_value
261+
Mapping between random variable value clones and their original
262+
value variables.
268263
"""
269264
self.values_to_transforms = values_to_transforms
270-
self.memo = memo
265+
self.value_clone_to_value = value_clone_to_value
271266

272267
def on_attach(self, fgraph):
273268
if hasattr(fgraph, "values_to_transforms"):
274269
raise AlreadyThere()
275270

276271
fgraph.values_to_transforms = self.values_to_transforms
277-
fgraph.memo = self.memo
272+
fgraph.value_clone_to_value = self.value_clone_to_value
278273

279274

280275
class TransformValuesRewrite(GraphRewriter):
@@ -322,6 +317,7 @@ def __init__(
322317
measurable variable can be assigned an `RVTransform` instance,
323318
`DEFAULT_TRANSFORM`, or ``None``. Measurable variables with no
324319
transform specified remain unchanged.
320+
rvs_to_values
325321
326322
"""
327323

@@ -330,14 +326,16 @@ def __init__(
330326
def add_requirements(
331327
self,
332328
fgraph,
333-
rv_to_values: Dict[TensorVariable, TensorVariable],
334-
memo: Dict[TensorVariable, TensorVariable],
329+
rvs_to_values: Dict[TensorVariable, TensorVariable],
330+
value_clone_to_value: Dict[TensorVariable, TensorVariable],
335331
):
336332
values_to_transforms = {
337-
rv_to_values[rv]: transform
333+
rvs_to_values[rv]: transform
338334
for rv, transform in self.rvs_to_transforms.items()
339335
}
340-
values_transforms_feature = TransformValuesMapping(values_to_transforms, memo)
336+
values_transforms_feature = TransformValuesMapping(
337+
values_to_transforms, value_clone_to_value
338+
)
341339
fgraph.attach_feature(values_transforms_feature)
342340

343341
def apply(self, fgraph: FunctionGraph):

tests/test_composite_logprob.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -79,25 +79,20 @@ def test_unvalued_ir_reversion():
7979
"""Make sure that un-valued IR rewrites are reverted."""
8080
srng = at.random.RandomStream(0)
8181

82-
x_rv = srng.normal()
82+
x_rv = srng.normal(name="X")
8383
y_rv = at.clip(x_rv, 0, 1)
84-
z_rv = srng.normal(y_rv, 1, name="z")
84+
y_rv.name = "Y"
85+
z_rv = srng.normal(y_rv, 1, name="Z")
8586
z_vv = z_rv.clone()
87+
z_vv.name = "z"
8688

8789
# Only the `z_rv` is "valued", so `y_rv` doesn't need to be converted into
8890
# measurable IR.
8991
rv_values = {z_rv: z_vv}
9092

91-
z_fgraph, _, memo = construct_ir_fgraph(rv_values)
93+
z_fgraph, new_rvs_to_values = construct_ir_fgraph(rv_values)
9294

93-
assert memo[y_rv] in z_fgraph.measurable_conversions
94-
95-
measurable_y_rv = z_fgraph.measurable_conversions[memo[y_rv]]
96-
assert isinstance(measurable_y_rv.owner.op, MeasurableClip)
97-
98-
# `construct_ir_fgraph` should've reverted the un-valued measurable IR
99-
# change
100-
assert measurable_y_rv not in z_fgraph
95+
assert not any(isinstance(node.op, MeasurableClip) for node in z_fgraph.apply_nodes)
10196

10297

10398
def test_shifted_cumsum():

tests/test_convolutions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_add_independent_normals(mu_x, mu_y, sigma_x, sigma_y, x_shape, y_shape,
7474
Z_rv.name = "Z"
7575
z_vv = Z_rv.clone()
7676

77-
fgraph, _, _ = construct_ir_fgraph({Z_rv: z_vv})
77+
fgraph, *_ = construct_ir_fgraph({Z_rv: z_vv})
7878

7979
(valued_var_out_node) = fgraph.outputs[0].owner
8080
# The convolution should be applied, and not the transform
@@ -108,7 +108,7 @@ def test_normal_add_input_valued():
108108
Z_rv.name = "Z"
109109
z_vv = Z_rv.clone()
110110

111-
fgraph, _, _ = construct_ir_fgraph({Z_rv: z_vv, X_rv: x_vv})
111+
fgraph, *_ = construct_ir_fgraph({Z_rv: z_vv, X_rv: x_vv})
112112

113113
valued_var_out_node = fgraph.outputs[0].owner
114114
# We should not expect the convolution to be applied; instead, the
@@ -136,7 +136,7 @@ def test_normal_add_three_inputs():
136136
Z_rv.name = "Z"
137137
z_vv = Z_rv.clone()
138138

139-
fgraph, _, _ = construct_ir_fgraph({Z_rv: z_vv})
139+
fgraph, *_ = construct_ir_fgraph({Z_rv: z_vv})
140140

141141
valued_var_out_node = fgraph.outputs[0].owner
142142
# The convolution should be applied, and not the transform

tests/test_mixture.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ def test_switch_mixture():
685685
z_vv = Z1_rv.clone()
686686
z_vv.name = "z1"
687687

688-
fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})
688+
fgraph, *_ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})
689689

690690
out_rv = fgraph.outputs[0].owner.inputs[0]
691691
assert isinstance(out_rv.owner.op, MixtureRV)
@@ -696,7 +696,7 @@ def test_switch_mixture():
696696

697697
Z1_rv.name = "Z1"
698698

699-
fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})
699+
fgraph, *_ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})
700700

701701
out_rv = fgraph.outputs[0].owner.inputs[0]
702702
assert out_rv.name == "Z1-mixture"
@@ -705,7 +705,7 @@ def test_switch_mixture():
705705

706706
Z2_rv = at.stack((X_rv, Y_rv))[I_rv]
707707

708-
fgraph2, _, _ = construct_ir_fgraph({Z2_rv: z_vv, I_rv: i_vv})
708+
fgraph2, *_ = construct_ir_fgraph({Z2_rv: z_vv, I_rv: i_vv})
709709

710710
assert equal_computations(fgraph.outputs, fgraph2.outputs)
711711

0 commit comments

Comments
 (0)