34
34
import utility_functions
35
35
36
36
import tsdate
37
- from tsdate .base import LIN , LOG
38
37
from tsdate .demography import PopulationSizeHistory
39
38
from tsdate .evaluation import remove_edges , unsupported_edges
39
+ from tsdate .node_time_class import LIN_GRID , LOG_GRID
40
40
41
41
42
42
class TestConstants :
@@ -129,7 +129,7 @@ def test_variational_gamma_unary_failure(self):
129
129
with pytest .raises (ValueError , match = "unary" ):
130
130
tsdate .variational_gamma (ts , mutation_rate = 1 )
131
131
132
- @pytest .mark .parametrize ("probability_space" , [LOG , LIN ])
132
+ @pytest .mark .parametrize ("probability_space" , [LOG_GRID , LIN_GRID ])
133
133
@pytest .mark .parametrize ("mu" , [None , 1 ])
134
134
def test_fails_with_recombination (self , probability_space , mu ):
135
135
ts = utility_functions .two_tree_mutation_ts ()
@@ -152,51 +152,60 @@ def test_default_alternative_time_units(self):
152
152
ts = tsdate .date (ts , mutation_rate = 1 , time_units = "years" )
153
153
assert ts .time_units == "years"
154
154
155
+ def test_deprecated_return_posteriors (self ):
156
+ ts = utility_functions .two_tree_mutation_ts ()
157
+ with pytest .raises (ValueError , match = "deprecated" ):
158
+ tsdate .date (ts , return_posteriors = True , mutation_rate = 1 )
159
+
155
160
def test_no_posteriors (self ):
156
161
ts = utility_functions .two_tree_mutation_ts ()
157
- with pytest . raises ( ValueError , match = "Cannot return posterior" ):
158
- tsdate . date (
159
- ts ,
160
- population_size = 1 ,
161
- return_posteriors = True ,
162
- method = "maximization" ,
163
- mutation_rate = 1 ,
164
- )
162
+ _ , model = tsdate . date (
163
+ ts ,
164
+ population_size = 1 ,
165
+ return_model = True ,
166
+ method = "maximization" ,
167
+ mutation_rate = 1 ,
168
+ )
169
+ assert model . node_posteriors () is None
165
170
166
171
def test_discretised_posteriors (self ):
167
172
ts = utility_functions .two_tree_mutation_ts ()
168
- ts , posteriors = tsdate .inside_outside (
169
- ts , mutation_rate = 1 , population_size = 1 , return_posteriors = True
173
+ ts , model = tsdate .inside_outside (
174
+ ts , mutation_rate = 1 , population_size = 1 , return_model = True
170
175
)
171
- assert len (posteriors ) == ts .num_nodes - ts .num_samples + 1
172
- assert len (posteriors ["time" ]) > 0
176
+ posteriors = model .node_posteriors ()
177
+ assert len (posteriors ) == ts .num_nodes
178
+ assert len (posteriors [0 ]) > 0
173
179
for node in ts .nodes ():
174
- if not node .is_sample ():
175
- assert node .id in posteriors
176
- assert len (posteriors [node .id ]) == len (posteriors ["time" ])
177
- assert np .isclose (np .sum (posteriors [node .id ]), 1 )
180
+ nd_vals = np .array (list (posteriors [node .id ]))
181
+ if node .is_sample ():
182
+ assert np .all (np .isnan (nd_vals ))
183
+ else :
184
+ assert np .isclose (np .sum (nd_vals ), 1 )
178
185
179
186
def test_variational_posteriors (self ):
180
- """
181
- There are no time-gridded posteriors returned by variational gamma,
182
- Output is currently None, but see https://github.com/tskit-dev/tsdate/issues/388
183
- """
184
187
ts = utility_functions .two_tree_mutation_ts ()
185
- ts , posteriors = tsdate .date (
188
+ ts , model = tsdate .date (
186
189
ts ,
187
190
mutation_rate = 1e-2 ,
188
191
method = "variational_gamma" ,
189
- return_posteriors = True ,
192
+ return_model = True ,
190
193
)
191
- assert posteriors is None
194
+ posteriors = model .node_posteriors ()
195
+ for nd in ts .nodes ():
196
+ mn , vr = posteriors [nd .id ]
197
+ assert posteriors ["mean" ][nd .id ] == mn
198
+ assert posteriors ["variance" ][nd .id ] == vr
199
+ assert np .isclose (nd .metadata ["mn" ], mn )
200
+ assert np .isclose (nd .metadata ["vr" ], vr )
192
201
193
202
def test_marginal_likelihood (self ):
194
203
ts = utility_functions .two_tree_mutation_ts ()
195
204
_ , _ , marg_lik = tsdate .inside_outside (
196
205
ts ,
197
206
mutation_rate = 1 ,
198
207
population_size = 1 ,
199
- return_posteriors = True ,
208
+ return_model = True ,
200
209
return_likelihood = True ,
201
210
)
202
211
_ , marg_lik_again = tsdate .inside_outside (
@@ -290,13 +299,13 @@ def test_linear_space(self):
290
299
ts , population_size = 10000 , timepoints = 10 , approximate_priors = None
291
300
)
292
301
dated_ts = tsdate .inside_outside (
293
- ts , mutation_rate = 1e-8 , priors = priors , probability_space = LIN
302
+ ts , mutation_rate = 1e-8 , priors = priors , probability_space = LIN_GRID
294
303
)
295
304
maximized_ts = tsdate .maximization (
296
305
ts ,
297
306
mutation_rate = 1e-8 ,
298
307
priors = priors ,
299
- probability_space = LIN ,
308
+ probability_space = LIN_GRID ,
300
309
)
301
310
self .ts_equal_except_times (ts , dated_ts )
302
311
self .ts_equal_except_times (ts , maximized_ts )
0 commit comments