Skip to content

Commit f797a4a

Browse files
committed
Record resources used in tsdate inference
1 parent 2b3412a commit f797a4a

File tree

5 files changed

+42
-10
lines changed

5 files changed

+42
-10
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
tskit>=0.5.8
1+
tskit>=0.6.0
22
tsinfer>=0.3.0
33
ruff
44
numpy

tests/test_provenance.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,19 @@ def test_date_params_recorded(self):
6161
assert np.isclose(rec["parameters"]["population_size"], Ne)
6262
assert rec["parameters"]["command"] == "maximization"
6363

64+
def test_date_time_recorded(self):
65+
ts = utility_functions.single_tree_ts_n2()
66+
mu = 0.123
67+
Ne = 9
68+
dated_ts = tsdate.date(
69+
ts, population_size=Ne, mutation_rate=mu, method="maximization"
70+
)
71+
rec = json.loads(dated_ts.provenance(-1).record)
72+
assert "resources" in rec
73+
assert rec["resources"]["elapsed_time"] >= 0
74+
assert rec["resources"]["user_time"] >= 0
75+
assert rec["resources"]["sys_time"] >= 0
76+
6477
@pytest.mark.parametrize(
6578
"popdict",
6679
[
@@ -119,6 +132,15 @@ def test_preprocess_interval_recorded(self):
119132
assert 40 < deleted_intervals[0][0] < 60
120133
assert 40 < deleted_intervals[0][1] < 60
121134

135+
def test_preprocess_time_recorded(self):
136+
ts = utility_functions.ts_w_data_desert(40, 60, 100)
137+
preprocessed_ts = tsdate.preprocess_ts(ts, minimum_gap=20)
138+
rec = json.loads(preprocessed_ts.provenance(-1).record)
139+
assert "resources" in rec
140+
assert rec["resources"]["elapsed_time"] >= 0
141+
assert rec["resources"]["user_time"] >= 0
142+
assert rec["resources"]["sys_time"] >= 0
143+
122144
@pytest.mark.parametrize("method", tsdate.core.estimation_methods.keys())
123145
def test_named_methods(self, method):
124146
ts = utility_functions.single_tree_ts_mutation_n3()

tsdate/core.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __init__(
100100
"then access `fit.node_posteriors()` to obtain a transposed version "
101101
"of the matrix previously returned when ``return_posteriors=True.``"
102102
)
103+
self.start_time = time.time()
103104
self.ts = ts
104105
self.mutation_rate = mutation_rate
105106
self.recombination_rate = recombination_rate
@@ -193,8 +194,6 @@ def get_modified_ts(self, result, eps):
193194
nodes = tables.nodes
194195
mutations = tables.mutations
195196

196-
if self.provenance_params is not None:
197-
provenance.record_provenance(tables, self.name, **self.provenance_params)
198197
# Constrain node ages for positive branch lengths
199198
constr_timing = time.time()
200199
nodes.time = util.constrain_ages(ts, node_mean_t, eps, self.constr_iterations)
@@ -220,6 +219,11 @@ def get_modified_ts(self, result, eps):
220219
tables.compute_mutation_parents()
221220
sort_timing -= time.time()
222221
logger.info(f"Sorted tree sequence in {abs(sort_timing):.2f} seconds")
222+
if self.provenance_params is not None:
223+
# Note that the time recorded in provenance excludes numba compilation time
224+
provenance.record_provenance(
225+
tables, self.name, self.start_time, **self.provenance_params
226+
)
223227
return tables.tree_sequence()
224228

225229
def set_time_metadata(self, table, mean, var, default_schema, overwrite=False):

tsdate/provenance.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def get_environment():
6464
return env
6565

6666

67-
def get_provenance_dict(command, **kwargs):
67+
def get_provenance_dict(command, start_time=None, **kwargs):
6868
"""
6969
Returns a dictionary encoding an execution of tsdate conforming to the
7070
tskit provenance schema.
@@ -78,14 +78,15 @@ def get_provenance_dict(command, **kwargs):
7878
"software": {"name": "tsdate", "version": __version__},
7979
"parameters": parameters,
8080
"environment": get_environment(),
81+
"resources": tskit.provenance.get_resources(start_time),
8182
}
8283
return document
8384

8485

85-
def record_provenance(tables, command=None, **kwargs):
86+
def record_provenance(tables, command=None, start_time=None, **kwargs):
8687
"""
8788
Adds provenance information to this table collection using the
8889
tskit provenances schema.
8990
"""
90-
record = get_provenance_dict(command=command, **kwargs)
91+
record = get_provenance_dict(command=command, start_time=start_time, **kwargs)
9192
tables.provenances.add_row(record=json.dumps(record))

tsdate/util.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import json
2828
import logging
29+
import time
2930

3031
import numba
3132
import numpy as np
@@ -115,6 +116,7 @@ def preprocess_ts(
115116
"""
116117

117118
logger.info("Beginning preprocessing")
119+
start_time = time.time()
118120
logger.info(f"Minimum_gap: {minimum_gap} and remove_telomeres: {remove_telomeres}")
119121
if split_disjoint is None:
120122
split_disjoint = True
@@ -184,10 +186,14 @@ def preprocess_ts(
184186
record_provenance=False,
185187
**kwargs,
186188
)
189+
if split_disjoint:
190+
ts = split_disjoint_nodes(tables.tree_sequence(), record_provenance=False)
191+
tables = ts.dump_tables()
187192
if record_provenance:
188193
provenance.record_provenance(
189194
tables,
190195
"preprocess_ts",
196+
start_time=start_time,
191197
minimum_gap=minimum_gap,
192198
remove_telomeres=remove_telomeres,
193199
split_disjoint=split_disjoint,
@@ -196,10 +202,7 @@ def preprocess_ts(
196202
filter_sites=filter_sites,
197203
delete_intervals=delete_intervals,
198204
)
199-
ts = tables.tree_sequence()
200-
if split_disjoint:
201-
ts = split_disjoint_nodes(ts, record_provenance=False)
202-
return ts
205+
return tables.tree_sequence()
203206

204207

205208
def nodes_time_unconstrained(tree_sequence):
@@ -526,6 +529,7 @@ def split_disjoint_nodes(ts, *, record_provenance=None):
526529
as ``True``).
527530
"""
528531
metadata_key = "unsplit_node_id"
532+
start_time = time.time()
529533
if record_provenance is None:
530534
record_provenance = True
531535
node_is_sample = np.bitwise_and(ts.nodes_flags, tskit.NODE_IS_SAMPLE).astype(bool)
@@ -578,6 +582,7 @@ def split_disjoint_nodes(ts, *, record_provenance=None):
578582
provenance.record_provenance(
579583
tables,
580584
"split_disjoint_nodes",
585+
start_time=start_time,
581586
)
582587
return tables.tree_sequence()
583588

0 commit comments

Comments
 (0)