Skip to content

Commit 3619785

Browse files
committed
Add a test for multiple mutations on same edge/site
Also enable the truncated ts test
1 parent deaa1ea commit 3619785

File tree

1 file changed

+27
-5
lines changed

1 file changed

+27
-5
lines changed

tests/test_inference.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,26 @@ def test_mutation_times(self):
408408
dated = tsdate.date(ts, mutation_rate=1)
409409
assert np.all(~np.isnan(dated.tables.mutations.time))
410410

411-
@pytest.mark.skip("YAN to fix")
411+
def test_multiple_mutations_same_node_and_pos(self):
412+
tables = tskit.Tree.generate_comb(4).tree_sequence.dump_tables()
413+
indA = tables.individuals.add_row()
414+
indB = tables.individuals.add_row()
415+
nodes_individual = tables.nodes.individual
416+
nodes_individual[[0, 1, 2, 3]] = [indA, indA, indB, indB]
417+
tables.nodes.individual = nodes_individual
418+
s = tables.sites.add_row(position=0, ancestral_state="anc")
419+
tables.mutations.add_row(site=s, node=5, derived_state="0", time=2)
420+
tables.mutations.add_row(site=s, node=1, derived_state="1", time=1.5)
421+
tables.mutations.add_row(site=s, node=1, derived_state="2", time=1)
422+
tables.mutations.add_row(site=s, node=1, derived_state="3", time=0.5)
423+
ts = tsdate.date(tables.tree_sequence(), mutation_rate=1)
424+
assert np.all(np.diff(ts.mutations_time) < 0)
425+
for m in ts.mutations():
426+
assert m.derived_state == str(m.id), (
427+
m.derived_state,
428+
str(m.id),
429+
)
430+
412431
def test_truncated_ts(self):
413432
Ne = 1e2
414433
mu = 2e-4
@@ -421,11 +440,14 @@ def test_truncated_ts(self):
421440
random_seed=12,
422441
)
423442
truncated_ts = util.truncate_ts_samples(ts, average_span=200, random_seed=123)
424-
dated_ts = tsdate.date(truncated_ts, population_size=Ne, mutation_rate=mu)
425-
# We should ideally test whether *haplotypes* are the same here
426-
# in case allele encoding has changed. But haplotypes() doesn't currently
427-
# deal with missing data
443+
dated_ts = tsdate.date(truncated_ts, mutation_rate=mu, allow_unary=True)
428444
self.ts_equal_except_times(truncated_ts, dated_ts)
445+
has_missing = False
446+
for s1, s2 in zip(truncated_ts.haplotypes(), dated_ts.haplotypes()):
447+
if "N" in s1:
448+
has_missing = True
449+
assert s1 == s2
450+
assert has_missing
429451

430452

431453
class TestInferred:

0 commit comments

Comments
 (0)