Skip to content

Commit fbed3c1

Browse files
committed
Added SmallVector crate for small vector optimization
1 parent fdadd61 commit fbed3c1

File tree

6 files changed

+95
-50
lines changed

6 files changed

+95
-50
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ codegen-units = 1
2424
members = ["macros"]
2525

2626
[dependencies]
27+
anyhow = "1.0"
2728
block-id = "0.1.2"
2829
cached = "0.53.1"
2930
clap = { version = "4.3.11", features = ["derive"] }
@@ -44,7 +45,7 @@ serde = { version = "1.0.171", features = ["derive"] }
4445
serde_yaml = "0.9.22"
4546
seq_io = "0.3.1"
4647
simple-logging = "2.0.2"
47-
anyhow = "1.0"
48+
smallvec = "1.13.2"
4849
rayon = { version = "1.7.0", optional = true }
4950
macros = { path = "macros" }
5051

src/core/fitness/epistasis.rs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use npyz::WriterBuilder;
2+
use smallvec::SmallVec;
23
use std::collections::HashMap;
34

5+
use crate::core::haplotype::Changes;
46
use crate::encoding::Symbol;
57
use crate::errors::VirolutionError;
68
use crate::references::HaplotypeRef;
@@ -118,11 +120,12 @@ impl<S: Symbol> EpistasisTable<S> {
118120
// have to call `get_mutations` on the haplotype which can be expensive.
119121
let candidate_changes = changes
120122
.iter()
121-
.filter(|(position, (old, new))| {
122-
self.table.contains_key(&(**position, *old))
123-
|| self.table.contains_key(&(**position, *new))
123+
.filter(|changes| {
124+
self.table.contains_key(&(changes.position, changes.to))
125+
|| self.table.contains_key(&(changes.position, changes.to))
124126
})
125-
.collect::<HashMap<_, _>>();
127+
.cloned()
128+
.collect::<SmallVec<Changes<S>>>();
126129

127130
// If no changes are present in the table, return the fitness
128131
if candidate_changes.is_empty() {
@@ -134,21 +137,21 @@ impl<S: Symbol> EpistasisTable<S> {
134137
let mut fitness = 1.;
135138

136139
// Add any epistatic effects
137-
for (position, (old, new)) in &candidate_changes {
140+
for change in &candidate_changes {
138141
// Get any interactions of the current mutation
139-
let interactions_add = self.table.get(&(**position, *new));
140-
let interactions_remove = self.table.get(&(**position, *old));
142+
let interactions_add = self.table.get(&(change.position, change.to));
143+
let interactions_remove = self.table.get(&(change.position, change.from));
141144

142145
// Apply any interactions
143146
for (pos, current) in mutations.iter() {
144147
// Deal with a rare case where two mutations have interactions with each
145148
// other. In this case, we enforce that the interaction is only applied
146149
// once, by enforcing an order when reading epistatic interactions.
147-
if changes.contains_key(pos) && pos <= position {
150+
if pos <= &change.position && changes.iter().any(|c| c.position == *pos) {
148151
continue;
149152
}
150153

151-
if pos != *position {
154+
if pos != &change.position {
152155
if let Some(interaction) = interactions_add {
153156
// If there is an interaction, multiply the fitness
154157
if let Some(v) = interaction.get(&(*pos, S::decode(current))) {

src/core/fitness/table.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,9 @@ impl FitnessTable {
6262
};
6363

6464
// Compute the fitness update
65-
changes.iter().fold(1., |acc, (position, (old, new))| {
66-
acc * self.get_value(position, new) / self.get_value(position, old)
65+
changes.iter().fold(1., |acc, change| {
66+
acc * self.get_value(&change.position, &change.to)
67+
/ self.get_value(&change.position, &change.from)
6768
})
6869
}
6970

src/core/haplotype.rs

Lines changed: 76 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
2121
use derivative::Derivative;
2222
use seq_io::fasta::OwnedRecord;
23+
use smallvec::SmallVec;
2324
use std::cell::Cell;
2425
use std::collections::HashMap;
2526
use std::fmt;
@@ -62,6 +63,17 @@ fn make_fitness_cache() -> Vec<OnceLock<f64>> {
6263
// #[derive(Clone, Debug, Deref)]
6364
// pub type Symbol = Option<u8>;
6465

66+
const SMALL_VEC_SIZE: usize = 1;
67+
68+
#[derive(Debug, Clone)]
69+
pub struct Change<S: Symbol> {
70+
pub position: usize,
71+
pub from: S,
72+
pub to: S,
73+
}
74+
75+
pub type Changes<S> = [Change<S>; SMALL_VEC_SIZE];
76+
6577
#[derive(Debug)]
6678
pub enum Haplotype<S: Symbol> {
6779
Wildtype(Wildtype<S>),
@@ -71,44 +83,65 @@ pub enum Haplotype<S: Symbol> {
7183

7284
#[derive(Debug)]
7385
pub struct Wildtype<S: Symbol> {
86+
// head
7487
reference: HaplotypeWeak<S>,
75-
sequence: Vec<S>,
7688
descendants: DescendantsCell<S>,
89+
90+
// body
91+
sequence: Vec<S>,
92+
93+
// sync
7794
// number of descendants that have died, we can replace their weak references
7895
_dirty_descendants: AtomicIsize,
7996
}
8097

8198
#[derive(Derivative)]
8299
#[derivative(Debug)]
83100
pub struct Mutant<S: Symbol> {
101+
// head
84102
reference: HaplotypeWeak<S>,
85103
wildtype: HaplotypeWeak<S>,
86104
ancestor: HaplotypeRef<S>,
87-
changes: HashMap<usize, (S, S)>,
105+
descendants: DescendantsCell<S>,
106+
107+
// body
88108
generation: usize,
109+
changes: SmallVec<Changes<S>>,
89110
fitness: Vec<OnceLock<f64>>,
90-
descendants: DescendantsCell<S>,
91-
// number of descendants that have died, we can replace their weak references
111+
112+
// sync
113+
// stores the number of descendants that have died
114+
// this allows us to replace any weak references instead of allocating more memory
92115
_dirty_descendants: AtomicIsize,
93116
// synchronization for merging nodes while allowing for parallel access
94-
// ask to defer drop
117+
// this field will be used to request deferred drops, when it is non-zero, merges will not
118+
// drop but instead defer the drop to the setter
119+
// any thread that requires a deferred drop is responsible for decrementing this field before
120+
// dropping the reference (safe handling is ensured by the `require_deferred_drop` attribute)
95121
_defer_drop: Arc<Mutex<usize>>,
96-
// request deferred drop
122+
// this field will be used to create a self-reference to the haplotype after it has been
123+
// removed from the tree, this allows us to safely drop the haplotype when it is no longer used
124+
// by consuming the reference when it is no longer used
97125
#[derivative(Debug = "ignore")]
98126
_drop: Cell<Option<HaplotypeRef<S>>>,
99127
}
100128

101129
#[derive(Debug)]
102130
pub struct Recombinant<S: Symbol> {
131+
// head
103132
reference: HaplotypeWeak<S>,
104133
wildtype: HaplotypeWeak<S>,
105134
left_ancestor: HaplotypeRef<S>,
106135
right_ancestor: HaplotypeRef<S>,
136+
descendants: DescendantsCell<S>,
137+
138+
// body
107139
left_position: usize,
108140
right_position: usize,
109141
generation: usize,
110142
fitness: Vec<OnceLock<f64>>,
111-
descendants: DescendantsCell<S>,
143+
144+
// sync
112145
// number of descendants that have died, we can replace their weak references
113146
_dirty_descendants: AtomicIsize,
114147
}
@@ -148,18 +181,15 @@ impl<S: Symbol> Haplotype<S> {
148181
let ancestor = self.get_reference();
149182
let wildtype = self.get_wildtype();
150183

151-
let changes = HashMap::from_iter(
152-
positions
153-
.iter()
154-
.zip(changes.iter())
155-
.map(|(pos, sym)| (*pos, (self.get_base(pos), *sym))),
156-
);
157-
158-
changes.iter().for_each(|(pos, change)| {
159-
if change.0 == change.1 {
160-
dbg!(pos, change);
161-
}
162-
});
184+
let changes = positions
185+
.iter()
186+
.zip(changes.iter())
187+
.map(|(&position, &to)| {
188+
let from = self.get_base(&position);
189+
assert_ne!(from, to);
190+
Change { position, from, to }
191+
})
192+
.collect();
163193

164194
let descendant = Mutant::new(ancestor, wildtype.get_weak(), changes, generation);
165195

@@ -207,7 +237,7 @@ impl<S: Symbol> Haplotype<S> {
207237
}
208238

209239
/// Returns a reference to the changes that are present in the haplotype if the type allows it.
210-
pub fn try_get_changes(&self) -> Option<&HashMap<usize, (S, S)>> {
240+
pub fn try_get_changes(&self) -> Option<&SmallVec<Changes<S>>> {
211241
match self {
212242
Haplotype::Mutant(ht) => Some(&ht.changes),
213243
_ => None,
@@ -452,15 +482,15 @@ impl<S: Symbol> Mutant<S> {
452482
pub fn new(
453483
ancestor: HaplotypeRef<S>,
454484
wildtype: HaplotypeWeak<S>,
455-
changes: HashMap<usize, (S, S)>,
485+
changes: SmallVec<Changes<S>>,
456486
generation: usize,
457487
) -> HaplotypeRef<S> {
458488
HaplotypeRef::new_cyclic(|reference| {
459489
Haplotype::Mutant(Self {
460490
reference: reference.clone(),
461491
wildtype: wildtype.clone(),
462492
ancestor: ancestor.clone(),
463-
changes: changes.clone(),
493+
changes,
464494
generation,
465495
fitness: make_fitness_cache(),
466496
descendants: DescendantsCell::new(),
@@ -478,7 +508,7 @@ impl<S: Symbol> Mutant<S> {
478508
last: &Self,
479509
ancestor: HaplotypeRef<S>,
480510
wildtype: HaplotypeWeak<S>,
481-
changes: HashMap<usize, (S, S)>,
511+
changes: SmallVec<Changes<S>>,
482512
generation: usize,
483513
) {
484514
// collect all descendants that are still alive
@@ -525,13 +555,13 @@ impl<S: Symbol> Mutant<S> {
525555

526556
#[require_deferred_drop]
527557
pub fn get_base(&self, position: &usize) -> S {
528-
match self.changes.get(position) {
529-
Some((_from, to)) => *to,
558+
match self.changes.iter().find(|x| x.position == *position) {
559+
Some(change) => change.to,
530560
None => self.ancestor.get_base(position),
531561
}
532562
}
533563

534-
pub fn iter_changes(&self) -> impl Iterator<Item = (&usize, &(S, S))> + '_ {
564+
pub fn iter_changes(&self) -> impl Iterator<Item = &Change<S>> + '_ {
535565
self.changes.iter()
536566
}
537567

@@ -547,7 +577,7 @@ impl<S: Symbol> Mutant<S> {
547577
self.ancestor.clone()
548578
}
549579

550-
pub fn get_changes(&self) -> &HashMap<usize, (S, S)> {
580+
pub fn get_changes(&self) -> &SmallVec<Changes<S>> {
551581
&self.changes
552582
}
553583

@@ -561,12 +591,12 @@ impl<S: Symbol> Mutant<S> {
561591

562592
assert_eq!(&self.changes as *const _, key);
563593

564-
self.changes.iter().for_each(|(position, (_, new))| {
565-
let wt_base = wt_ref.get_base(position);
566-
if *new == wt_base {
567-
mutations.remove(position);
594+
self.changes.iter().for_each(|change| {
595+
let wt_base = wt_ref.get_base(&change.position);
596+
if change.to == wt_base {
597+
mutations.remove(&change.position);
568598
} else {
569-
mutations.insert(*position, *new.index() as u8);
599+
mutations.insert(change.position, *change.to.index() as u8);
570600
}
571601
});
572602

@@ -636,11 +666,11 @@ impl<S: Symbol> Mutant<S> {
636666
let merger: [&Mutant<S>; 2] = [descendant_inner, self];
637667

638668
// aggregate changes
639-
let changes: HashMap<usize, (S, S)> = merger
669+
let changes: SmallVec<Changes<S>> = merger
640670
.iter()
641671
.rev()
642672
.flat_map(|x| x.changes.iter())
643-
.map(|(position, change)| (*position, *change))
673+
.cloned()
644674
.collect();
645675

646676
// determine generation
@@ -852,11 +882,20 @@ mod tests {
852882
let symbols = vec![Nt::A, Nt::T, Nt::C, Nt::G];
853883
let wt = Wildtype::new(symbols.clone());
854884
let _hts: Vec<HaplotypeRef<Nt>> = (0..100)
855-
.map(|i| wt.create_descendant(vec![0], vec![Nt::decode(&((i % Nt::SIZE) as u8))], 0))
885+
.map(|i| {
886+
wt.create_descendant(
887+
vec![0],
888+
vec![Nt::decode(&(((i % (Nt::SIZE - 1)) + 1) as u8))],
889+
0,
890+
)
891+
})
856892
.collect();
857-
for (position, descendant) in wt.get_descendants().lock().iter().enumerate() {
893+
for (i, descendant) in wt.get_descendants().lock().iter().enumerate() {
858894
if let Some(d) = descendant.upgrade() {
859-
assert_eq!(d.get_base(&0), Nt::decode(&((position % Nt::SIZE) as u8)));
895+
assert_eq!(
896+
d.get_base(&0),
897+
Nt::decode(&(((i % (Nt::SIZE - 1)) + 1) as u8))
898+
);
860899
} else {
861900
panic!();
862901
}

src/references/cell.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ impl<S: Symbol> HaplotypeRef<S> {
2525

2626
pub fn new_cyclic<F>(data_fn: F) -> Self
2727
where
28-
F: std::ops::Fn(&HaplotypeWeak<S>) -> Haplotype<S>,
28+
F: std::ops::FnOnce(&HaplotypeWeak<S>) -> Haplotype<S>,
2929
{
3030
Self(Rc::new_cyclic(|weak| data_fn(&HaplotypeWeak(weak.clone()))))
3131
}

0 commit comments

Comments
 (0)