@@ -102,10 +102,7 @@ def add(self, samples, date, num_mismatches):
102
102
pkl_compressed ,
103
103
)
104
104
data .append (args )
105
- pango = sample .metadata .get ("Viridian_pangolin" , "Unknown" )
106
- logger .debug (
107
- f"MatchDB insert: { sample .strain } { date } { pango } hmm_cost={ hmm_cost [j ]} "
108
- )
105
+ logger .debug (f"MatchDB insert: hmm_cost={ hmm_cost [j ]} { sample .summary ()} " )
109
106
# Batch insert, for efficiency.
110
107
with self .conn :
111
108
self .conn .executemany (sql , data )
@@ -150,11 +147,7 @@ def get(self, where_clause):
150
147
for row in self .conn .execute (sql ):
151
148
pkl = row .pop ("pickle" )
152
149
sample = pickle .loads (bz2 .decompress (pkl ))
153
- pango = sample .metadata .get ("Viridian_pangolin" , "Unknown" )
154
- logger .debug (
155
- f"MatchDb got: { sample .strain } { sample .date } { pango } "
156
- f"hmm_cost={ row ['hmm_cost' ]} "
157
- )
150
+ logger .debug (f"MatchDb got: { sample .summary ()} hmm_cost={ row ['hmm_cost' ]} " )
158
151
# print(row)
159
152
yield sample
160
153
@@ -364,6 +357,18 @@ class Sample:
364
357
# def __str__(self):
365
358
# return f"{self.strain}: {self.path} + {self.mutations}"
366
359
360
+ def path_summary (self ):
361
+ return "," .join (f"({ seg .left } :{ seg .right } , { seg .parent } )" for seg in self .path )
362
+
363
+ def mutation_summary (self ):
364
+ return "[" + "," .join (str (mutation ) for mutation in self .mutations ) + "]"
365
+
366
+ def summary (self ):
367
+ pango = self .metadata .get ("Viridian_pangolin" , "Unknown" )
368
+ return (f"{ self .strain } { self .date } { pango } path={ self .path_summary ()} "
369
+ f"mutations({ len (self .mutations )} )={ self .mutation_summary ()} "
370
+ )
371
+
367
372
@property
368
373
def breakpoints (self ):
369
374
breakpoints = [seg .left for seg in self .path ]
@@ -388,104 +393,62 @@ def asdict(self):
388
393
}
389
394
390
395
391
- # def daily_extend(
392
- # *,
393
- # alignment_store,
394
- # metadata_db,
395
- # base_ts,
396
- # match_db,
397
- # num_mismatches=None,
398
- # max_hmm_cost=None,
399
- # min_group_size=None,
400
- # num_past_days=None,
401
- # show_progress=False,
402
- # max_submission_delay=None,
403
- # max_daily_samples=None,
404
- # num_threads=None,
405
- # precision=None,
406
- # rng=None,
407
- # excluded_sample_dir=None,
408
- # ):
409
- # assert num_past_days is None
410
- # assert max_submission_delay is None
411
-
412
- # start_day = last_date(base_ts)
413
-
414
- # last_ts = base_ts
415
- # for date in metadata_db.get_days(start_day):
416
- # ts = extend(
417
- # alignment_store=alignment_store,
418
- # metadata_db=metadata_db,
419
- # date=date,
420
- # base_ts=last_ts,
421
- # match_db=match_db,
422
- # num_mismatches=num_mismatches,
423
- # max_hmm_cost=max_hmm_cost,
424
- # min_group_size=min_group_size,
425
- # show_progress=show_progress,
426
- # max_submission_delay=max_submission_delay,
427
- # max_daily_samples=max_daily_samples,
428
- # num_threads=num_threads,
429
- # precision=precision,
430
- # )
431
- # yield ts, date
432
-
433
- # last_ts = ts
434
-
435
-
436
396
def match_samples (
437
397
date ,
438
398
samples ,
439
399
* ,
440
- match_db ,
441
400
base_ts ,
442
401
num_mismatches = None ,
443
402
show_progress = False ,
444
403
num_threads = None ,
445
- precision = None ,
446
- mirror_coordinates = False ,
447
404
):
448
- if num_mismatches is None :
449
- # Default to no recombination
450
- num_mismatches = 1000
451
-
452
- match_tsinfer (
453
- samples = samples ,
454
- ts = base_ts ,
455
- num_mismatches = num_mismatches ,
456
- precision = 2 ,
457
- num_threads = num_threads ,
458
- show_progress = show_progress ,
459
- mirror_coordinates = mirror_coordinates ,
460
- )
461
- samples_to_rerun = []
462
- for sample in samples :
463
- hmm_cost = sample .get_hmm_cost (num_mismatches )
464
- logger .debug (
465
- f"First sketch: { sample .strain } hmm_cost={ hmm_cost } path={ sample .path } "
466
- )
467
- if hmm_cost >= 2 :
468
- sample .path .clear ()
469
- sample .mutations .clear ()
470
- samples_to_rerun .append (sample )
471
-
472
- if len (samples_to_rerun ) > 0 :
405
+ # First pass, compute the matches at precision=0.
406
+ run_batch = samples
407
+
408
+ # Values based on https://github.com/jeromekelleher/sc2ts/issues/242,
409
+ # but somewhat arbitrary.
410
+ for precision , cost_threshold in [(0 , 1 ), (1 , 2 ), (2 , 3 )]:
411
+ logger .info (f"Running batch of { len (run_batch )} at p={ precision } " )
473
412
match_tsinfer (
474
- samples = samples_to_rerun ,
413
+ samples = run_batch ,
475
414
ts = base_ts ,
476
415
num_mismatches = num_mismatches ,
477
416
precision = precision ,
478
417
num_threads = num_threads ,
479
418
show_progress = show_progress ,
480
- mirror_coordinates = mirror_coordinates ,
481
419
)
482
- for sample in samples_to_rerun :
483
- hmm_cost = sample .get_hmm_cost (num_mismatches )
484
- logger .debug (
485
- f"Final HMM pass:{ sample .strain } hmm_cost={ hmm_cost } path={ sample .path } "
486
- )
487
420
488
- match_db .add (samples , date , num_mismatches )
421
+ exceeding_threshold = []
422
+ for sample in run_batch :
423
+ cost = sample .get_hmm_cost (num_mismatches )
424
+ logger .debug (f"HMM@p={ precision } : hmm_cost={ cost } { sample .summary ()} " )
425
+ if cost > cost_threshold :
426
+ sample .path .clear ()
427
+ sample .mutations .clear ()
428
+ exceeding_threshold .append (sample )
429
+
430
+ num_matches_found = len (run_batch ) - len (exceeding_threshold )
431
+ logger .info (
432
+ f"{ num_matches_found } final matches for found p={ precision } ; "
433
+ f"{ len (exceeding_threshold )} remain"
434
+ )
435
+ run_batch = exceeding_threshold
436
+
437
+ precision = 6
438
+ logger .info (f"Running final batch of { len (run_batch )} at p={ precision } " )
439
+ match_tsinfer (
440
+ samples = run_batch ,
441
+ ts = base_ts ,
442
+ num_mismatches = num_mismatches ,
443
+ precision = precision ,
444
+ num_threads = num_threads ,
445
+ show_progress = show_progress ,
446
+ )
447
+ for sample in run_batch :
448
+ cost = sample .get_hmm_cost (num_mismatches )
449
+ # print(f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}")
450
+ logger .debug (f"Final HMM pass hmm_cost={ cost } { sample .summary ()} " )
451
+ return samples
489
452
490
453
491
454
def check_base_ts (ts ):
@@ -561,7 +524,6 @@ def extend(
561
524
min_group_size = 10
562
525
563
526
# TMP
564
- precision = 6
565
527
check_base_ts (base_ts )
566
528
logger .info (
567
529
f"Extend { date } ; ts:nodes={ base_ts .num_nodes } ;samples={ base_ts .num_samples } ;"
@@ -584,17 +546,16 @@ def extend(
584
546
f"Got alignments for { len (samples )} of { len (metadata_matches )} in metadata"
585
547
)
586
548
587
- match_samples (
549
+ samples = match_samples (
588
550
date ,
589
551
samples ,
590
552
base_ts = base_ts ,
591
- match_db = match_db ,
592
553
num_mismatches = num_mismatches ,
593
554
show_progress = show_progress ,
594
555
num_threads = num_threads ,
595
- precision = precision ,
596
556
)
597
557
558
+ match_db .add (samples , date , num_mismatches )
598
559
match_db .create_mask_table (base_ts )
599
560
ts = increment_time (date , base_ts )
600
561
@@ -641,6 +602,18 @@ def update_top_level_metadata(ts, date):
641
602
return tables .tree_sequence ()
642
603
643
604
605
+ def add_sample_to_tables (sample , tables , flags = tskit .NODE_IS_SAMPLE , time = 0 ):
606
+ metadata = {
607
+ ** sample .metadata ,
608
+ "sc2ts" : {
609
+ "qc" : sample .alignment_qc ,
610
+ "path" : [x .asdict () for x in sample .path ],
611
+ "mutations" : [x .asdict () for x in sample .mutations ],
612
+ },
613
+ }
614
+ return tables .nodes .add_row (flags = flags , time = time , metadata = metadata )
615
+
616
+
644
617
def match_path_ts (samples , ts , path , reversions ):
645
618
"""
646
619
Given the specified list of samples with equal copying paths,
@@ -659,17 +632,7 @@ def match_path_ts(samples, ts, path, reversions):
659
632
)
660
633
for sample in samples :
661
634
assert sample .path == path
662
- metadata = {
663
- ** sample .metadata ,
664
- "sc2ts" : {
665
- "qc" : sample .alignment_qc ,
666
- "path" : [x .asdict () for x in sample .path ],
667
- "mutations" : [x .asdict () for x in sample .mutations ],
668
- },
669
- }
670
- node_id = tables .nodes .add_row (
671
- flags = tskit .NODE_IS_SAMPLE , time = 0 , metadata = metadata
672
- )
635
+ node_id = add_sample_to_tables (sample , tables )
673
636
tables .edges .add_row (0 , ts .sequence_length , parent = 0 , child = node_id )
674
637
for mut in sample .mutations :
675
638
if mut .site_id not in site_id_map :
@@ -707,10 +670,10 @@ def add_exact_matches(match_db, ts, date):
707
670
for sample in samples :
708
671
assert len (sample .path ) == 1
709
672
assert len (sample .mutations ) == 0
710
- node_id = tables .nodes .add_row (
673
+ node_id = add_sample_to_tables (
674
+ sample ,
675
+ tables ,
711
676
flags = tskit .NODE_IS_SAMPLE | core .NODE_IS_EXACT_MATCH ,
712
- time = 0 ,
713
- metadata = sample .metadata ,
714
677
)
715
678
parent = sample .path [0 ].parent
716
679
logger .debug (f"ARG add exact match { sample .strain } :{ node_id } ->{ parent } " )
@@ -843,23 +806,21 @@ def solve_num_mismatches(ts, k):
843
806
NOTE! This is NOT taking into account the spatial distance along
844
807
the genome, and so is not a very good model in some ways.
845
808
"""
809
+ # We can match against any node in tsinfer
846
810
m = ts .num_sites
847
- n = ts .num_nodes # We can match against any node in tsinfer
848
- if k == 0 :
849
- # Pathological things happen when k=0
850
- r = 1e-3
851
- mu = 1e-20
852
- else :
853
- # NOTE: the magnitude of mu matters because it puts a limit
854
- # on how low we can push the HMM precision. We should be able to solve
855
- # for the optimal value of this parameter such that the magnitude of the
856
- # values within the HMM are as large as possible (so that we can truncate
857
- # usefully).
858
- mu = 1e-3
859
- denom = (1 - mu ) ** k + (n - 1 ) * mu ** k
860
- r = n * mu ** k / denom
861
- assert mu < 0.5
862
- assert r < 0.5
811
+ n = ts .num_nodes
812
+ # values of k <= 1 are not relevant for SC2 and lead to awkward corner cases
813
+ assert k > 1
814
+
815
+ # NOTE: the magnitude of mu matters because it puts a limit
816
+ # on how low we can push the HMM precision. We should be able to solve
817
+ # for the optimal value of this parameter such that the magnitude of the
818
+ # values within the HMM are as large as possible (so that we can truncate
819
+ # usefully).
820
+ # mu = 1e-2
821
+ mu = 0.125
822
+ denom = (1 - mu ) ** k + (n - 1 ) * mu ** k
823
+ r = n * mu ** k / denom
863
824
864
825
# Add a little bit of extra mass for recombination so that we deterministically
865
826
# chose to recombine over k mutations
@@ -1312,6 +1273,8 @@ def match_tsinfer(
1312
1273
show_progress = False ,
1313
1274
mirror_coordinates = False ,
1314
1275
):
1276
+ if len (samples ) == 0 :
1277
+ return
1315
1278
genotypes = np .array ([sample .alignment for sample in samples ], dtype = np .int8 ).T
1316
1279
input_ts = ts
1317
1280
if mirror_coordinates :
@@ -1478,7 +1441,7 @@ def get_closest_mutation(node, site_id):
1478
1441
sample .mutations .append (
1479
1442
MatchMutation (
1480
1443
site_id = site_id ,
1481
- site_position = site_pos ,
1444
+ site_position = int ( site_pos ) ,
1482
1445
derived_state = derived_state ,
1483
1446
inherited_state = inherited_state ,
1484
1447
is_reversion = is_reversion ,
0 commit comments