Skip to content

Commit 6956365

Browse files
Merge pull request #256 from jeromekelleher/dynamic-precision
Dynamic precision
2 parents e996bfd + dffa62c commit 6956365

File tree

4 files changed

+290
-191
lines changed

4 files changed

+290
-191
lines changed

run.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@ num_threads=8
99

1010
# Paths
1111
datadir=testrun
12-
run_id=tmp-dev
12+
run_id=tmp-dev-hp
1313
# run_id=upgma-mds-$max_daily_samples-md-$max_submission_delay-mm-$mismatches
1414
resultsdir=results/$run_id
1515
results_prefix=$resultsdir/$run_id-
1616
logfile=logs/$run_id.log
1717

1818
alignments=$datadir/alignments.db
1919
metadata=$datadir/metadata.db
20-
matches=$resultsdir/matces.db
20+
matches=$resultsdir/matches.db
2121

22-
dates=`python3 -m sc2ts list-dates $metadata | grep -v 2021-12-31`
22+
dates=`python3 -m sc2ts list-dates $metadata | grep -v 2021-12-31 | head -n 14`
2323
echo $dates
2424

2525
options="--num-threads $num_threads -vv -l $logfile "

sc2ts/inference.py

Lines changed: 88 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,7 @@ def add(self, samples, date, num_mismatches):
102102
pkl_compressed,
103103
)
104104
data.append(args)
105-
pango = sample.metadata.get("Viridian_pangolin", "Unknown")
106-
logger.debug(
107-
f"MatchDB insert: {sample.strain} {date} {pango} hmm_cost={hmm_cost[j]}"
108-
)
105+
logger.debug(f"MatchDB insert: hmm_cost={hmm_cost[j]} {sample.summary()}")
109106
# Batch insert, for efficiency.
110107
with self.conn:
111108
self.conn.executemany(sql, data)
@@ -150,11 +147,7 @@ def get(self, where_clause):
150147
for row in self.conn.execute(sql):
151148
pkl = row.pop("pickle")
152149
sample = pickle.loads(bz2.decompress(pkl))
153-
pango = sample.metadata.get("Viridian_pangolin", "Unknown")
154-
logger.debug(
155-
f"MatchDb got: {sample.strain} {sample.date} {pango} "
156-
f"hmm_cost={row['hmm_cost']}"
157-
)
150+
logger.debug(f"MatchDb got: {sample.summary()} hmm_cost={row['hmm_cost']}")
158151
# print(row)
159152
yield sample
160153

@@ -364,6 +357,18 @@ class Sample:
364357
# def __str__(self):
365358
# return f"{self.strain}: {self.path} + {self.mutations}"
366359

360+
def path_summary(self):
361+
return ",".join(f"({seg.left}:{seg.right}, {seg.parent})" for seg in self.path)
362+
363+
def mutation_summary(self):
364+
return "[" + ",".join(str(mutation) for mutation in self.mutations) + "]"
365+
366+
def summary(self):
367+
pango = self.metadata.get("Viridian_pangolin", "Unknown")
368+
return (f"{self.strain} {self.date} {pango} path={self.path_summary()} "
369+
f"mutations({len(self.mutations)})={self.mutation_summary()}"
370+
)
371+
367372
@property
368373
def breakpoints(self):
369374
breakpoints = [seg.left for seg in self.path]
@@ -388,104 +393,62 @@ def asdict(self):
388393
}
389394

390395

391-
# def daily_extend(
392-
# *,
393-
# alignment_store,
394-
# metadata_db,
395-
# base_ts,
396-
# match_db,
397-
# num_mismatches=None,
398-
# max_hmm_cost=None,
399-
# min_group_size=None,
400-
# num_past_days=None,
401-
# show_progress=False,
402-
# max_submission_delay=None,
403-
# max_daily_samples=None,
404-
# num_threads=None,
405-
# precision=None,
406-
# rng=None,
407-
# excluded_sample_dir=None,
408-
# ):
409-
# assert num_past_days is None
410-
# assert max_submission_delay is None
411-
412-
# start_day = last_date(base_ts)
413-
414-
# last_ts = base_ts
415-
# for date in metadata_db.get_days(start_day):
416-
# ts = extend(
417-
# alignment_store=alignment_store,
418-
# metadata_db=metadata_db,
419-
# date=date,
420-
# base_ts=last_ts,
421-
# match_db=match_db,
422-
# num_mismatches=num_mismatches,
423-
# max_hmm_cost=max_hmm_cost,
424-
# min_group_size=min_group_size,
425-
# show_progress=show_progress,
426-
# max_submission_delay=max_submission_delay,
427-
# max_daily_samples=max_daily_samples,
428-
# num_threads=num_threads,
429-
# precision=precision,
430-
# )
431-
# yield ts, date
432-
433-
# last_ts = ts
434-
435-
436396
def match_samples(
437397
date,
438398
samples,
439399
*,
440-
match_db,
441400
base_ts,
442401
num_mismatches=None,
443402
show_progress=False,
444403
num_threads=None,
445-
precision=None,
446-
mirror_coordinates=False,
447404
):
448-
if num_mismatches is None:
449-
# Default to no recombination
450-
num_mismatches = 1000
451-
452-
match_tsinfer(
453-
samples=samples,
454-
ts=base_ts,
455-
num_mismatches=num_mismatches,
456-
precision=2,
457-
num_threads=num_threads,
458-
show_progress=show_progress,
459-
mirror_coordinates=mirror_coordinates,
460-
)
461-
samples_to_rerun = []
462-
for sample in samples:
463-
hmm_cost = sample.get_hmm_cost(num_mismatches)
464-
logger.debug(
465-
f"First sketch: {sample.strain} hmm_cost={hmm_cost} path={sample.path}"
466-
)
467-
if hmm_cost >= 2:
468-
sample.path.clear()
469-
sample.mutations.clear()
470-
samples_to_rerun.append(sample)
471-
472-
if len(samples_to_rerun) > 0:
405+
# First pass, compute the matches at precision=0.
406+
run_batch = samples
407+
408+
# Values based on https://github.com/jeromekelleher/sc2ts/issues/242,
409+
# but somewhat arbitrary.
410+
for precision, cost_threshold in [(0, 1), (1, 2), (2, 3)]:
411+
logger.info(f"Running batch of {len(run_batch)} at p={precision}")
473412
match_tsinfer(
474-
samples=samples_to_rerun,
413+
samples=run_batch,
475414
ts=base_ts,
476415
num_mismatches=num_mismatches,
477416
precision=precision,
478417
num_threads=num_threads,
479418
show_progress=show_progress,
480-
mirror_coordinates=mirror_coordinates,
481419
)
482-
for sample in samples_to_rerun:
483-
hmm_cost = sample.get_hmm_cost(num_mismatches)
484-
logger.debug(
485-
f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}"
486-
)
487420

488-
match_db.add(samples, date, num_mismatches)
421+
exceeding_threshold = []
422+
for sample in run_batch:
423+
cost = sample.get_hmm_cost(num_mismatches)
424+
logger.debug(f"HMM@p={precision}: hmm_cost={cost} {sample.summary()}")
425+
if cost > cost_threshold:
426+
sample.path.clear()
427+
sample.mutations.clear()
428+
exceeding_threshold.append(sample)
429+
430+
num_matches_found = len(run_batch) - len(exceeding_threshold)
431+
logger.info(
432+
f"{num_matches_found} final matches for found p={precision}; "
433+
f"{len(exceeding_threshold)} remain"
434+
)
435+
run_batch = exceeding_threshold
436+
437+
precision = 6
438+
logger.info(f"Running final batch of {len(run_batch)} at p={precision}")
439+
match_tsinfer(
440+
samples=run_batch,
441+
ts=base_ts,
442+
num_mismatches=num_mismatches,
443+
precision=precision,
444+
num_threads=num_threads,
445+
show_progress=show_progress,
446+
)
447+
for sample in run_batch:
448+
cost = sample.get_hmm_cost(num_mismatches)
449+
# print(f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}")
450+
logger.debug(f"Final HMM pass hmm_cost={cost} {sample.summary()}")
451+
return samples
489452

490453

491454
def check_base_ts(ts):
@@ -561,7 +524,6 @@ def extend(
561524
min_group_size = 10
562525

563526
# TMP
564-
precision = 6
565527
check_base_ts(base_ts)
566528
logger.info(
567529
f"Extend {date}; ts:nodes={base_ts.num_nodes};samples={base_ts.num_samples};"
@@ -584,17 +546,16 @@ def extend(
584546
f"Got alignments for {len(samples)} of {len(metadata_matches)} in metadata"
585547
)
586548

587-
match_samples(
549+
samples = match_samples(
588550
date,
589551
samples,
590552
base_ts=base_ts,
591-
match_db=match_db,
592553
num_mismatches=num_mismatches,
593554
show_progress=show_progress,
594555
num_threads=num_threads,
595-
precision=precision,
596556
)
597557

558+
match_db.add(samples, date, num_mismatches)
598559
match_db.create_mask_table(base_ts)
599560
ts = increment_time(date, base_ts)
600561

@@ -641,6 +602,18 @@ def update_top_level_metadata(ts, date):
641602
return tables.tree_sequence()
642603

643604

605+
def add_sample_to_tables(sample, tables, flags=tskit.NODE_IS_SAMPLE, time=0):
606+
metadata = {
607+
**sample.metadata,
608+
"sc2ts": {
609+
"qc": sample.alignment_qc,
610+
"path": [x.asdict() for x in sample.path],
611+
"mutations": [x.asdict() for x in sample.mutations],
612+
},
613+
}
614+
return tables.nodes.add_row(flags=flags, time=time, metadata=metadata)
615+
616+
644617
def match_path_ts(samples, ts, path, reversions):
645618
"""
646619
Given the specified list of samples with equal copying paths,
@@ -659,17 +632,7 @@ def match_path_ts(samples, ts, path, reversions):
659632
)
660633
for sample in samples:
661634
assert sample.path == path
662-
metadata = {
663-
**sample.metadata,
664-
"sc2ts": {
665-
"qc": sample.alignment_qc,
666-
"path": [x.asdict() for x in sample.path],
667-
"mutations": [x.asdict() for x in sample.mutations],
668-
},
669-
}
670-
node_id = tables.nodes.add_row(
671-
flags=tskit.NODE_IS_SAMPLE, time=0, metadata=metadata
672-
)
635+
node_id = add_sample_to_tables(sample, tables)
673636
tables.edges.add_row(0, ts.sequence_length, parent=0, child=node_id)
674637
for mut in sample.mutations:
675638
if mut.site_id not in site_id_map:
@@ -707,10 +670,10 @@ def add_exact_matches(match_db, ts, date):
707670
for sample in samples:
708671
assert len(sample.path) == 1
709672
assert len(sample.mutations) == 0
710-
node_id = tables.nodes.add_row(
673+
node_id = add_sample_to_tables(
674+
sample,
675+
tables,
711676
flags=tskit.NODE_IS_SAMPLE | core.NODE_IS_EXACT_MATCH,
712-
time=0,
713-
metadata=sample.metadata,
714677
)
715678
parent = sample.path[0].parent
716679
logger.debug(f"ARG add exact match {sample.strain}:{node_id}->{parent}")
@@ -843,23 +806,21 @@ def solve_num_mismatches(ts, k):
843806
NOTE! This is NOT taking into account the spatial distance along
844807
the genome, and so is not a very good model in some ways.
845808
"""
809+
# We can match against any node in tsinfer
846810
m = ts.num_sites
847-
n = ts.num_nodes # We can match against any node in tsinfer
848-
if k == 0:
849-
# Pathological things happen when k=0
850-
r = 1e-3
851-
mu = 1e-20
852-
else:
853-
# NOTE: the magnitude of mu matters because it puts a limit
854-
# on how low we can push the HMM precision. We should be able to solve
855-
# for the optimal value of this parameter such that the magnitude of the
856-
# values within the HMM are as large as possible (so that we can truncate
857-
# usefully).
858-
mu = 1e-3
859-
denom = (1 - mu) ** k + (n - 1) * mu**k
860-
r = n * mu**k / denom
861-
assert mu < 0.5
862-
assert r < 0.5
811+
n = ts.num_nodes
812+
# values of k <= 1 are not relevant for SC2 and lead to awkward corner cases
813+
assert k > 1
814+
815+
# NOTE: the magnitude of mu matters because it puts a limit
816+
# on how low we can push the HMM precision. We should be able to solve
817+
# for the optimal value of this parameter such that the magnitude of the
818+
# values within the HMM are as large as possible (so that we can truncate
819+
# usefully).
820+
# mu = 1e-2
821+
mu = 0.125
822+
denom = (1 - mu) ** k + (n - 1) * mu**k
823+
r = n * mu**k / denom
863824

864825
# Add a little bit of extra mass for recombination so that we deterministically
865826
# chose to recombine over k mutations
@@ -1312,6 +1273,8 @@ def match_tsinfer(
13121273
show_progress=False,
13131274
mirror_coordinates=False,
13141275
):
1276+
if len(samples) == 0:
1277+
return
13151278
genotypes = np.array([sample.alignment for sample in samples], dtype=np.int8).T
13161279
input_ts = ts
13171280
if mirror_coordinates:
@@ -1478,7 +1441,7 @@ def get_closest_mutation(node, site_id):
14781441
sample.mutations.append(
14791442
MatchMutation(
14801443
site_id=site_id,
1481-
site_position=site_pos,
1444+
site_position=int(site_pos),
14821445
derived_state=derived_state,
14831446
inherited_state=inherited_state,
14841447
is_reversion=is_reversion,

0 commit comments

Comments
 (0)