31
31
import pytest
32
32
import tsinfer
33
33
import tskit
34
- import utility_functions
34
+ import utility_functions as util
35
35
36
36
import tsdate
37
37
from tsdate .demography import PopulationSizeHistory
@@ -50,24 +50,24 @@ class TestPrebuilt:
50
50
"""
51
51
52
52
def test_invalid_method_failure (self ):
53
- ts = utility_functions .two_tree_mutation_ts ()
53
+ ts = util .two_tree_mutation_ts ()
54
54
with pytest .raises (ValueError , match = "method must be one of" ):
55
55
tsdate .date (ts , population_size = 1 , mutation_rate = None , method = "foo" )
56
56
57
57
def test_no_mutations_failure (self ):
58
- ts = utility_functions .single_tree_ts_n2 ()
58
+ ts = util .single_tree_ts_n2 ()
59
59
with pytest .raises (ValueError , match = "No mutations present" ):
60
60
tsdate .variational_gamma (ts , mutation_rate = 1 )
61
61
62
62
def test_no_population_size (self ):
63
- ts = utility_functions .two_tree_mutation_ts ()
63
+ ts = util .two_tree_mutation_ts ()
64
64
with pytest .raises (ValueError , match = "Must specify population size" ):
65
65
tsdate .inside_outside (ts , mutation_rate = None )
66
66
67
67
def test_no_mutation (self ):
68
68
for ts in (
69
- utility_functions .two_tree_mutation_ts (),
70
- utility_functions .single_tree_ts_mutation_n3 (),
69
+ util .two_tree_mutation_ts (),
70
+ util .single_tree_ts_mutation_n3 (),
71
71
):
72
72
with pytest .raises (ValueError , match = "method requires mutation rate" ):
73
73
tsdate .date (
@@ -86,53 +86,59 @@ def test_no_mutation(self):
86
86
)
87
87
88
88
def test_not_needed_population_size (self ):
89
- ts = utility_functions .two_tree_mutation_ts ()
89
+ ts = util .two_tree_mutation_ts ()
90
90
prior = tsdate .build_prior_grid (ts , population_size = 1 , timepoints = 10 )
91
91
with pytest .raises (ValueError , match = "Cannot specify population size" ):
92
92
tsdate .inside_outside (ts , population_size = 1 , mutation_rate = None , priors = prior )
93
93
94
94
def test_bad_population_size (self ):
95
- ts = utility_functions .two_tree_mutation_ts ()
95
+ ts = util .two_tree_mutation_ts ()
96
96
for Ne in [0 , - 1 ]:
97
97
with pytest .raises (ValueError , match = "greater than 0" ):
98
98
tsdate .inside_outside (ts , mutation_rate = None , population_size = Ne )
99
99
100
100
def test_both_ne_and_population_size_specified (self ):
101
- ts = utility_functions .two_tree_mutation_ts ()
101
+ ts = util .two_tree_mutation_ts ()
102
102
with pytest .raises (ValueError , match = "Only provide one of Ne" ):
103
103
tsdate .inside_outside (
104
104
ts , mutation_rate = 1 , population_size = PopulationSizeHistory (1 ), Ne = 1
105
105
)
106
106
tsdate .inside_outside (ts , mutation_rate = 1 , Ne = PopulationSizeHistory (1 ))
107
107
108
108
def test_inside_outside_dangling_failure (self ):
109
- ts = utility_functions .single_tree_ts_n2_dangling ()
109
+ ts = util .single_tree_ts_n2_dangling ()
110
110
with pytest .raises (ValueError , match = "simplified" ):
111
111
tsdate .inside_outside (ts , mutation_rate = None , population_size = 1 )
112
112
113
113
def test_variational_gamma_dangling (self ):
114
114
# Dangling nodes are fine for the variational gamma method
115
- ts = utility_functions .single_tree_ts_n2_dangling ()
115
+ ts = util .single_tree_ts_n2_dangling ()
116
116
ts = msprime .sim_mutations (ts , rate = 2 , random_seed = 1 )
117
117
assert ts .num_mutations > 1
118
118
tsdate .variational_gamma (ts , mutation_rate = 2 )
119
119
120
120
def test_inside_outside_unary_failure (self ):
121
- ts = utility_functions .single_tree_ts_with_unary ()
121
+ ts = util .single_tree_ts_with_unary ()
122
122
with pytest .raises (ValueError , match = "unary" ):
123
123
tsdate .inside_outside (ts , mutation_rate = None , population_size = 1 )
124
124
125
- @pytest .mark .skip ("V_gamma should fail with unary nodes, but doesn't currently" )
126
- def test_variational_gamma_unary_failure (self ):
127
- ts = utility_functions .single_tree_ts_with_unary ()
125
+ @pytest .mark .parametrize ("method" , tsdate .estimation_methods .keys ())
126
+ @pytest .mark .parametrize (
127
+ "ts" , [util .single_tree_ts_with_unary (), util .two_tree_ts_with_unary_n3 ()]
128
+ )
129
+ def test_allow_unary (self , method , ts ):
130
+ Ne = None if method == "variational_gamma" else 1
128
131
ts = msprime .sim_mutations (ts , rate = 1 , random_seed = 1 )
129
132
with pytest .raises (ValueError , match = "unary" ):
130
- tsdate .variational_gamma (ts , mutation_rate = 1 )
133
+ tsdate .date (ts , method = method , population_size = Ne , mutation_rate = 1 )
134
+ tsdate .date (
135
+ ts , method = method , population_size = Ne , mutation_rate = 1 , allow_unary = True
136
+ )
131
137
132
138
@pytest .mark .parametrize ("probability_space" , [LOG_GRID , LIN_GRID ])
133
139
@pytest .mark .parametrize ("mu" , [None , 1 ])
134
140
def test_fails_with_recombination (self , probability_space , mu ):
135
- ts = utility_functions .two_tree_mutation_ts ()
141
+ ts = util .two_tree_mutation_ts ()
136
142
with pytest .raises (NotImplementedError ):
137
143
tsdate .inside_outside (
138
144
ts ,
@@ -143,27 +149,27 @@ def test_fails_with_recombination(self, probability_space, mu):
143
149
)
144
150
145
151
def test_default_time_units (self ):
146
- ts = utility_functions .two_tree_mutation_ts ()
152
+ ts = util .two_tree_mutation_ts ()
147
153
ts = tsdate .date (ts , mutation_rate = 1 )
148
154
assert ts .time_units == "generations"
149
155
150
156
def test_default_alternative_time_units (self ):
151
- ts = utility_functions .two_tree_mutation_ts ()
157
+ ts = util .two_tree_mutation_ts ()
152
158
ts = tsdate .date (ts , mutation_rate = 1 , time_units = "years" )
153
159
assert ts .time_units == "years"
154
160
155
161
def test_deprecated_return_posteriors (self ):
156
- ts = utility_functions .two_tree_mutation_ts ()
162
+ ts = util .two_tree_mutation_ts ()
157
163
with pytest .raises (ValueError , match = "deprecated" ):
158
164
tsdate .date (ts , return_posteriors = True , mutation_rate = 1 )
159
165
160
166
def test_return_fit (self ):
161
- ts = utility_functions .two_tree_mutation_ts ()
167
+ ts = util .two_tree_mutation_ts ()
162
168
_ , fit = tsdate .date (ts , return_fit = True , mutation_rate = 1 )
163
169
assert hasattr (fit , "node_posteriors" )
164
170
165
171
def test_no_maximization_posteriors (self ):
166
- ts = utility_functions .two_tree_mutation_ts ()
172
+ ts = util .two_tree_mutation_ts ()
167
173
_ , fit = tsdate .date (
168
174
ts ,
169
175
population_size = 1 ,
@@ -175,7 +181,7 @@ def test_no_maximization_posteriors(self):
175
181
fit .node_posteriors ()
176
182
177
183
def test_discretised_posteriors (self ):
178
- ts = utility_functions .two_tree_mutation_ts ()
184
+ ts = util .two_tree_mutation_ts ()
179
185
ts , fit = tsdate .inside_outside (
180
186
ts , mutation_rate = 1 , population_size = 1 , return_fit = True
181
187
)
@@ -190,7 +196,7 @@ def test_discretised_posteriors(self):
190
196
assert np .isclose (np .sum (nd_vals ), 1 )
191
197
192
198
def test_variational_node_posteriors (self ):
193
- ts = utility_functions .two_tree_mutation_ts ()
199
+ ts = util .two_tree_mutation_ts ()
194
200
ts , fit = tsdate .date (
195
201
ts ,
196
202
mutation_rate = 1e-2 ,
@@ -206,7 +212,7 @@ def test_variational_node_posteriors(self):
206
212
assert np .isclose (nd .metadata ["vr" ], vr )
207
213
208
214
def test_variational_mutation_posteriors (self ):
209
- ts = utility_functions .two_tree_mutation_ts ()
215
+ ts = util .two_tree_mutation_ts ()
210
216
ts , fit = tsdate .date (
211
217
ts ,
212
218
mutation_rate = 1e-2 ,
@@ -223,7 +229,7 @@ def test_variational_mutation_posteriors(self):
223
229
224
230
def test_variational_mean_edge_logconst (self ):
225
231
# This should give a guide to EP convergence
226
- ts = utility_functions .two_tree_mutation_ts ()
232
+ ts = util .two_tree_mutation_ts ()
227
233
ts , fit = tsdate .date (
228
234
ts ,
229
235
mutation_rate = 1e-2 ,
@@ -239,7 +245,7 @@ def test_variational_mean_edge_logconst(self):
239
245
assert np .all (obs [5 :] == test_vals [- 1 ])
240
246
241
247
def test_marginal_likelihood (self ):
242
- ts = utility_functions .two_tree_mutation_ts ()
248
+ ts = util .two_tree_mutation_ts ()
243
249
_ , _ , marg_lik = tsdate .inside_outside (
244
250
ts ,
245
251
mutation_rate = 1 ,
@@ -253,8 +259,8 @@ def test_marginal_likelihood(self):
253
259
assert marg_lik == marg_lik_again
254
260
255
261
def test_intervals (self ):
256
- ts = utility_functions .two_tree_ts ()
257
- long_ts = utility_functions .two_tree_ts_extra_length ()
262
+ ts = util .two_tree_ts ()
263
+ long_ts = util .two_tree_ts_extra_length ()
258
264
keep_ts = long_ts .keep_intervals ([[0.0 , 1.0 ]])
259
265
del_ts = long_ts .delete_intervals ([[1.0 , 1.5 ]])
260
266
dat_ts = tsdate .inside_outside (ts , mutation_rate = 1 , population_size = 1 )
@@ -414,9 +420,7 @@ def test_truncated_ts(self):
414
420
mutation_rate = mu ,
415
421
random_seed = 12 ,
416
422
)
417
- truncated_ts = utility_functions .truncate_ts_samples (
418
- ts , average_span = 200 , random_seed = 123
419
- )
423
+ truncated_ts = util .truncate_ts_samples (ts , average_span = 200 , random_seed = 123 )
420
424
dated_ts = tsdate .date (truncated_ts , population_size = Ne , mutation_rate = mu )
421
425
# We should ideally test whether *haplotypes* are the same here
422
426
# in case allele encoding has changed. But haplotypes() doesn't currently
0 commit comments