Skip to content

Commit e2a2e32

Browse files
committed
Rework file struture to allow returning method
1 parent 573483f commit e2a2e32

File tree

8 files changed

+339
-293
lines changed

8 files changed

+339
-293
lines changed

tests/test_accuracy.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,7 @@ def test_basic(
108108
assert sim_mutations_parameters["command"] == "sim_mutations"
109109
mu = sim_mutations_parameters["rate"]
110110

111-
dts, posteriors = tsdate.inside_outside(
112-
ts, population_size=Ne, mutation_rate=mu, return_posteriors=True
113-
)
111+
dts = tsdate.inside_outside(ts, population_size=Ne, mutation_rate=mu)
114112
# make sure we can read node metadata - old tsdate versions didn't set a schema
115113
if dts.table_metadata_schemas.node.schema is None:
116114
tables = dts.dump_tables()

tests/test_functions.py

Lines changed: 70 additions & 77 deletions
Large diffs are not rendered by default.

tests/test_inference.py

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
import utility_functions
3535

3636
import tsdate
37-
from tsdate.base import LIN, LOG
3837
from tsdate.demography import PopulationSizeHistory
3938
from tsdate.evaluation import remove_edges, unsupported_edges
39+
from tsdate.node_time_class import LIN_GRID, LOG_GRID
4040

4141

4242
class TestConstants:
@@ -129,7 +129,7 @@ def test_variational_gamma_unary_failure(self):
129129
with pytest.raises(ValueError, match="unary"):
130130
tsdate.variational_gamma(ts, mutation_rate=1)
131131

132-
@pytest.mark.parametrize("probability_space", [LOG, LIN])
132+
@pytest.mark.parametrize("probability_space", [LOG_GRID, LIN_GRID])
133133
@pytest.mark.parametrize("mu", [None, 1])
134134
def test_fails_with_recombination(self, probability_space, mu):
135135
ts = utility_functions.two_tree_mutation_ts()
@@ -152,51 +152,60 @@ def test_default_alternative_time_units(self):
152152
ts = tsdate.date(ts, mutation_rate=1, time_units="years")
153153
assert ts.time_units == "years"
154154

155+
def test_deprecated_return_posteriors(self):
156+
ts = utility_functions.two_tree_mutation_ts()
157+
with pytest.raises(ValueError, match="deprecated"):
158+
tsdate.date(ts, return_posteriors=True, mutation_rate=1)
159+
155160
def test_no_posteriors(self):
156161
ts = utility_functions.two_tree_mutation_ts()
157-
with pytest.raises(ValueError, match="Cannot return posterior"):
158-
tsdate.date(
159-
ts,
160-
population_size=1,
161-
return_posteriors=True,
162-
method="maximization",
163-
mutation_rate=1,
164-
)
162+
_, model = tsdate.date(
163+
ts,
164+
population_size=1,
165+
return_model=True,
166+
method="maximization",
167+
mutation_rate=1,
168+
)
169+
assert model.node_posteriors() is None
165170

166171
def test_discretised_posteriors(self):
167172
ts = utility_functions.two_tree_mutation_ts()
168-
ts, posteriors = tsdate.inside_outside(
169-
ts, mutation_rate=1, population_size=1, return_posteriors=True
173+
ts, model = tsdate.inside_outside(
174+
ts, mutation_rate=1, population_size=1, return_model=True
170175
)
171-
assert len(posteriors) == ts.num_nodes - ts.num_samples + 1
172-
assert len(posteriors["time"]) > 0
176+
posteriors = model.node_posteriors()
177+
assert len(posteriors) == ts.num_nodes
178+
assert len(posteriors[0]) > 0
173179
for node in ts.nodes():
174-
if not node.is_sample():
175-
assert node.id in posteriors
176-
assert len(posteriors[node.id]) == len(posteriors["time"])
177-
assert np.isclose(np.sum(posteriors[node.id]), 1)
180+
nd_vals = np.array(list(posteriors[node.id]))
181+
if node.is_sample():
182+
assert np.all(np.isnan(nd_vals))
183+
else:
184+
assert np.isclose(np.sum(nd_vals), 1)
178185

179186
def test_variational_posteriors(self):
180-
"""
181-
There are no time-gridded posteriors returned by variational gamma,
182-
Output is currently None, but see https://github.com/tskit-dev/tsdate/issues/388
183-
"""
184187
ts = utility_functions.two_tree_mutation_ts()
185-
ts, posteriors = tsdate.date(
188+
ts, model = tsdate.date(
186189
ts,
187190
mutation_rate=1e-2,
188191
method="variational_gamma",
189-
return_posteriors=True,
192+
return_model=True,
190193
)
191-
assert posteriors is None
194+
posteriors = model.node_posteriors()
195+
for nd in ts.nodes():
196+
mn, vr = posteriors[nd.id]
197+
assert posteriors["mean"][nd.id] == mn
198+
assert posteriors["variance"][nd.id] == vr
199+
assert np.isclose(nd.metadata["mn"], mn)
200+
assert np.isclose(nd.metadata["vr"], vr)
192201

193202
def test_marginal_likelihood(self):
194203
ts = utility_functions.two_tree_mutation_ts()
195204
_, _, marg_lik = tsdate.inside_outside(
196205
ts,
197206
mutation_rate=1,
198207
population_size=1,
199-
return_posteriors=True,
208+
return_model=True,
200209
return_likelihood=True,
201210
)
202211
_, marg_lik_again = tsdate.inside_outside(
@@ -290,13 +299,13 @@ def test_linear_space(self):
290299
ts, population_size=10000, timepoints=10, approximate_priors=None
291300
)
292301
dated_ts = tsdate.inside_outside(
293-
ts, mutation_rate=1e-8, priors=priors, probability_space=LIN
302+
ts, mutation_rate=1e-8, priors=priors, probability_space=LIN_GRID
294303
)
295304
maximized_ts = tsdate.maximization(
296305
ts,
297306
mutation_rate=1e-8,
298307
priors=priors,
299-
probability_space=LIN,
308+
probability_space=LIN_GRID,
300309
)
301310
self.ts_equal_except_times(ts, dated_ts)
302311
self.ts_equal_except_times(ts, maximized_ts)

0 commit comments

Comments
 (0)