diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index 612e5200d0..eb60670e33 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -331,7 +331,7 @@ verify_trees(tsk_treeseq_t *ts, tsk_size_t num_trees, tsk_id_t *parents) uint32_t mutation_index, site_index; tsk_size_t k, l, tree_sites_length; const tsk_site_t *sites = NULL; - tsk_tree_t tree; + tsk_tree_t tree, skip_tree; tsk_size_t num_edges; tsk_size_t num_nodes = tsk_treeseq_get_num_nodes(ts); tsk_size_t num_sites = tsk_treeseq_get_num_sites(ts); @@ -340,6 +340,7 @@ verify_trees(tsk_treeseq_t *ts, tsk_size_t num_trees, tsk_id_t *parents) ret = tsk_tree_init(&tree, ts, 0); CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_init(&skip_tree, ts, 0); CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(ts), num_trees); CU_ASSERT_EQUAL(tree.index, -1); @@ -372,6 +373,21 @@ verify_trees(tsk_treeseq_t *ts, tsk_size_t num_trees, tsk_id_t *parents) } site_index++; } + /* Check the skip tree */ + ret = tsk_tree_first(&skip_tree); + CU_ASSERT_EQUAL(ret, TSK_TREE_OK); + ret = tsk_tree_seek(&skip_tree, breakpoints[j], TSK_SEEK_SKIP); + CU_ASSERT_EQUAL(ret, 0); + /* Calling print_state here also verifies the integrity of the tree */ + tsk_tree_print_state(&skip_tree, _devnull); + check_trees_equal(&tree, &skip_tree); + ret = tsk_tree_last(&skip_tree); + CU_ASSERT_EQUAL(ret, TSK_TREE_OK); + ret = tsk_tree_seek(&skip_tree, breakpoints[j], TSK_SEEK_SKIP); + CU_ASSERT_EQUAL(ret, 0); + tsk_tree_print_state(&skip_tree, _devnull); + check_trees_equal(&tree, &skip_tree); + j++; } CU_ASSERT_EQUAL(ret, 0); @@ -381,6 +397,7 @@ verify_trees(tsk_treeseq_t *ts, tsk_size_t num_trees, tsk_id_t *parents) CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(ts), breakpoints[j]); tsk_tree_free(&tree); + tsk_tree_free(&skip_tree); verify_tree_pos(ts, num_trees, parents); } @@ -811,7 +828,8 @@ typedef struct { } sample_count_test_t; static void -verify_sample_counts(tsk_treeseq_t *ts, tsk_size_t num_tests, sample_count_test_t *tests) +verify_sample_counts(tsk_treeseq_t *ts, tsk_size_t num_tests, sample_count_test_t *tests, + tsk_flags_t seek_options) { int ret; tsk_size_t j, num_samples, n, k; @@ -826,13 +844,9 @@ verify_sample_counts(tsk_treeseq_t *ts, tsk_size_t num_tests, sample_count_test_ ret = tsk_tree_init(&tree, ts, TSK_NO_SAMPLE_COUNTS); CU_ASSERT_EQUAL(ret, 0); - ret = tsk_tree_first(&tree); - CU_ASSERT_EQUAL_FATAL(ret, TSK_TREE_OK); for (j = 0; j < num_tests; j++) { - while (tree.index < tests[j].tree_index) { - ret = tsk_tree_next(&tree); - CU_ASSERT_EQUAL_FATAL(ret, TSK_TREE_OK); - } + ret = tsk_tree_seek_index(&tree, tests[j].tree_index, seek_options); + CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_tree_get_num_samples(&tree, tests[j].node, &num_samples); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL(tests[j].count, num_samples); @@ -850,10 +864,8 @@ verify_sample_counts(tsk_treeseq_t *ts, tsk_size_t num_tests, sample_count_test_ ret = tsk_tree_first(&tree); CU_ASSERT_EQUAL_FATAL(ret, TSK_TREE_OK); for (j = 0; j < num_tests; j++) { - while (tree.index < tests[j].tree_index) { - ret = tsk_tree_next(&tree); - CU_ASSERT_EQUAL_FATAL(ret, TSK_TREE_OK); - } + ret = tsk_tree_seek_index(&tree, tests[j].tree_index, seek_options); + CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_tree_get_num_samples(&tree, tests[j].node, &num_samples); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL(tests[j].count, num_samples); @@ -872,10 +884,8 @@ verify_sample_counts(tsk_treeseq_t *ts, tsk_size_t num_tests, sample_count_test_ ret = tsk_tree_first(&tree); CU_ASSERT_EQUAL_FATAL(ret, TSK_TREE_OK); for (j = 0; j < num_tests; j++) { - while (tree.index < tests[j].tree_index) { - ret = tsk_tree_next(&tree); - CU_ASSERT_EQUAL_FATAL(ret, TSK_TREE_OK); - } + ret = tsk_tree_seek_index(&tree, tests[j].tree_index, seek_options); + CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_tree_get_num_samples(&tree, tests[j].node, &num_samples); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL(tests[j].count, num_samples); @@ -908,10 +918,8 @@ verify_sample_counts(tsk_treeseq_t *ts, tsk_size_t num_tests, sample_count_test_ ret = tsk_tree_first(&tree); CU_ASSERT_EQUAL_FATAL(ret, TSK_TREE_OK); for (j = 0; j < num_tests; j++) { - while (tree.index < tests[j].tree_index) { - ret = tsk_tree_next(&tree); - CU_ASSERT_EQUAL_FATAL(ret, TSK_TREE_OK); - } + ret = tsk_tree_seek_index(&tree, tests[j].tree_index, seek_options); + CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_tree_get_num_samples(&tree, tests[j].node, &num_samples); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL(tests[j].count, num_samples); @@ -1008,6 +1016,7 @@ verify_sample_sets(tsk_treeseq_t *ts) { int ret; tsk_tree_t t; + tsk_id_t j; ret = tsk_tree_init(&t, ts, TSK_SAMPLE_LISTS); CU_ASSERT_EQUAL(ret, 0); @@ -1021,6 +1030,20 @@ verify_sample_sets(tsk_treeseq_t *ts) } CU_ASSERT_EQUAL_FATAL(ret, 0); + for (j = 0; j < (tsk_id_t) tsk_treeseq_get_num_trees(ts); j++) { + ret = tsk_tree_first(&t); + CU_ASSERT_EQUAL_FATAL(ret, TSK_TREE_OK); + ret = tsk_tree_seek_index(&t, j, TSK_SEEK_SKIP); + CU_ASSERT_EQUAL_FATAL(ret, 0); + verify_sample_sets_for_tree(&t); + + ret = tsk_tree_last(&t); + CU_ASSERT_EQUAL_FATAL(ret, TSK_TREE_OK); + ret = tsk_tree_seek_index(&t, j, TSK_SEEK_SKIP); + CU_ASSERT_EQUAL_FATAL(ret, 0); + verify_sample_sets_for_tree(&t); + } + tsk_tree_free(&t); } @@ -5870,7 +5893,8 @@ test_simple_sample_sets(void) tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, NULL, NULL, paper_ex_individuals, NULL, 0); - verify_sample_counts(&ts, num_tests, tests); + verify_sample_counts(&ts, num_tests, tests, 0); + verify_sample_counts(&ts, num_tests, tests, TSK_SEEK_SKIP); verify_sample_sets(&ts); tsk_treeseq_free(&ts); @@ -5889,7 +5913,8 @@ test_nonbinary_sample_sets(void) tsk_treeseq_from_text(&ts, 100, nonbinary_ex_nodes, nonbinary_ex_edges, NULL, NULL, NULL, NULL, NULL, 0); - verify_sample_counts(&ts, num_tests, tests); + verify_sample_counts(&ts, num_tests, tests, 0); + verify_sample_counts(&ts, num_tests, tests, TSK_SEEK_SKIP); verify_sample_sets(&ts); tsk_treeseq_free(&ts); @@ -5909,7 +5934,8 @@ test_internal_sample_sample_sets(void) tsk_treeseq_from_text(&ts, 10, internal_sample_ex_nodes, internal_sample_ex_edges, NULL, NULL, NULL, NULL, NULL, 0); - verify_sample_counts(&ts, num_tests, tests); + verify_sample_counts(&ts, num_tests, tests, 0); + verify_sample_counts(&ts, num_tests, tests, TSK_SEEK_SKIP); verify_sample_sets(&ts); tsk_treeseq_free(&ts); @@ -6283,7 +6309,7 @@ test_multiroot_tree_traversal(void) } static void -test_seek_multi_tree(void) +verify_seek_multi_tree(tsk_flags_t seek_options) { int ret; tsk_treeseq_t ts; @@ -6299,29 +6325,29 @@ test_seek_multi_tree(void) CU_ASSERT_EQUAL_FATAL(ret, 0); for (j = 0; j < num_trees; j++) { - ret = tsk_tree_seek(&t, breakpoints[j], 0); + ret = tsk_tree_seek(&t, breakpoints[j], seek_options); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL_FATAL(t.index, j); - ret = tsk_tree_seek_index(&t, j, 0); + ret = tsk_tree_seek_index(&t, j, seek_options); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL_FATAL(t.index, j); for (k = 0; k < num_trees; k++) { - ret = tsk_tree_seek(&t, breakpoints[k], 0); + ret = tsk_tree_seek(&t, breakpoints[k], seek_options); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL_FATAL(t.index, k); - ret = tsk_tree_seek_index(&t, k, 0); + ret = tsk_tree_seek_index(&t, k, seek_options); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL_FATAL(t.index, k); } } - ret = tsk_tree_seek(&t, 1.99999, 0); + ret = tsk_tree_seek(&t, 1.99999, seek_options); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL_FATAL(t.index, 0); - ret = tsk_tree_seek(&t, 6.99999, 0); + ret = tsk_tree_seek(&t, 6.99999, seek_options); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL_FATAL(t.index, 1); - ret = tsk_tree_seek(&t, 9.99999, 0); + ret = tsk_tree_seek(&t, 9.99999, seek_options); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL_FATAL(t.index, 2); @@ -6331,7 +6357,7 @@ test_seek_multi_tree(void) for (j = 0; j < num_trees; j++) { ret = tsk_tree_init(&t, &ts, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); - ret = tsk_tree_seek(&t, breakpoints[j], 0); + ret = tsk_tree_seek(&t, breakpoints[j], seek_options); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL_FATAL(t.index, j); tsk_tree_free(&t); @@ -6341,12 +6367,12 @@ test_seek_multi_tree(void) ret = tsk_tree_init(&t, &ts, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); for (j = 0; j < num_trees; j++) { - ret = tsk_tree_seek(&t, 0, 0); + ret = tsk_tree_seek(&t, 0, seek_options); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_tree_prev(&t); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL_FATAL(t.index, -1); - ret = tsk_tree_seek(&t, breakpoints[j], 0); + ret = tsk_tree_seek(&t, breakpoints[j], seek_options); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL_FATAL(t.index, j); } @@ -6355,6 +6381,13 @@ test_seek_multi_tree(void) tsk_treeseq_free(&ts); } +static void +test_seek_multi_tree(void) +{ + verify_seek_multi_tree(0); + verify_seek_multi_tree(TSK_SEEK_SKIP); +} + static void test_seek_errors(void) { diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 2b778f195b..02f12f053a 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -6146,17 +6146,18 @@ tsk_tree_print_state(const tsk_tree_t *self, FILE *out) fprintf(out, "left = %f\n", self->interval.left); fprintf(out, "right = %f\n", self->interval.right); fprintf(out, "index = %lld\n", (long long) self->index); - fprintf(out, "node\tparent\tlchild\trchild\tlsib\trsib"); + fprintf(out, "num_edges = %d\n", (int) self->num_edges); + fprintf(out, "node\tedge\tparent\tlchild\trchild\tlsib\trsib"); if (self->options & TSK_SAMPLE_LISTS) { fprintf(out, "\thead\ttail"); } fprintf(out, "\n"); for (j = 0; j < self->num_nodes + 1; j++) { - fprintf(out, "%lld\t%lld\t%lld\t%lld\t%lld\t%lld", (long long) j, - (long long) self->parent[j], (long long) self->left_child[j], - (long long) self->right_child[j], (long long) self->left_sib[j], - (long long) self->right_sib[j]); + fprintf(out, "%lld\t%lld\t%lld\t%lld\t%lld\t%lld\t%lld", (long long) j, + (long long) self->edge[j], (long long) self->parent[j], + (long long) self->left_child[j], (long long) self->right_child[j], + (long long) self->left_sib[j], (long long) self->right_sib[j]); if (self->options & TSK_SAMPLE_LISTS) { fprintf(out, "\t%lld\t%lld\t", (long long) self->left_sample[j], (long long) self->right_sample[j]); @@ -6284,7 +6285,8 @@ tsk_tree_remove_root(tsk_tree_t *self, tsk_id_t root, tsk_id_t *restrict parent) } static void -tsk_tree_remove_edge(tsk_tree_t *self, tsk_id_t p, tsk_id_t c) +tsk_tree_remove_edge( + tsk_tree_t *self, tsk_id_t p, tsk_id_t c, tsk_id_t TSK_UNUSED(edge_id)) { tsk_id_t *restrict parent = self->parent; tsk_size_t *restrict num_samples = self->num_samples; @@ -6425,7 +6427,7 @@ tsk_tree_next(tsk_tree_t *self) if (valid) { for (j = tree_pos.out.start; j != tree_pos.out.stop; j++) { e = tree_pos.out.order[j]; - tsk_tree_remove_edge(self, edge_parent[e], edge_child[e]); + tsk_tree_remove_edge(self, edge_parent[e], edge_child[e], e); } for (j = tree_pos.in.start; j != tree_pos.in.stop; j++) { @@ -6457,7 +6459,7 @@ tsk_tree_prev(tsk_tree_t *self) if (valid) { for (j = tree_pos.out.start; j != tree_pos.out.stop; j--) { e = tree_pos.out.order[j]; - tsk_tree_remove_edge(self, edge_parent[e], edge_child[e]); + tsk_tree_remove_edge(self, edge_parent[e], edge_child[e], e); } for (j = tree_pos.in.start; j != tree_pos.in.stop; j--) { @@ -6532,6 +6534,90 @@ tsk_tree_seek_from_null(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(optio return ret; } +static int TSK_WARN_UNUSED +tsk_tree_seek_forward(tsk_tree_t *self, tsk_id_t index) +{ + int ret = 0; + tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_id_t *restrict edge_parent = tables->edges.parent; + const tsk_id_t *restrict edge_child = tables->edges.child; + const double *restrict edge_left = tables->edges.left; + const double *restrict edge_right = tables->edges.right; + double interval_left, e_left; + const double old_right = self->interval.right; + tsk_id_t j, e; + tsk_tree_position_t tree_pos; + + ret = tsk_tree_position_seek_forward(&self->tree_pos, index); + if (ret != 0) { + goto out; + } + tree_pos = self->tree_pos; + interval_left = tree_pos.interval.left; + + for (j = tree_pos.out.start; j != tree_pos.out.stop; j++) { + e = tree_pos.out.order[j]; + e_left = edge_left[e]; + if (e_left < old_right) { + tsk_bug_assert(edge_parent[e] != TSK_NULL); + tsk_tree_remove_edge(self, edge_parent[e], edge_child[e], e); + } + tsk_bug_assert(e_left < interval_left); + } + + for (j = tree_pos.in.start; j != tree_pos.in.stop; j++) { + e = tree_pos.in.order[j]; + if (edge_left[e] <= interval_left && interval_left < edge_right[e]) { + tsk_tree_insert_edge(self, edge_parent[e], edge_child[e], e); + } + } + tsk_tree_update_index_and_interval(self); +out: + return ret; +} + +static int TSK_WARN_UNUSED +tsk_tree_seek_backward(tsk_tree_t *self, tsk_id_t index) +{ + int ret = 0; + tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_id_t *restrict edge_parent = tables->edges.parent; + const tsk_id_t *restrict edge_child = tables->edges.child; + const double *restrict edge_left = tables->edges.left; + const double *restrict edge_right = tables->edges.right; + double interval_right, e_right; + const double old_right = self->interval.right; + tsk_id_t j, e; + tsk_tree_position_t tree_pos; + + ret = tsk_tree_position_seek_backward(&self->tree_pos, index); + if (ret != 0) { + goto out; + } + tree_pos = self->tree_pos; + interval_right = tree_pos.interval.right; + + for (j = tree_pos.out.start; j != tree_pos.out.stop; j--) { + e = tree_pos.out.order[j]; + e_right = edge_right[e]; + if (e_right >= old_right) { + tsk_bug_assert(edge_parent[e] != TSK_NULL); + tsk_tree_remove_edge(self, edge_parent[e], edge_child[e], e); + } + tsk_bug_assert(e_right > interval_right); + } + + for (j = tree_pos.in.start; j != tree_pos.in.stop; j--) { + e = tree_pos.in.order[j]; + if (edge_right[e] >= interval_right && interval_right > edge_left[e]) { + tsk_tree_insert_edge(self, edge_parent[e], edge_child[e], e); + } + } + tsk_tree_update_index_and_interval(self); +out: + return ret; +} + int TSK_WARN_UNUSED tsk_tree_seek_index(tsk_tree_t *self, tsk_id_t tree, tsk_flags_t options) { @@ -6549,7 +6635,7 @@ tsk_tree_seek_index(tsk_tree_t *self, tsk_id_t tree, tsk_flags_t options) } static int TSK_WARN_UNUSED -tsk_tree_seek_linear(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(options)) +tsk_tree_seek_linear(tsk_tree_t *self, double x) { const double L = tsk_treeseq_get_sequence_length(self->tree_sequence); const double t_l = self->interval.left; @@ -6588,6 +6674,29 @@ tsk_tree_seek_linear(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(options) return ret; } +static int TSK_WARN_UNUSED +tsk_tree_seek_skip(tsk_tree_t *self, double x) +{ + const double t_l = self->interval.left; + int ret = 0; + tsk_id_t index; + const tsk_size_t num_trees = self->tree_sequence->num_trees; + const double *restrict breakpoints = self->tree_sequence->breakpoints; + + index = (tsk_id_t) tsk_search_sorted(breakpoints, num_trees + 1, x); + if (breakpoints[index] > x) { + index--; + } + + if (x < t_l) { + ret = tsk_tree_seek_backward(self, index); + } else { + ret = tsk_tree_seek_forward(self, index); + } + tsk_bug_assert(tsk_tree_position_in_interval(self, x)); + return ret; +} + int TSK_WARN_UNUSED tsk_tree_seek(tsk_tree_t *self, double x, tsk_flags_t options) { @@ -6602,7 +6711,11 @@ tsk_tree_seek(tsk_tree_t *self, double x, tsk_flags_t options) if (self->index == -1) { ret = tsk_tree_seek_from_null(self, x, options); } else { - ret = tsk_tree_seek_linear(self, x, options); + if (options & TSK_SEEK_SKIP) { + ret = tsk_tree_seek_skip(self, x); + } else { + ret = tsk_tree_seek_linear(self, x); + } } out: diff --git a/c/tskit/trees.h b/c/tskit/trees.h index ac4100f7b0..89d800f3b3 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1270,6 +1270,13 @@ int tsk_tree_copy(const tsk_tree_t *self, tsk_tree_t *dest, tsk_flags_t options) @{ */ +/** @brief Option to seek by skipping to the target tree, adding and removing as few + edges as possible. If not specified, a linear time algorithm is used instead. + + @ingroup TREE_API_SEEKING_GROUP +*/ +#define TSK_SEEK_SKIP (1 << 0) + /** @brief Seek to the first tree in the sequence. @@ -1375,12 +1382,22 @@ we will have ``position < tree.interval.right``. Seeking to a position currently covered by the tree is a constant time operation. + +Seeking to a position from a non-null tree uses a linear time +algorithm by default, unless the option :c:macro:`TSK_SEEK_SKIP` +is specified. In this case, a faster algorithm is employed which skips +to the target tree by removing and adding the minimal number of edges +possible. However, this approach does not guarantee that edges are +inserted and removed in time-sorted order. + +.. warning:: Using the :c:macro:`TSK_SEEK_SKIP` option + may lead to edges not being inserted or removed in time-sorted order. + @endrst @param self A pointer to an initialised tsk_tree_t object. @param position The position in genome coordinates -@param options Seek options. Currently unused. Set to 0 for compatibility - with future versions of tskit. +@param options Seek options. See the notes above for details. @return Return 0 on success or a negative value on failure. */ int tsk_tree_seek(tsk_tree_t *self, double position, tsk_flags_t options); diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index cab445e5d0..66776bd187 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -12103,16 +12103,21 @@ static PyObject * Tree_seek(Tree *self, PyObject *args) { PyObject *ret = NULL; + tsk_flags_t options = 0; + int skip = false; double position; int err; if (Tree_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTuple(args, "d", &position)) { + if (!PyArg_ParseTuple(args, "d|i", &position, &skip)) { goto out; } - err = tsk_tree_seek(self->tree, position, 0); + if (skip) { + options |= TSK_SEEK_SKIP; + } + err = tsk_tree_seek(self->tree, position, options); if (err != 0) { handle_library_error(err); goto out; @@ -12127,15 +12132,20 @@ Tree_seek_index(Tree *self, PyObject *args) { PyObject *ret = NULL; tsk_id_t index = 0; + tsk_flags_t options = 0; + int skip = false; int err; if (Tree_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTuple(args, "O&", tsk_id_converter, &index)) { + if (!PyArg_ParseTuple(args, "O&|i", tsk_id_converter, &index, &skip)) { goto out; } - err = tsk_tree_seek_index(self->tree, index, 0); + if (skip) { + options |= TSK_SEEK_SKIP; + } + err = tsk_tree_seek_index(self->tree, index, options); if (err != 0) { handle_library_error(err); goto out; diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index 1f476be574..f33083bc9b 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -3838,7 +3838,8 @@ def test_deprecated_api_warnings(self): with pytest.warns(FutureWarning, match="Tree.tree_sequence.num_nodes"): t1.num_nodes - def test_seek_index(self): + @pytest.mark.parametrize("skip", [False, True]) + def test_seek_index(self, skip): ts = msprime.simulate(10, recombination_rate=3, length=5, random_seed=42) N = ts.num_trees assert ts.num_trees > 3 @@ -3847,18 +3848,23 @@ def test_seek_index(self): fresh_tree = tskit.Tree(ts) assert fresh_tree.index == -1 fresh_tree.seek_index(index) - tree.seek_index(index) assert fresh_tree.index == index - assert tree.index == index + tree.seek_index(index, skip) + assert_trees_equivalent(fresh_tree, tree) tree = tskit.Tree(ts) for index in [-1, -2, -N + 2, -N + 1, -N]: fresh_tree = tskit.Tree(ts) assert fresh_tree.index == -1 fresh_tree.seek_index(index) - tree.seek_index(index) + tree.seek_index(index, skip) assert fresh_tree.index == index + N assert tree.index == index + N + assert_trees_equivalent(fresh_tree, tree) + + def test_seek_index_errors(self): + tree = self.get_tree() + N = tree.tree_sequence.num_trees with pytest.raises(IndexError): tree.seek_index(N) with pytest.raises(IndexError): @@ -4373,11 +4379,21 @@ def test_nonbinary(self): def assert_trees_identical(t1, t2): assert t1.tree_sequence == t2.tree_sequence assert t1.index == t2.index - assert np.all(t1.parent_array == t2.parent_array) - assert np.all(t1.left_child_array == t2.left_child_array) - assert np.all(t1.left_sib_array == t2.left_sib_array) - assert np.all(t1.right_child_array == t2.right_child_array) - assert np.all(t1.right_sib_array == t2.right_sib_array) + assert_array_equal(t1.parent_array, t2.parent_array) + assert_array_equal(t1.left_child_array, t2.left_child_array) + assert_array_equal(t1.left_sib_array, t2.left_sib_array) + assert_array_equal(t1.right_child_array, t2.right_child_array) + assert_array_equal(t1.right_sib_array, t2.right_sib_array) + + +def assert_trees_equivalent(t1, t2): + assert t1.tree_sequence == t2.tree_sequence + assert t1.index == t2.index + assert_array_equal(t1.parent_array, t2.parent_array) + assert_array_equal(t1.edge_array, t2.edge_array) + for u in range(t1.tree_sequence.num_nodes): + # this isn't fully testing the data model, but that's done elsewhere + assert sorted(t1.children(u)) == sorted(t2.children(u)) def assert_same_tree_different_order(t1, t2): @@ -4431,8 +4447,9 @@ def get_tree_pair(self): ts = self.ts() t1 = tskit.Tree(ts) t2 = tskit.Tree(ts) - # Note: for development we can monkeypatch in the Python implementation - # above like this: + # # Note: for development we can monkeypatch in the Python implementation + # # above like this: + # import functools # t2.seek = functools.partial(seek, t2) return t1, t2 @@ -4453,64 +4470,80 @@ def test_seek_from_null(self, position): t1.clear() t1.seek(position) t2.first() - t2.seek(position) + t2.seek(position, skip=False) assert_trees_identical(t1, t2) + @pytest.mark.parametrize("position", [0, 1, 2, 3]) + def test_skip_from_null(self, position): + t1, t2 = self.get_tree_pair() + t1.clear() + t1.seek(position) + t2.first() + t2.seek(position, skip=True) + assert_trees_equivalent(t1, t2) + @pytest.mark.parametrize("index", range(3)) - def test_seek_next_tree(self, index): + @pytest.mark.parametrize("skip", [False, True]) + def test_seek_next_tree(self, index, skip): t1, t2 = self.get_tree_pair() while t1.index != index: t1.next() t2.next() t1.next() - t2.seek(index + 1) + t2.seek(index + 1, skip=skip) assert_trees_identical(t1, t2) @pytest.mark.parametrize("index", [3, 2, 1]) - def test_seek_prev_tree(self, index): + @pytest.mark.parametrize("skip", [False, True]) + def test_seek_prev_tree(self, index, skip): t1, t2 = self.get_tree_pair() while t1.index != index: t1.prev() t2.prev() t1.prev() - t2.seek(index - 1) + t2.seek(index - 1, skip=skip) assert_trees_identical(t1, t2) - def test_seek_1_from_0(self): + @pytest.mark.parametrize("skip", [False, True]) + def test_seek_1_from_0(self, skip): t1, t2 = self.get_tree_pair() t1.first() t1.next() t2.first() - t2.seek(1) + t2.seek(1, skip) assert_trees_identical(t1, t2) - def test_seek_1_5_from_0(self): + @pytest.mark.parametrize("skip", [False, True]) + def test_seek_1_5_from_0(self, skip): t1, t2 = self.get_tree_pair() t1.first() t1.next() t2.first() - t2.seek(1.5) + t2.seek(1.5, skip) assert_trees_identical(t1, t2) - def test_seek_1_5_from_1(self): + @pytest.mark.parametrize("skip", [False, True]) + def test_seek_1_5_from_1(self, skip): t1, t2 = self.get_tree_pair() for _ in range(2): t1.next() t2.next() - t2.seek(1.5) + t2.seek(1.5, skip) assert_trees_identical(t1, t2) - def test_seek_3_from_null(self): + @pytest.mark.parametrize("skip", [False, True]) + def test_seek_3_from_null(self, skip): t1, t2 = self.get_tree_pair() t1.last() - t2.seek(3) + t2.seek(3, skip) assert_trees_identical(t1, t2) - def test_seek_3_from_null_prev(self): + @pytest.mark.parametrize("skip", [False, True]) + def test_seek_3_from_null_prev(self, skip): t1, t2 = self.get_tree_pair() t1.last() t1.prev() - t2.seek(3) + t2.seek(3, skip) t2.prev() assert_trees_identical(t1, t2) @@ -4521,6 +4554,21 @@ def test_seek_3_from_0(self): t2.seek(3) assert_trees_identical(t1, t2) + def test_skip_3_from_0(self): + t1, t2 = self.get_tree_pair() + t1.last() + t2.first() + t2.seek(3, True) + assert_trees_equivalent(t1, t2) + + def test_skip_0_from_3(self): + t1, t2 = self.get_tree_pair() + t1.last() + t1.first() + t2.last() + t2.seek(0, True) + assert_trees_equivalent(t1, t2) + def test_seek_0_from_3(self): t1, t2 = self.get_tree_pair() t1.last() @@ -4545,8 +4593,18 @@ def test_seek_mid_null_and_middle(self, ts): else: while t2.index != index: t2.prev() - assert t1.index == t2.index - assert np.all(t1.parent_array == t2.parent_array) + assert_trees_equivalent(t1, t2) + + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) + def test_seek_skip_middle(self, ts): + breakpoints = ts.breakpoints(as_array=True) + mid = breakpoints[:-1] + np.diff(breakpoints) / 2 + for _, x in enumerate(mid[:-1]): + t1 = tskit.Tree(ts) + t1.seek(x, skip=False) + t2 = tskit.Tree(ts) + t2.seek(x, skip=True) + assert_trees_equivalent(t1, t2) @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_seek_last_then_prev(self, ts): diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 9cc8206cb1..ab373ccae5 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -3640,6 +3640,15 @@ def test_seek_errors(self): with pytest.raises(_tskit.LibraryError): tree.seek(bad_pos) + def seek_skip_errors(self): + ts = self.get_example_tree_sequence() + tree = _tskit.Tree(ts) + for bad_type in ["", "x", {}]: + with pytest.raises(TypeError): + tree.seek(0, bad_type) + with pytest.raises(TypeError): + tree.seek_index(0, bad_type) + def test_seek_index_errors(self): ts = self.get_example_tree_sequence() tree = _tskit.Tree(ts) @@ -3650,6 +3659,16 @@ def test_seek_index_errors(self): with pytest.raises(_tskit.LibraryError): tree.seek_index(bad_index) + @pytest.mark.parametrize("skip", [True, False]) + def test_seek_zero(self, skip): + ts = self.get_example_tree_sequence() + tree1 = _tskit.Tree(ts) + tree1.seek_index(0, skip) + assert tree1.get_left() == 0 + tree2 = _tskit.Tree(ts) + tree2.seek(0, skip) + assert tree2.get_left() == 0 + def test_root_threshold(self): for ts in self.get_example_tree_sequences(): tree = _tskit.Tree(ts) diff --git a/python/tests/test_tree_positioning.py b/python/tests/test_tree_positioning.py index b36c8c385c..e3e7e7a0be 100644 --- a/python/tests/test_tree_positioning.py +++ b/python/tests/test_tree_positioning.py @@ -32,9 +32,6 @@ from tests import tsutil from tests.tsutil import get_example_tree_sequences -# ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when -# we can remove this. - class StatefulTree: """ diff --git a/python/tskit/trees.py b/python/tskit/trees.py index e740181f81..9a49c0da47 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -840,7 +840,7 @@ def clear(self): """ self._ll_tree.clear() - def seek_index(self, index): + def seek_index(self, index, skip=None): """ Sets the state to represent the tree at the specified index in the parent tree sequence. Negative indexes following the @@ -857,9 +857,10 @@ def seek_index(self, index): index += num_trees if index < 0 or index >= num_trees: raise IndexError("Index out of bounds") - self._ll_tree.seek_index(index) + skip = False if skip is None else skip + self._ll_tree.seek_index(index, skip) - def seek(self, position): + def seek(self, position, skip=None): """ Sets the state to represent the tree that covers the specified position in the parent tree sequence. After a successful return @@ -875,7 +876,8 @@ def seek(self, position): """ if position < 0 or position >= self.tree_sequence.sequence_length: raise ValueError("Position out of bounds") - self._ll_tree.seek(position) + skip = False if skip is None else skip + self._ll_tree.seek(position, skip) def rank(self) -> tskit.Rank: """