diff --git a/pyro/contrib/easyguide/easyguide.py b/pyro/contrib/easyguide/easyguide.py index 714615637e..6fced7dc58 100644 --- a/pyro/contrib/easyguide/easyguide.py +++ b/pyro/contrib/easyguide/easyguide.py @@ -230,7 +230,7 @@ def __init__(self, guide, sites): self.event_shape = torch.Size([sum(self._site_sizes.values())]) def __getstate__(self): - state = getattr(super(), "__getstate__", self.__dict__.copy)() + state = getattr(super(), "__getstate__", lambda: self.__dict__)().copy() state["_guide"] = state["_guide"]() # weakref -> ref return state diff --git a/tests/contrib/easyguide/test_easyguide.py b/tests/contrib/easyguide/test_easyguide.py index 4d15db167e..b4ee78d6fb 100644 --- a/tests/contrib/easyguide/test_easyguide.py +++ b/tests/contrib/easyguide/test_easyguide.py @@ -79,7 +79,6 @@ def guide(self, batch, subsample, full_size): self.group(match="state_[0-9]*").map_estimate() -@pytest.mark.xfail(reason="https://github.com/pyro-ppl/pyro/issues/3430") def test_serialize(): guide = PickleGuide(model) check_guide(guide)