Skip to content

Commit 09a50f0

Browse files
committed
Add ability to load epistatic effects
1 parent e986e3e commit 09a50f0

File tree

4 files changed

+43
-11
lines changed

4 files changed

+43
-11
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ evalexpr = "8.2.0"
2727
indicatif = "0.17.5"
2828
itertools = "0.10.5"
2929
log = "0.4.19"
30-
npyz = "0.7.4"
30+
npyz = { version = "0.8.3", features = ["derive"] }
3131
parking_lot = "0.12.1"
3232
phf = { version = "0.11.2", features = ["macros"] }
3333
rand = { version = "0.8.5", features = ["alloc"] }

src/core/fitness/epistasis.rs

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,51 @@ use crate::core::haplotype::Symbol;
44
use crate::errors::VirolutionError;
55
use crate::references::HaplotypeRef;
66

7-
use super::init::FitnessModel;
7+
use super::init::{FitnessDistribution, FitnessModel};
88

99
type EpistasisTableKey = (usize, Symbol);
1010

11+
#[derive(npyz::Deserialize)]
12+
pub struct EpiEntry {
13+
pos1: u64,
14+
base1: u8,
15+
pos2: u64,
16+
base2: u8,
17+
value: f64,
18+
}
19+
1120
#[derive(Clone, Debug)]
1221
pub struct EpistasisTable {
1322
table: HashMap<EpistasisTableKey, HashMap<EpistasisTableKey, f64>>,
1423
}
1524

1625
impl EpistasisTable {
1726
pub fn from_model(model: &FitnessModel) -> Result<Self, VirolutionError> {
18-
todo!()
27+
let table = match &model.distribution {
28+
FitnessDistribution::Epistatic(_, epi_params) => {
29+
let entries = epi_params.load_table();
30+
Self::from_vec(entries)
31+
}
32+
_ => {
33+
return Err(VirolutionError::ImplementationError(
34+
"Model is not epistatic".to_string(),
35+
));
36+
}
37+
};
38+
Ok(table)
1939
}
2040

21-
pub fn from_vec(entries: Vec<(usize, Symbol, usize, Symbol, f64)>) -> Self {
41+
pub fn from_vec(entries: Vec<EpiEntry>) -> Self {
2242
let mut table: HashMap<EpistasisTableKey, HashMap<EpistasisTableKey, f64>> = HashMap::new();
23-
for (pos1, base1, pos2, base2, value) in entries.iter() {
43+
for entry in entries.iter() {
2444
table
25-
.entry((*pos1, *base1))
45+
.entry((entry.pos1 as usize, Some(entry.base1)))
2646
.or_default()
27-
.insert((*pos2, *base2), *value);
47+
.insert((entry.pos2 as usize, Some(entry.base2)), entry.value);
2848
table
29-
.entry((*pos2, *base2))
49+
.entry((entry.pos2 as usize, Some(entry.base2)))
3050
.or_default()
31-
.insert((*pos1, *base1), *value);
51+
.insert((entry.pos1 as usize, Some(entry.base1)), entry.value);
3252
}
3353
Self { table }
3454
}

src/core/fitness/init.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use serde::{Deserialize, Serialize};
55

66
use crate::core::haplotype::Symbol;
77

8+
use super::epistasis::EpiEntry;
89
use super::utility::UtilityFunction;
910

1011
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
@@ -167,3 +168,14 @@ impl FileParameters {
167168
.collect()
168169
}
169170
}
171+
172+
impl EpiFileParameters {
173+
pub fn load_table(&self) -> Vec<EpiEntry> {
174+
let reader = NpyFile::new(std::fs::File::open(&self.path).unwrap()).unwrap();
175+
reader
176+
.data::<EpiEntry>()
177+
.unwrap()
178+
.map(|entry| entry.unwrap())
179+
.collect()
180+
}
181+
}

src/core/fitness/table.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ impl FitnessTable {
2929
) -> Result<Self, VirolutionError> {
3030
let n_sites = sequence.len();
3131
let table = match fitness_model.distribution {
32+
FitnessDistribution::Neutral => vec![1.; n_sites * n_symbols],
3233
FitnessDistribution::Exponential(ref params) => {
3334
params.create_table(n_symbols, sequence)
3435
}
3536
FitnessDistribution::Lognormal(ref params) => params.create_table(n_symbols, sequence),
3637
FitnessDistribution::File(ref params) => params.load_table(),
37-
FitnessDistribution::Neutral => vec![1.; n_sites * n_symbols],
38-
FitnessDistribution::Epistatic(_, _) => todo!(),
38+
FitnessDistribution::Epistatic(ref params, _) => params.load_table(),
3939
};
4040

4141
if table.len() != n_sites * n_symbols {

0 commit comments

Comments
 (0)