20
20
21
21
use derivative:: Derivative ;
22
22
use seq_io:: fasta:: OwnedRecord ;
23
+ use smallvec:: SmallVec ;
23
24
use std:: cell:: Cell ;
24
25
use std:: collections:: HashMap ;
25
26
use std:: fmt;
@@ -62,6 +63,17 @@ fn make_fitness_cache() -> Vec<OnceLock<f64>> {
62
63
// #[derive(Clone, Debug, Deref)]
63
64
// pub type Symbol = Option<u8>;
64
65
66
+ const SMALL_VEC_SIZE : usize = 1 ;
67
+
68
+ #[ derive( Debug , Clone ) ]
69
+ pub struct Change < S : Symbol > {
70
+ pub position : usize ,
71
+ pub from : S ,
72
+ pub to : S ,
73
+ }
74
+
75
+ pub type Changes < S > = [ Change < S > ; SMALL_VEC_SIZE ] ;
76
+
65
77
#[ derive( Debug ) ]
66
78
pub enum Haplotype < S : Symbol > {
67
79
Wildtype ( Wildtype < S > ) ,
@@ -71,44 +83,65 @@ pub enum Haplotype<S: Symbol> {
71
83
72
84
#[ derive( Debug ) ]
73
85
pub struct Wildtype < S : Symbol > {
86
+ // head
74
87
reference : HaplotypeWeak < S > ,
75
- sequence : Vec < S > ,
76
88
descendants : DescendantsCell < S > ,
89
+
90
+ // body
91
+ sequence : Vec < S > ,
92
+
93
+ // sync
77
94
// number of descendants that have died, we can replace their weak references
78
95
_dirty_descendants : AtomicIsize ,
79
96
}
80
97
81
98
#[ derive( Derivative ) ]
82
99
#[ derivative( Debug ) ]
83
100
pub struct Mutant < S : Symbol > {
101
+ // head
84
102
reference : HaplotypeWeak < S > ,
85
103
wildtype : HaplotypeWeak < S > ,
86
104
ancestor : HaplotypeRef < S > ,
87
- changes : HashMap < usize , ( S , S ) > ,
105
+ descendants : DescendantsCell < S > ,
106
+
107
+ // body
88
108
generation : usize ,
109
+ changes : SmallVec < Changes < S > > ,
89
110
fitness : Vec < OnceLock < f64 > > ,
90
- descendants : DescendantsCell < S > ,
91
- // number of descendants that have died, we can replace their weak references
111
+
112
+ // sync
113
+ // stores the number of descendants that have died
114
+ // this allows us to replace any weak references instead of allocating more memory
92
115
_dirty_descendants : AtomicIsize ,
93
116
// synchronization for merging nodes while allowing for parallel access
94
- // ask to defer drop
117
+ // this field will be used to request deferred drops, when it is non-zero, merges will not
118
+ // drop but instead defer the drop to the setter
119
+ // any thread that requires a deferred drop is responsible for decrementing this field before
120
+ // dropping the reference (safe handling is ensured by the `require_deferred_drop` attribute)
95
121
_defer_drop : Arc < Mutex < usize > > ,
96
- // request deferred drop
122
+ // this field will be used to create a self-reference to the haplotype after it has been
123
+ // removed from the tree, this allows us to safely drop the haplotype when it is no longer used
124
+ // by consuming the reference when it is no longer used
97
125
#[ derivative( Debug = "ignore" ) ]
98
126
_drop : Cell < Option < HaplotypeRef < S > > > ,
99
127
}
100
128
101
129
#[ derive( Debug ) ]
102
130
pub struct Recombinant < S : Symbol > {
131
+ // head
103
132
reference : HaplotypeWeak < S > ,
104
133
wildtype : HaplotypeWeak < S > ,
105
134
left_ancestor : HaplotypeRef < S > ,
106
135
right_ancestor : HaplotypeRef < S > ,
136
+ descendants : DescendantsCell < S > ,
137
+
138
+ // body
107
139
left_position : usize ,
108
140
right_position : usize ,
109
141
generation : usize ,
110
142
fitness : Vec < OnceLock < f64 > > ,
111
- descendants : DescendantsCell < S > ,
143
+
144
+ // sync
112
145
// number of descendants that have died, we can replace their weak references
113
146
_dirty_descendants : AtomicIsize ,
114
147
}
@@ -148,18 +181,15 @@ impl<S: Symbol> Haplotype<S> {
148
181
let ancestor = self . get_reference ( ) ;
149
182
let wildtype = self . get_wildtype ( ) ;
150
183
151
- let changes = HashMap :: from_iter (
152
- positions
153
- . iter ( )
154
- . zip ( changes. iter ( ) )
155
- . map ( |( pos, sym) | ( * pos, ( self . get_base ( pos) , * sym) ) ) ,
156
- ) ;
157
-
158
- changes. iter ( ) . for_each ( |( pos, change) | {
159
- if change. 0 == change. 1 {
160
- dbg ! ( pos, change) ;
161
- }
162
- } ) ;
184
+ let changes = positions
185
+ . iter ( )
186
+ . zip ( changes. iter ( ) )
187
+ . map ( |( & position, & to) | {
188
+ let from = self . get_base ( & position) ;
189
+ assert_ne ! ( from, to) ;
190
+ Change { position, from, to }
191
+ } )
192
+ . collect ( ) ;
163
193
164
194
let descendant = Mutant :: new ( ancestor, wildtype. get_weak ( ) , changes, generation) ;
165
195
@@ -207,7 +237,7 @@ impl<S: Symbol> Haplotype<S> {
207
237
}
208
238
209
239
/// Returns a reference to the changes that are present in the haplotype if the type allows it.
210
- pub fn try_get_changes ( & self ) -> Option < & HashMap < usize , ( S , S ) > > {
240
+ pub fn try_get_changes ( & self ) -> Option < & SmallVec < Changes < S > > > {
211
241
match self {
212
242
Haplotype :: Mutant ( ht) => Some ( & ht. changes ) ,
213
243
_ => None ,
@@ -452,15 +482,15 @@ impl<S: Symbol> Mutant<S> {
452
482
pub fn new (
453
483
ancestor : HaplotypeRef < S > ,
454
484
wildtype : HaplotypeWeak < S > ,
455
- changes : HashMap < usize , ( S , S ) > ,
485
+ changes : SmallVec < Changes < S > > ,
456
486
generation : usize ,
457
487
) -> HaplotypeRef < S > {
458
488
HaplotypeRef :: new_cyclic ( |reference| {
459
489
Haplotype :: Mutant ( Self {
460
490
reference : reference. clone ( ) ,
461
491
wildtype : wildtype. clone ( ) ,
462
492
ancestor : ancestor. clone ( ) ,
463
- changes : changes . clone ( ) ,
493
+ changes,
464
494
generation,
465
495
fitness : make_fitness_cache ( ) ,
466
496
descendants : DescendantsCell :: new ( ) ,
@@ -478,7 +508,7 @@ impl<S: Symbol> Mutant<S> {
478
508
last : & Self ,
479
509
ancestor : HaplotypeRef < S > ,
480
510
wildtype : HaplotypeWeak < S > ,
481
- changes : HashMap < usize , ( S , S ) > ,
511
+ changes : SmallVec < Changes < S > > ,
482
512
generation : usize ,
483
513
) {
484
514
// collect all descendants that are still alive
@@ -525,13 +555,13 @@ impl<S: Symbol> Mutant<S> {
525
555
526
556
#[ require_deferred_drop]
527
557
pub fn get_base ( & self , position : & usize ) -> S {
528
- match self . changes . get ( position) {
529
- Some ( ( _from , to ) ) => * to,
558
+ match self . changes . iter ( ) . find ( |x| x . position == * position) {
559
+ Some ( change ) => change . to ,
530
560
None => self . ancestor . get_base ( position) ,
531
561
}
532
562
}
533
563
534
- pub fn iter_changes ( & self ) -> impl Iterator < Item = ( & usize , & ( S , S ) ) > + ' _ {
564
+ pub fn iter_changes ( & self ) -> impl Iterator < Item = & Change < S > > + ' _ {
535
565
self . changes . iter ( )
536
566
}
537
567
@@ -547,7 +577,7 @@ impl<S: Symbol> Mutant<S> {
547
577
self . ancestor . clone ( )
548
578
}
549
579
550
- pub fn get_changes ( & self ) -> & HashMap < usize , ( S , S ) > {
580
+ pub fn get_changes ( & self ) -> & SmallVec < Changes < S > > {
551
581
& self . changes
552
582
}
553
583
@@ -561,12 +591,12 @@ impl<S: Symbol> Mutant<S> {
561
591
562
592
assert_eq ! ( & self . changes as * const _, key) ;
563
593
564
- self . changes . iter ( ) . for_each ( |( position , ( _ , new ) ) | {
565
- let wt_base = wt_ref. get_base ( position) ;
566
- if * new == wt_base {
567
- mutations. remove ( position) ;
594
+ self . changes . iter ( ) . for_each ( |change | {
595
+ let wt_base = wt_ref. get_base ( & change . position ) ;
596
+ if change . to == wt_base {
597
+ mutations. remove ( & change . position ) ;
568
598
} else {
569
- mutations. insert ( * position, * new . index ( ) as u8 ) ;
599
+ mutations. insert ( change . position , * change . to . index ( ) as u8 ) ;
570
600
}
571
601
} ) ;
572
602
@@ -636,11 +666,11 @@ impl<S: Symbol> Mutant<S> {
636
666
let merger: [ & Mutant < S > ; 2 ] = [ descendant_inner, self ] ;
637
667
638
668
// aggregate changes
639
- let changes: HashMap < usize , ( S , S ) > = merger
669
+ let changes: SmallVec < Changes < S > > = merger
640
670
. iter ( )
641
671
. rev ( )
642
672
. flat_map ( |x| x. changes . iter ( ) )
643
- . map ( | ( position , change ) | ( * position , * change ) )
673
+ . cloned ( )
644
674
. collect ( ) ;
645
675
646
676
// determine generation
@@ -852,11 +882,20 @@ mod tests {
852
882
let symbols = vec ! [ Nt :: A , Nt :: T , Nt :: C , Nt :: G ] ;
853
883
let wt = Wildtype :: new ( symbols. clone ( ) ) ;
854
884
let _hts: Vec < HaplotypeRef < Nt > > = ( 0 ..100 )
855
- . map ( |i| wt. create_descendant ( vec ! [ 0 ] , vec ! [ Nt :: decode( & ( ( i % Nt :: SIZE ) as u8 ) ) ] , 0 ) )
885
+ . map ( |i| {
886
+ wt. create_descendant (
887
+ vec ! [ 0 ] ,
888
+ vec ! [ Nt :: decode( & ( ( ( i % ( Nt :: SIZE - 1 ) ) + 1 ) as u8 ) ) ] ,
889
+ 0 ,
890
+ )
891
+ } )
856
892
. collect ( ) ;
857
- for ( position , descendant) in wt. get_descendants ( ) . lock ( ) . iter ( ) . enumerate ( ) {
893
+ for ( i , descendant) in wt. get_descendants ( ) . lock ( ) . iter ( ) . enumerate ( ) {
858
894
if let Some ( d) = descendant. upgrade ( ) {
859
- assert_eq ! ( d. get_base( & 0 ) , Nt :: decode( & ( ( position % Nt :: SIZE ) as u8 ) ) ) ;
895
+ assert_eq ! (
896
+ d. get_base( & 0 ) ,
897
+ Nt :: decode( & ( ( ( i % ( Nt :: SIZE - 1 ) ) + 1 ) as u8 ) )
898
+ ) ;
860
899
} else {
861
900
panic ! ( ) ;
862
901
}
0 commit comments