Skip to content

Commit 485085c

Browse files
committed
Use FitnessFunction trait in FitnessProvider
1 parent 7f06d46 commit 485085c

File tree

1 file changed

+134
-48
lines changed

1 file changed

+134
-48
lines changed

src/providers/fitness.rs

Lines changed: 134 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,139 @@ use crate::references::HaplotypeRef;
1717
#[derive(Clone, Debug)]
1818
pub struct FitnessProvider<S: Symbol> {
1919
name: &'static str,
20-
function: FitnessFunction<S>,
20+
function: Box<dyn FitnessFunction<S>>,
2121
utility: UtilityFunction,
2222
}
2323

24+
/// A trait for fitness functions.
25+
pub trait FitnessFunction<S: Symbol>: Send + Sync + std::fmt::Debug {
26+
fn update_fitness(&self, haplotype: &HaplotypeRef<S>) -> f64;
27+
fn compute_fitness(&self, haplotype: &HaplotypeRef<S>) -> f64;
28+
fn write(&self, path: &Path, name: &'static str) -> Result<(), VirolutionError>;
29+
fn clone_box(&self) -> Box<dyn FitnessFunction<S>>;
30+
}
31+
32+
impl<S: Symbol> Clone for Box<dyn FitnessFunction<S>> {
33+
fn clone(&self) -> Self {
34+
self.clone_box()
35+
}
36+
}
37+
38+
#[derive(Clone, Debug)]
39+
pub struct NonEpistatic<S: Symbol> {
40+
table: FitnessTable,
41+
phantom: std::marker::PhantomData<S>,
42+
}
43+
44+
impl<S: Symbol> NonEpistatic<S> {
45+
pub fn new(table: FitnessTable) -> Self {
46+
Self {
47+
table,
48+
phantom: std::marker::PhantomData,
49+
}
50+
}
51+
}
52+
53+
impl<S: Symbol> FitnessFunction<S> for NonEpistatic<S> {
54+
fn update_fitness(&self, haplotype: &HaplotypeRef<S>) -> f64 {
55+
self.table.update_fitness(haplotype)
56+
}
57+
58+
fn compute_fitness(&self, haplotype: &HaplotypeRef<S>) -> f64 {
59+
self.table.compute_fitness(haplotype)
60+
}
61+
62+
fn write(&self, path: &Path, name: &'static str) -> Result<(), VirolutionError> {
63+
let table_name = format!("{}_table.npy", name);
64+
let mut table_file = io::BufWriter::new(fs::File::create(path.join(table_name)).unwrap());
65+
self.table.write(&mut table_file)?;
66+
Ok(())
67+
}
68+
69+
fn clone_box(&self) -> Box<dyn FitnessFunction<S>> {
70+
Box::new(self.clone())
71+
}
72+
}
73+
2474
#[derive(Clone, Debug)]
25-
pub enum FitnessFunction<S: Symbol> {
26-
NonEpistatic(FitnessTable),
27-
SimpleEpistatic(FitnessTable, EpistasisTable<S>),
75+
pub struct SimpleEpistatic<S: Symbol> {
76+
table: FitnessTable,
77+
epistasis: EpistasisTable<S>,
78+
}
79+
80+
impl<S: Symbol> SimpleEpistatic<S> {
81+
pub fn new(table: FitnessTable, epistasis: EpistasisTable<S>) -> Self {
82+
Self { table, epistasis }
83+
}
84+
}
85+
86+
impl<S: Symbol> FitnessFunction<S> for SimpleEpistatic<S> {
87+
fn update_fitness(&self, haplotype: &HaplotypeRef<S>) -> f64 {
88+
self.table.update_fitness(haplotype) * self.epistasis.update_fitness(haplotype)
89+
}
90+
91+
fn compute_fitness(&self, haplotype: &HaplotypeRef<S>) -> f64 {
92+
self.table.compute_fitness(haplotype) * self.epistasis.compute_fitness(haplotype)
93+
}
94+
95+
fn write(&self, path: &Path, name: &'static str) -> Result<(), VirolutionError> {
96+
let table_name = format!("{}_table.npy", name);
97+
let epistasis_name = format!("{}_epistasis_table.npy", name);
98+
99+
let table_file = fs::File::create(path.join(table_name)).unwrap();
100+
let epistasis_path = fs::File::create(path.join(epistasis_name)).unwrap();
101+
102+
let mut table_writer = io::BufWriter::new(table_file);
103+
let mut epistasis_writer = io::BufWriter::new(epistasis_path);
104+
105+
self.table.write(&mut table_writer)?;
106+
self.epistasis.write(&mut epistasis_writer)?;
107+
Ok(())
108+
}
109+
110+
fn clone_box(&self) -> Box<dyn FitnessFunction<S>> {
111+
Box::new(self.clone())
112+
}
113+
}
114+
115+
#[derive(Clone, Debug)]
116+
pub struct Neutral<S: Symbol> {
117+
marker: std::marker::PhantomData<S>,
118+
}
119+
120+
impl<S: Symbol> Neutral<S> {
121+
pub fn new() -> Self {
122+
Self {
123+
marker: std::marker::PhantomData,
124+
}
125+
}
126+
}
127+
128+
impl<S: Symbol> FitnessFunction<S> for Neutral<S> {
129+
fn update_fitness(&self, _haplotype: &HaplotypeRef<S>) -> f64 {
130+
1.0
131+
}
132+
133+
fn compute_fitness(&self, _haplotype: &HaplotypeRef<S>) -> f64 {
134+
1.0
135+
}
136+
137+
fn write(&self, _path: &Path, _name: &'static str) -> Result<(), VirolutionError> {
138+
Ok(())
139+
}
140+
141+
fn clone_box(&self) -> Box<dyn FitnessFunction<S>> {
142+
Box::new(self.clone())
143+
}
28144
}
29145

30146
impl<S: Symbol> FitnessProvider<S> {
31147
/// Create a new fitness provider.
32-
pub fn new(name: &'static str, function: FitnessFunction<S>, utility: UtilityFunction) -> Self {
148+
pub fn new(
149+
name: &'static str,
150+
function: Box<dyn FitnessFunction<S>>,
151+
utility: UtilityFunction,
152+
) -> Self {
33153
Self {
34154
name,
35155
function,
@@ -46,27 +166,24 @@ impl<S: Symbol> FitnessProvider<S> {
46166
sequence: &[S],
47167
model: &FitnessModel,
48168
) -> Result<Self, VirolutionError> {
49-
let function = match model.distribution {
50-
FitnessDistribution::Neutral => {
51-
let table = FitnessTable::from_model(sequence, model)?;
52-
FitnessFunction::NonEpistatic(table)
53-
}
169+
let function: Box<dyn FitnessFunction<S>> = match model.distribution {
170+
FitnessDistribution::Neutral => Box::new(Neutral::new()),
54171
FitnessDistribution::Exponential(_) => {
55172
let table = FitnessTable::from_model(sequence, model)?;
56-
FitnessFunction::NonEpistatic(table)
173+
Box::new(NonEpistatic::new(table))
57174
}
58175
FitnessDistribution::Lognormal(_) => {
59176
let table = FitnessTable::from_model(sequence, model)?;
60-
FitnessFunction::NonEpistatic(table)
177+
Box::new(NonEpistatic::new(table))
61178
}
62179
FitnessDistribution::File(_) => {
63180
let table = FitnessTable::from_model(sequence, model)?;
64-
FitnessFunction::NonEpistatic(table)
181+
Box::new(NonEpistatic::new(table))
65182
}
66183
FitnessDistribution::Epistatic(_) => {
67184
let table = FitnessTable::from_model(sequence, model)?;
68185
let epistasis = EpistasisTable::from_model(model)?;
69-
FitnessFunction::SimpleEpistatic(table, epistasis)
186+
Box::new(SimpleEpistatic::new(table, epistasis))
70187
}
71188
};
72189
Ok(Self {
@@ -99,21 +216,11 @@ impl<S: Symbol> FitnessProvider<S> {
99216
}
100217

101218
fn update_fitness(&self, haplotype: &HaplotypeRef<S>) -> f64 {
102-
match &self.function {
103-
FitnessFunction::NonEpistatic(table) => table.update_fitness(haplotype),
104-
FitnessFunction::SimpleEpistatic(table, epistasis) => {
105-
table.update_fitness(haplotype) * epistasis.update_fitness(haplotype)
106-
}
107-
}
219+
self.function.update_fitness(haplotype)
108220
}
109221

110222
fn compute_fitness(&self, haplotype: &HaplotypeRef<S>) -> f64 {
111-
match &self.function {
112-
FitnessFunction::NonEpistatic(table) => table.compute_fitness(haplotype),
113-
FitnessFunction::SimpleEpistatic(table, epistasis) => {
114-
table.compute_fitness(haplotype) * epistasis.compute_fitness(haplotype)
115-
}
116-
}
223+
self.function.compute_fitness(haplotype)
117224
}
118225
}
119226

@@ -141,27 +248,6 @@ impl<S: Symbol> AttributeProvider<S> for FitnessProvider<S> {
141248
}
142249

143250
fn write(&self, path: &Path) -> Result<(), VirolutionError> {
144-
match &self.function {
145-
FitnessFunction::NonEpistatic(table) => {
146-
let table_name = format!("{}_table.npy", self.name);
147-
let mut table_file =
148-
io::BufWriter::new(fs::File::create(path.join(table_name)).unwrap());
149-
table.write(&mut table_file)?;
150-
}
151-
FitnessFunction::SimpleEpistatic(table, epistasis) => {
152-
let table_name = format!("{}_table.npy", self.name);
153-
let epistasis_name = format!("{}_epistasis_table.npy", self.name);
154-
155-
let table_file = fs::File::create(path.join(table_name)).unwrap();
156-
let epistasis_path = fs::File::create(path.join(epistasis_name)).unwrap();
157-
158-
let mut table_writer = io::BufWriter::new(table_file);
159-
let mut epistasis_writer = io::BufWriter::new(epistasis_path);
160-
161-
table.write(&mut table_writer)?;
162-
epistasis.write(&mut epistasis_writer)?;
163-
}
164-
}
165-
Ok(())
251+
self.function.write(path, self.name)
166252
}
167253
}

0 commit comments

Comments
 (0)