Skip to content

Commit c05676f

Browse files
authored
Merge pull request #437 from hyanwong/extras
Allow unary nodes
2 parents bd9b072 + 40d6231 commit c05676f

File tree

4 files changed

+60
-46
lines changed

4 files changed

+60
-46
lines changed

docs/usage.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ command-line interface. See {ref}`sec_cli` for more details.
319319

320320
### Numerical stability and preprocessing
321321

322-
Numerical stability issues witll manifest themselves by raising an error when dating.
322+
Numerical stability issues will manifest themselves by raising an error when dating.
323323
They are usually caused by "bad" tree sequences (i.e.
324324
pathological combinations of topologies and mutations). These can be caused,
325325
for example, by long deep branches with very few mutations, such as samples attaching directly
@@ -343,7 +343,5 @@ increase or decrease its stringency.
343343

344344
:::{note}
345345
If unary regions are *correctly* estimated, they can help improve dating slightly.
346-
There is therefore a specific route to date a tree sequence containing locally unary
347-
nodes. For example, for discrete time methods, you can use the `allow_unary` option
348-
when {ref}`building a prior<sec_priors>`.
346+
You can set the `allow_unary=True` option to run tsdate on such tree sequences.
349347
:::

tests/test_inference.py

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import pytest
3232
import tsinfer
3333
import tskit
34-
import utility_functions
34+
import utility_functions as util
3535

3636
import tsdate
3737
from tsdate.demography import PopulationSizeHistory
@@ -50,24 +50,24 @@ class TestPrebuilt:
5050
"""
5151

5252
def test_invalid_method_failure(self):
53-
ts = utility_functions.two_tree_mutation_ts()
53+
ts = util.two_tree_mutation_ts()
5454
with pytest.raises(ValueError, match="method must be one of"):
5555
tsdate.date(ts, population_size=1, mutation_rate=None, method="foo")
5656

5757
def test_no_mutations_failure(self):
58-
ts = utility_functions.single_tree_ts_n2()
58+
ts = util.single_tree_ts_n2()
5959
with pytest.raises(ValueError, match="No mutations present"):
6060
tsdate.variational_gamma(ts, mutation_rate=1)
6161

6262
def test_no_population_size(self):
63-
ts = utility_functions.two_tree_mutation_ts()
63+
ts = util.two_tree_mutation_ts()
6464
with pytest.raises(ValueError, match="Must specify population size"):
6565
tsdate.inside_outside(ts, mutation_rate=None)
6666

6767
def test_no_mutation(self):
6868
for ts in (
69-
utility_functions.two_tree_mutation_ts(),
70-
utility_functions.single_tree_ts_mutation_n3(),
69+
util.two_tree_mutation_ts(),
70+
util.single_tree_ts_mutation_n3(),
7171
):
7272
with pytest.raises(ValueError, match="method requires mutation rate"):
7373
tsdate.date(
@@ -86,53 +86,59 @@ def test_no_mutation(self):
8686
)
8787

8888
def test_not_needed_population_size(self):
89-
ts = utility_functions.two_tree_mutation_ts()
89+
ts = util.two_tree_mutation_ts()
9090
prior = tsdate.build_prior_grid(ts, population_size=1, timepoints=10)
9191
with pytest.raises(ValueError, match="Cannot specify population size"):
9292
tsdate.inside_outside(ts, population_size=1, mutation_rate=None, priors=prior)
9393

9494
def test_bad_population_size(self):
95-
ts = utility_functions.two_tree_mutation_ts()
95+
ts = util.two_tree_mutation_ts()
9696
for Ne in [0, -1]:
9797
with pytest.raises(ValueError, match="greater than 0"):
9898
tsdate.inside_outside(ts, mutation_rate=None, population_size=Ne)
9999

100100
def test_both_ne_and_population_size_specified(self):
101-
ts = utility_functions.two_tree_mutation_ts()
101+
ts = util.two_tree_mutation_ts()
102102
with pytest.raises(ValueError, match="Only provide one of Ne"):
103103
tsdate.inside_outside(
104104
ts, mutation_rate=1, population_size=PopulationSizeHistory(1), Ne=1
105105
)
106106
tsdate.inside_outside(ts, mutation_rate=1, Ne=PopulationSizeHistory(1))
107107

108108
def test_inside_outside_dangling_failure(self):
109-
ts = utility_functions.single_tree_ts_n2_dangling()
109+
ts = util.single_tree_ts_n2_dangling()
110110
with pytest.raises(ValueError, match="simplified"):
111111
tsdate.inside_outside(ts, mutation_rate=None, population_size=1)
112112

113113
def test_variational_gamma_dangling(self):
114114
# Dangling nodes are fine for the variational gamma method
115-
ts = utility_functions.single_tree_ts_n2_dangling()
115+
ts = util.single_tree_ts_n2_dangling()
116116
ts = msprime.sim_mutations(ts, rate=2, random_seed=1)
117117
assert ts.num_mutations > 1
118118
tsdate.variational_gamma(ts, mutation_rate=2)
119119

120120
def test_inside_outside_unary_failure(self):
121-
ts = utility_functions.single_tree_ts_with_unary()
121+
ts = util.single_tree_ts_with_unary()
122122
with pytest.raises(ValueError, match="unary"):
123123
tsdate.inside_outside(ts, mutation_rate=None, population_size=1)
124124

125-
@pytest.mark.skip("V_gamma should fail with unary nodes, but doesn't currently")
126-
def test_variational_gamma_unary_failure(self):
127-
ts = utility_functions.single_tree_ts_with_unary()
125+
@pytest.mark.parametrize("method", tsdate.estimation_methods.keys())
126+
@pytest.mark.parametrize(
127+
"ts", [util.single_tree_ts_with_unary(), util.two_tree_ts_with_unary_n3()]
128+
)
129+
def test_allow_unary(self, method, ts):
130+
Ne = None if method == "variational_gamma" else 1
128131
ts = msprime.sim_mutations(ts, rate=1, random_seed=1)
129132
with pytest.raises(ValueError, match="unary"):
130-
tsdate.variational_gamma(ts, mutation_rate=1)
133+
tsdate.date(ts, method=method, population_size=Ne, mutation_rate=1)
134+
tsdate.date(
135+
ts, method=method, population_size=Ne, mutation_rate=1, allow_unary=True
136+
)
131137

132138
@pytest.mark.parametrize("probability_space", [LOG_GRID, LIN_GRID])
133139
@pytest.mark.parametrize("mu", [None, 1])
134140
def test_fails_with_recombination(self, probability_space, mu):
135-
ts = utility_functions.two_tree_mutation_ts()
141+
ts = util.two_tree_mutation_ts()
136142
with pytest.raises(NotImplementedError):
137143
tsdate.inside_outside(
138144
ts,
@@ -143,27 +149,27 @@ def test_fails_with_recombination(self, probability_space, mu):
143149
)
144150

145151
def test_default_time_units(self):
146-
ts = utility_functions.two_tree_mutation_ts()
152+
ts = util.two_tree_mutation_ts()
147153
ts = tsdate.date(ts, mutation_rate=1)
148154
assert ts.time_units == "generations"
149155

150156
def test_default_alternative_time_units(self):
151-
ts = utility_functions.two_tree_mutation_ts()
157+
ts = util.two_tree_mutation_ts()
152158
ts = tsdate.date(ts, mutation_rate=1, time_units="years")
153159
assert ts.time_units == "years"
154160

155161
def test_deprecated_return_posteriors(self):
156-
ts = utility_functions.two_tree_mutation_ts()
162+
ts = util.two_tree_mutation_ts()
157163
with pytest.raises(ValueError, match="deprecated"):
158164
tsdate.date(ts, return_posteriors=True, mutation_rate=1)
159165

160166
def test_return_fit(self):
161-
ts = utility_functions.two_tree_mutation_ts()
167+
ts = util.two_tree_mutation_ts()
162168
_, fit = tsdate.date(ts, return_fit=True, mutation_rate=1)
163169
assert hasattr(fit, "node_posteriors")
164170

165171
def test_no_maximization_posteriors(self):
166-
ts = utility_functions.two_tree_mutation_ts()
172+
ts = util.two_tree_mutation_ts()
167173
_, fit = tsdate.date(
168174
ts,
169175
population_size=1,
@@ -175,7 +181,7 @@ def test_no_maximization_posteriors(self):
175181
fit.node_posteriors()
176182

177183
def test_discretised_posteriors(self):
178-
ts = utility_functions.two_tree_mutation_ts()
184+
ts = util.two_tree_mutation_ts()
179185
ts, fit = tsdate.inside_outside(
180186
ts, mutation_rate=1, population_size=1, return_fit=True
181187
)
@@ -190,7 +196,7 @@ def test_discretised_posteriors(self):
190196
assert np.isclose(np.sum(nd_vals), 1)
191197

192198
def test_variational_node_posteriors(self):
193-
ts = utility_functions.two_tree_mutation_ts()
199+
ts = util.two_tree_mutation_ts()
194200
ts, fit = tsdate.date(
195201
ts,
196202
mutation_rate=1e-2,
@@ -206,7 +212,7 @@ def test_variational_node_posteriors(self):
206212
assert np.isclose(nd.metadata["vr"], vr)
207213

208214
def test_variational_mutation_posteriors(self):
209-
ts = utility_functions.two_tree_mutation_ts()
215+
ts = util.two_tree_mutation_ts()
210216
ts, fit = tsdate.date(
211217
ts,
212218
mutation_rate=1e-2,
@@ -223,7 +229,7 @@ def test_variational_mutation_posteriors(self):
223229

224230
def test_variational_mean_edge_logconst(self):
225231
# This should give a guide to EP convergence
226-
ts = utility_functions.two_tree_mutation_ts()
232+
ts = util.two_tree_mutation_ts()
227233
ts, fit = tsdate.date(
228234
ts,
229235
mutation_rate=1e-2,
@@ -239,7 +245,7 @@ def test_variational_mean_edge_logconst(self):
239245
assert np.all(obs[5:] == test_vals[-1])
240246

241247
def test_marginal_likelihood(self):
242-
ts = utility_functions.two_tree_mutation_ts()
248+
ts = util.two_tree_mutation_ts()
243249
_, _, marg_lik = tsdate.inside_outside(
244250
ts,
245251
mutation_rate=1,
@@ -253,8 +259,8 @@ def test_marginal_likelihood(self):
253259
assert marg_lik == marg_lik_again
254260

255261
def test_intervals(self):
256-
ts = utility_functions.two_tree_ts()
257-
long_ts = utility_functions.two_tree_ts_extra_length()
262+
ts = util.two_tree_ts()
263+
long_ts = util.two_tree_ts_extra_length()
258264
keep_ts = long_ts.keep_intervals([[0.0, 1.0]])
259265
del_ts = long_ts.delete_intervals([[1.0, 1.5]])
260266
dat_ts = tsdate.inside_outside(ts, mutation_rate=1, population_size=1)
@@ -414,9 +420,7 @@ def test_truncated_ts(self):
414420
mutation_rate=mu,
415421
random_seed=12,
416422
)
417-
truncated_ts = utility_functions.truncate_ts_samples(
418-
ts, average_span=200, random_seed=123
419-
)
423+
truncated_ts = util.truncate_ts_samples(ts, average_span=200, random_seed=123)
420424
dated_ts = tsdate.date(truncated_ts, population_size=Ne, mutation_rate=mu)
421425
# We should ideally test whether *haplotypes* are the same here
422426
# in case allele encoding has changed. But haplotypes() doesn't currently

tsdate/core.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(
8484
priors=None,
8585
return_likelihood=None,
8686
return_fit=None,
87+
allow_unary=None,
8788
record_provenance=None,
8889
constr_iterations=None,
8990
progress=None,
@@ -140,6 +141,8 @@ def __init__(
140141
)
141142
self.constr_iterations = constr_iterations
142143

144+
self.allow_unary = False if allow_unary is None else allow_unary
145+
143146
if self.prior_grid_func_name is None:
144147
if priors is not None:
145148
raise ValueError(f"Priors are not used for method {self.name}")
@@ -157,7 +160,11 @@ def __init__(
157160
# greater than DEFAULT_APPROX_PRIOR_SIZE samples
158161
approx = ts.num_samples > prior.DEFAULT_APPROX_PRIOR_SIZE
159162
self.priors = mk_prior(
160-
ts, Ne, approximate_priors=approx, progress=progress
163+
ts,
164+
Ne,
165+
approximate_priors=approx,
166+
allow_unary=self.allow_unary,
167+
progress=progress,
161168
)
162169
else:
163170
logger.info("Using user-specified priors")
@@ -444,6 +451,7 @@ def run(
444451
fit_obj = variational.ExpectationPropagation(
445452
self.ts,
446453
mutation_rate=self.mutation_rate,
454+
allow_unary=self.allow_unary,
447455
singletons_phased=singletons_phased,
448456
)
449457
fit_obj.infer(
@@ -552,7 +560,7 @@ def maximization(
552560
"linear" space (fast, may overflow). Default: None treated as"logarithmic"
553561
:param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper
554562
function, notably ``mutation_rate``, and ``population_size`` or ``priors``.
555-
Further arguments include ``time_units``, ``progress``, and
563+
Further arguments include ``time_units``, ``progress``, ``allow_unary`` and
556564
``record_provenance``. The additional arguments ``return_fit`` and
557565
``return_likelihood`` can be used to return additional information (see below).
558566
:return:
@@ -685,7 +693,7 @@ def inside_outside(
685693
"linear" space (fast, may overflow). Default: "logarithmic"
686694
:param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper
687695
function, notably ``mutation_rate``, and ``population_size`` or ``priors``.
688-
Further arguments include ``time_units``, ``progress``, and
696+
Further arguments include ``time_units``, ``progress``, ``allow_unary`` and
689697
``record_provenance``. The additional arguments ``return_fit`` and
690698
``return_likelihood`` can be used to return additional information (see below).
691699
:return:
@@ -784,9 +792,9 @@ def variational_gamma(
784792
length are approximately equal, which gives unbiased estimates when there
785793
are polytomies. Default ``False``.
786794
:param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper
787-
function, including ``time_units``, ``progress``, and ``record_provenance``.
788-
The arguments ``return_fit`` and ``return_likelihood`` can be
789-
used to return additional information (see below).
795+
function, including ``time_units``, ``progress``, ``allow_unary`` and
796+
``record_provenance``. The arguments ``return_fit`` and ``return_likelihood``
797+
can be used to return additional information (see below).
790798
:return:
791799
- **ts** (:class:`~tskit.TreeSequence`) -- a copy of the input tree sequence with
792800
updated node times based on the posterior mean, corrected where necessary to
@@ -866,6 +874,7 @@ def date(
866874
constr_iterations=None,
867875
return_fit=None,
868876
return_likelihood=None,
877+
allow_unary=None,
869878
progress=None,
870879
record_provenance=True,
871880
# Other kwargs documented in the functions for each specific estimation-method
@@ -919,6 +928,8 @@ def date(
919928
from the inside algorithm in addition to the dated tree sequence. If
920929
``return_fit`` is also ``True``, then the marginal likelihood
921930
will be the last element of the tuple. Default: None, treated as False.
931+
:param bool allow_unary: Allow nodes that are "locally unary" (i.e. have only
932+
one child in one or more local trees). Default: None, treated as False.
922933
:param bool progress: Show a progress bar. Default: None, treated as False.
923934
:param bool record_provenance: Should the tsdate command be appended to the
924935
provenence information in the returned tree sequence?
@@ -947,6 +958,7 @@ def date(
947958
constr_iterations=constr_iterations,
948959
return_fit=return_fit,
949960
return_likelihood=return_likelihood,
961+
allow_unary=allow_unary,
950962
record_provenance=record_provenance,
951963
**kwargs,
952964
)

tsdate/variational.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,10 @@ def _check_valid_constraints(constraints, edges_parent, edges_child):
156156
)
157157

158158
@staticmethod
159-
def _check_valid_inputs(ts, mutation_rate):
159+
def _check_valid_inputs(ts, mutation_rate, allow_unary):
160160
if not mutation_rate > 0.0:
161161
raise ValueError("Mutation rate must be positive")
162-
if contains_unary_nodes(ts):
162+
if not allow_unary and contains_unary_nodes(ts):
163163
raise ValueError("Tree sequence contains unary nodes, simplify first")
164164

165165
@staticmethod
@@ -185,7 +185,7 @@ def _check_valid_state(
185185
posterior_check += node_factors[:, CONSTRNT]
186186
np.testing.assert_allclose(posterior_check, posterior)
187187

188-
def __init__(self, ts, *, mutation_rate, singletons_phased=True):
188+
def __init__(self, ts, *, mutation_rate, allow_unary=None, singletons_phased=True):
189189
"""
190190
Initialize an expectation propagation algorithm for dating nodes
191191
in a tree sequence.
@@ -202,7 +202,7 @@ def __init__(self, ts, *, mutation_rate, singletons_phased=True):
202202
time unit.
203203
"""
204204

205-
self._check_valid_inputs(ts, mutation_rate)
205+
self._check_valid_inputs(ts, mutation_rate, allow_unary)
206206
self.edge_parents = ts.edges_parent
207207
self.edge_children = ts.edges_child
208208

0 commit comments

Comments
 (0)