diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b2401fc2..f41b15250 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,10 @@ ## [1.3.5] - 2025-XX-XX +**New features** + +- Demography objects can now be created from provenance entries ({pr}`{2369}, {user}`hyanwong`) + **Breaking changes**: - The `.asdict()` methods for Demography, Population, and Event classes in the diff --git a/msprime/provenance.py b/msprime/provenance.py index 4f99970fd..9c10cb6ce 100644 --- a/msprime/provenance.py +++ b/msprime/provenance.py @@ -31,6 +31,7 @@ import tskit from . import ancestry +from . import demography from msprime import _msprime __version__ = "undefined" @@ -188,10 +189,21 @@ def hook(obj): elif "__npgeneric__" in obj: return numpy.array([obj["__npgeneric__"]]).astype(obj["dtype"])[0] elif "__class__" in obj: - module, cls = obj["__class__"].rsplit(".", 1) + module, cls = obj.pop("__class__").rsplit(".", 1) module = importlib.import_module(module) - del obj["__class__"] - return getattr(module, cls)(**obj) + cls = getattr(module, cls) + if cls == demography.Demography: + # the Demography class normally generates its own population IDs, but + # allow fixed IDs here to ensure they match e.g. IDs in event objects + demography_object = object.__new__(cls) + try: + demography_object.__init__(**obj) + except ValueError as e: + if str(e).startswith("Population ID should not be set"): + return demography_object # This is the allowed fixed ID err + else: + raise + return cls(**obj) return obj return json.JSONDecoder(object_hook=hook).decode(s) diff --git a/tests/test_provenance.py b/tests/test_provenance.py index 5eba25b85..b5a5ed9a0 100644 --- a/tests/test_provenance.py +++ b/tests/test_provenance.py @@ -244,6 +244,39 @@ def test_current_ts(self): assert command == "sim_mutations" assert prov["tree_sequence"] == ts1 + def test_demography(self): + demography = msprime.Demography.island_model([1, 1], 1 / 3) + ts = msprime.sim_ancestry( + demography=demography, + samples=[ + msprime.SampleSet(1, population=0), + msprime.SampleSet(1, population=1), + ], + random_seed=3, + ) + command, prov = msprime.provenance.parse_provenance(ts.provenance(-1), ts) + assert command == "sim_ancestry" + assert prov["demography"] == demography + + def test_bad_demography(self): + demography = msprime.Demography.island_model([1, 1], 1 / 3) + ts = msprime.sim_ancestry( + demography=demography, + samples=[ + msprime.SampleSet(1, population=0), + msprime.SampleSet(1, population=1), + ], + random_seed=3, + ) + prov = ts.provenance(-1) + record = json.loads(prov.record) + # Corrupt the provenance record + assert len(record["parameters"]["demography"]["migration_matrix"]) == 2 + record["parameters"]["demography"]["migration_matrix"] = [1, 0, 0] + prov.record = json.dumps(record) + with pytest.raises(ValueError, match="Migration matrix must be square"): + msprime.provenance.parse_provenance(prov, ts) + class TestRoundTrip: """