Skip to content

Commit 92791df

Browse files
committed
passing test
1 parent b42d20a commit 92791df

File tree

3 files changed

+57
-25
lines changed

3 files changed

+57
-25
lines changed

pyro/nn/module.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -233,29 +233,6 @@ def __get__(
233233
return value
234234

235235

236-
class PyroSamplePlateScope(pyro.poutine.messenger.Messenger):
237-
"""
238-
Handler for executing PyroSample statements in a more intuitive plate context.
239-
"""
240-
def __init__(self, allowed_plates: Iterable[str] = ()):
241-
self._inner_allowed_plates = frozenset(allowed_plates)
242-
243-
def __enter__(self):
244-
self._plates: frozenset[str] = frozenset(p.name for p in pyro.poutine.runtime.get_plates()) | self._inner_allowed_plates
245-
return super().__enter__()
246-
247-
def _is_local_plate(self, m: pyro.poutine.messenger.Messenger) -> bool:
248-
return isinstance(m, pyro.poutine.plate_messenger.PlateMessenger) and m.name not in self._plates
249-
250-
def _pyro_sample(self, msg):
251-
if not msg["infer"].get("_is_global_sample", False):
252-
return
253-
msg["stop"] = True
254-
msg["done"] = True
255-
with pyro.poutine.messenger.block_messenger(lambda m: m is self or self._is_local_plate(m)):
256-
msg["value"] = pyro.sample(msg["name"], msg["fn"], obs=msg["value"] if msg["is_observed"] else None, infer=msg["infer"])
257-
258-
259236
def _make_name(prefix: str, name: str) -> str:
260237
return "{}.{}".format(prefix, name) if prefix else name
261238

@@ -639,7 +616,7 @@ def __getattr__(self, name: str) -> Any:
639616
value = (
640617
pyro.deterministic(fullname, prior)
641618
if isinstance(prior, torch.Tensor)
642-
else pyro.sample(fullname, prior, infer={"_is_global_sample": True})
619+
else pyro.sample(fullname, prior, infer={"_original_pyrosample_dist": prior})
643620
)
644621
context.set(fullname, value)
645622
return value

pyro/poutine/plate_messenger.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from contextlib import contextmanager
5-
from typing import TYPE_CHECKING, Iterator, Optional
5+
from typing import TYPE_CHECKING, Iterable, Iterator, Optional
66

7+
import pyro
78
from pyro.poutine.broadcast_messenger import BroadcastMessenger
89
from pyro.poutine.messenger import Messenger, block_messengers
910
from pyro.poutine.subsample_messenger import SubsampleMessenger
@@ -88,3 +89,27 @@ def predicate(messenger: Messenger) -> bool:
8889
"setting strict=False."
8990
)
9091
yield
92+
93+
94+
class PyroSamplePlateScope(Messenger):
95+
"""
96+
Handler for executing PyroSample statements in a more intuitive plate context.
97+
"""
98+
def __init__(self, allowed_plates: Iterable[str] = ()):
99+
self._inner_allowed_plates = frozenset(allowed_plates)
100+
101+
def __enter__(self):
102+
self._plates: frozenset[str] = frozenset(p.name for p in pyro.poutine.runtime.get_plates()) | self._inner_allowed_plates
103+
return super().__enter__()
104+
105+
def _is_local_plate(self, m: Messenger) -> bool:
106+
return isinstance(m, PlateMessenger) and m.name not in self._plates
107+
108+
def _pyro_sample(self, msg):
109+
if not msg["infer"].get("_original_pyrosample_dist", None):
110+
return
111+
msg["stop"] = True
112+
msg["done"] = True
113+
with block_messengers(lambda m: m is self or self._is_local_plate(m)):
114+
d = msg["infer"].pop("_original_pyrosample_dist")
115+
msg["value"] = pyro.sample(msg["name"], d, obs=msg["value"] if msg["is_observed"] else None, infer=msg["infer"])

tests/nn/test_module.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,3 +1084,33 @@ def forward(self):
10841084
with pyro.settings.context(module_local_params=use_module_local_params):
10851085
model = Model()
10861086
pyro.render_model(model)
1087+
1088+
1089+
def test_pyrosample_platescope():
1090+
1091+
class Model(pyro.nn.PyroModule):
1092+
def __init__(self, num_inputs, num_outputs):
1093+
super().__init__()
1094+
self.num_inputs = num_inputs
1095+
self.num_outputs = num_outputs
1096+
self.linear = pyro.nn.PyroModule[torch.nn.Linear](num_inputs, num_outputs)
1097+
self.linear.weight = pyro.nn.PyroSample(dist.Normal(0, 1).expand([num_outputs, num_inputs]).to_event(2))
1098+
self.linear.bias = pyro.nn.PyroSample(dist.Normal(0, 1).expand([num_outputs]).to_event(1))
1099+
1100+
@pyro.nn.PyroSample
1101+
def scale(self):
1102+
return pyro.distributions.LogNormal(0, 1).expand([self.num_outputs]).to_event(1)
1103+
1104+
@pyro.poutine.plate_messenger.PyroSamplePlateScope()
1105+
def forward(self, x):
1106+
with pyro.plate("data", x.shape[-2], dim=-1):
1107+
assert len(self.linear.weight.shape) == 2 or self.linear.weight.shape[-3] != 1 # sampled outside data plate
1108+
loc = self.linear(x)
1109+
assert len(self.scale.shape) == 1 or self.scale.shape[-2] == 1 # sampled outside data plate
1110+
y = pyro.sample("y", dist.Normal(loc, self.scale).to_event(1))
1111+
assert y.shape[-2] == x.shape[-2] # ordinary pyro.sample statement
1112+
return y
1113+
1114+
model = Model(3, 2)
1115+
x = torch.randn(4, 3)
1116+
assert model(x).shape == (4, 2)

0 commit comments

Comments
 (0)