Skip to content

Remove tscompare routines #458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 0 additions & 148 deletions tests/test_evaluation.py

This file was deleted.

232 changes: 5 additions & 227 deletions tsdate/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
Tools for comparing node times between tree sequences with different node sets
Tools for comparing node times between tree sequences with different node sets.

NB: functionality such as CladeMap, shared_node_spans, and match_node_ages is
now in the tscompare repository: https://github.com/tskit-dev/tscompare/
"""

import copy
import json
from collections import defaultdict
from itertools import groupby, product
from itertools import groupby
from math import isqrt

import matplotlib.pyplot as plt
Expand All @@ -37,230 +39,6 @@
from .phasing import mutation_frequency
from .rescaling import count_mutations


class CladeMap:
"""
An iterator across trees that maintains a mapping from a clade (a `frozenset` of
sample IDs) to a `set` of nodes. When there are unary nodes, there may be multiple
nodes associated with each clade.
"""

def __init__(self, ts):
self._nil = frozenset()
self._nodes = defaultdict(set) # nodes[clade] = {node ids}
self._clades = defaultdict(frozenset) # clades[node] = {sample ids}
self.tree_sequence = ts
self.tree = ts.first(sample_lists=True)
for node in self.tree.nodes():
clade = frozenset(self.tree.samples(node))
self._nodes[clade].add(node)
self._clades[node] = clade
self._prev = copy.deepcopy(self._clades)
self._diff = ts.edge_diffs()
next(self._diff)

def _propagate(self, edge, downdate=False):
"""
Traverse path from `edge.parent` to root, either adding or removing the
state (clade) associated with `edge.child` from the state of each
visited node. Return a set with the node ids encountered during
traversal.
"""
nodes = set()
node = edge.parent
clade = self._clades[edge.child]
while node != tskit.NULL:
last = self._clades[node]
self._clades[node] = last - clade if downdate else last | clade
if len(last):
self._nodes[last].remove(node)
if len(self._nodes[last]) == 0:
del self._nodes[last]
self._nodes[self._clades[node]].add(node)
nodes.add(node)
node = self.tree.parent(node)
return nodes

def next(self):
"""
Advance to the next tree, returning the difference between trees as a
dictionary of the form `node : (last_clade, next_clade)`
"""
nodes = set() # nodes with potentially altered clades
diff = {} # diff[node] = (prev_clade, curr_clade)

if self.tree.index + 1 == self.tree_sequence.num_trees:
return None

# Subtract clades subtended by outgoing edges
edge_diff = next(self._diff)
for eo in edge_diff.edges_out:
nodes |= self._propagate(eo, downdate=True)

# Prune nodes that are no longer in tree
for node in self._nodes[self._nil]:
diff[node] = (self._prev[node], self._nil)
del self._clades[node]
nodes -= self._nodes[self._nil]
self._nodes[self._nil].clear()

# Add clades subtended by incoming edges
self.tree.next()
for ei in edge_diff.edges_in:
nodes |= self._propagate(ei, downdate=False)

# Find difference in clades between adjacent trees
for node in nodes:
diff[node] = (self._prev[node], self._clades[node])
if self._prev[node] == self._clades[node]:
del diff[node]

# Sync previous and current states
for node, (_, curr) in diff.items():
if curr == self._nil:
del self._prev[node]
else:
self._prev[node] = curr

return diff

@property
def interval(self):
"""
Return interval spanned by tree
"""
return self.tree.interval

def clades(self):
"""
Return set of clades in tree
"""
return self._nodes.keys() - self._nil

def __getitem__(self, clade):
"""
Return set of nodes associated with a given clade.
"""
return frozenset(self._nodes[clade]) if frozenset(clade) in self else self._nil

def __contains__(self, clade):
"""
Check if a clade is present in the tree
"""
return clade in self._nodes


def shared_node_spans(ts, other):
"""
Calculate the spans over which pairs of nodes in two tree sequences are
ancestral to indentical sets of samples.

Returns a sparse matrix where rows correspond to nodes in `ts` and columns
correspond to nodes in `other`.
"""

if ts.sequence_length != other.sequence_length:
raise ValueError("Tree sequences must be of equal sequence length.")

if ts.num_samples != other.num_samples:
raise ValueError("Tree sequences must have the same numbers of samples.")

nil = frozenset()

# Initialize clade iterators
query = CladeMap(ts)
target = CladeMap(other)

# Initialize buffer[clade] = (query_nodes, target_nodes, left_coord)
modified = query.clades() | target.clades()
buffer = {}

# Build sparse matrix of matches in triplet format
query_node = []
target_node = []
shared_span = []
right = 0
while True:
left = right
right = min(query.interval[1], target.interval[1])

# Flush pairs of nodes that no longer have matching clades
for clade in modified: # flush:
if clade in buffer:
n_i, n_j, start = buffer.pop(clade)
span = left - start
for i, j in product(n_i, n_j):
query_node.append(i)
target_node.append(j)
shared_span.append(span)

# Add new pairs of nodes with matching clades
for clade in modified:
assert clade not in buffer
if clade in query and clade in target:
n_i, n_j = query[clade], target[clade]
buffer[clade] = (n_i, n_j, left)

if right == ts.sequence_length:
break

# Find difference in clades with advance to next tree
modified.clear()
for clade_map in (query, target):
if clade_map.interval[1] == right:
clade_diff = clade_map.next()
for prev, curr in clade_diff.values():
if prev != nil:
modified.add(prev)
if curr != nil:
modified.add(curr)

# Flush final tree
for clade in buffer:
n_i, n_j, start = buffer[clade]
span = right - start
for i, j in product(n_i, n_j):
query_node.append(i)
target_node.append(j)
shared_span.append(span)

numer = scipy.sparse.coo_matrix(
(shared_span, (query_node, target_node)),
shape=(ts.num_nodes, other.num_nodes),
).tocsr()

return numer


def match_node_ages(ts, other):
"""
For each node in `ts`, return the age of a matched node from `other`. Node
matching is accomplished by calculating the intervals over which pairs of
nodes (one from `ts`, one from `other`) subtend the same set of samples.

Returns three vectors of length `ts.num_nodes`: the age of the best
matching node in `other` (e.g. with the longest shared span); the
proportion of the node span in `ts` that is covered by the best match; and
the id of the best match in `other`.

If either tree sequence contains unary nodes, then there may be multiple
matches with the same span for a single node. In this case, the returned
match is the node with the smallest integer id.
"""

shared_spans = shared_node_spans(ts, other)
matched_span = shared_spans.max(axis=1).todense().A1
best_match = shared_spans.argmax(axis=1).A1
# NB: if there are multiple nodes with the largest span in a row,
# argmax returns the node with the smallest integer id
matched_time = other.nodes_time[best_match]

best_match[matched_span == 0] = tskit.NULL
matched_time[matched_span == 0] = np.nan

return matched_time, matched_span, best_match


# --- infrastructure for testing against polytomies --- #


Expand Down
Loading