@@ -17,19 +17,139 @@ use crate::references::HaplotypeRef;
17
17
#[ derive( Clone , Debug ) ]
18
18
pub struct FitnessProvider < S : Symbol > {
19
19
name : & ' static str ,
20
- function : FitnessFunction < S > ,
20
+ function : Box < dyn FitnessFunction < S > > ,
21
21
utility : UtilityFunction ,
22
22
}
23
23
24
+ /// A trait for fitness functions.
25
+ pub trait FitnessFunction < S : Symbol > : Send + Sync + std:: fmt:: Debug {
26
+ fn update_fitness ( & self , haplotype : & HaplotypeRef < S > ) -> f64 ;
27
+ fn compute_fitness ( & self , haplotype : & HaplotypeRef < S > ) -> f64 ;
28
+ fn write ( & self , path : & Path , name : & ' static str ) -> Result < ( ) , VirolutionError > ;
29
+ fn clone_box ( & self ) -> Box < dyn FitnessFunction < S > > ;
30
+ }
31
+
32
+ impl < S : Symbol > Clone for Box < dyn FitnessFunction < S > > {
33
+ fn clone ( & self ) -> Self {
34
+ self . clone_box ( )
35
+ }
36
+ }
37
+
38
+ #[ derive( Clone , Debug ) ]
39
+ pub struct NonEpistatic < S : Symbol > {
40
+ table : FitnessTable ,
41
+ phantom : std:: marker:: PhantomData < S > ,
42
+ }
43
+
44
+ impl < S : Symbol > NonEpistatic < S > {
45
+ pub fn new ( table : FitnessTable ) -> Self {
46
+ Self {
47
+ table,
48
+ phantom : std:: marker:: PhantomData ,
49
+ }
50
+ }
51
+ }
52
+
53
+ impl < S : Symbol > FitnessFunction < S > for NonEpistatic < S > {
54
+ fn update_fitness ( & self , haplotype : & HaplotypeRef < S > ) -> f64 {
55
+ self . table . update_fitness ( haplotype)
56
+ }
57
+
58
+ fn compute_fitness ( & self , haplotype : & HaplotypeRef < S > ) -> f64 {
59
+ self . table . compute_fitness ( haplotype)
60
+ }
61
+
62
+ fn write ( & self , path : & Path , name : & ' static str ) -> Result < ( ) , VirolutionError > {
63
+ let table_name = format ! ( "{}_table.npy" , name) ;
64
+ let mut table_file = io:: BufWriter :: new ( fs:: File :: create ( path. join ( table_name) ) . unwrap ( ) ) ;
65
+ self . table . write ( & mut table_file) ?;
66
+ Ok ( ( ) )
67
+ }
68
+
69
+ fn clone_box ( & self ) -> Box < dyn FitnessFunction < S > > {
70
+ Box :: new ( self . clone ( ) )
71
+ }
72
+ }
73
+
24
74
#[ derive( Clone , Debug ) ]
25
- pub enum FitnessFunction < S : Symbol > {
26
- NonEpistatic ( FitnessTable ) ,
27
- SimpleEpistatic ( FitnessTable , EpistasisTable < S > ) ,
75
+ pub struct SimpleEpistatic < S : Symbol > {
76
+ table : FitnessTable ,
77
+ epistasis : EpistasisTable < S > ,
78
+ }
79
+
80
+ impl < S : Symbol > SimpleEpistatic < S > {
81
+ pub fn new ( table : FitnessTable , epistasis : EpistasisTable < S > ) -> Self {
82
+ Self { table, epistasis }
83
+ }
84
+ }
85
+
86
+ impl < S : Symbol > FitnessFunction < S > for SimpleEpistatic < S > {
87
+ fn update_fitness ( & self , haplotype : & HaplotypeRef < S > ) -> f64 {
88
+ self . table . update_fitness ( haplotype) * self . epistasis . update_fitness ( haplotype)
89
+ }
90
+
91
+ fn compute_fitness ( & self , haplotype : & HaplotypeRef < S > ) -> f64 {
92
+ self . table . compute_fitness ( haplotype) * self . epistasis . compute_fitness ( haplotype)
93
+ }
94
+
95
+ fn write ( & self , path : & Path , name : & ' static str ) -> Result < ( ) , VirolutionError > {
96
+ let table_name = format ! ( "{}_table.npy" , name) ;
97
+ let epistasis_name = format ! ( "{}_epistasis_table.npy" , name) ;
98
+
99
+ let table_file = fs:: File :: create ( path. join ( table_name) ) . unwrap ( ) ;
100
+ let epistasis_path = fs:: File :: create ( path. join ( epistasis_name) ) . unwrap ( ) ;
101
+
102
+ let mut table_writer = io:: BufWriter :: new ( table_file) ;
103
+ let mut epistasis_writer = io:: BufWriter :: new ( epistasis_path) ;
104
+
105
+ self . table . write ( & mut table_writer) ?;
106
+ self . epistasis . write ( & mut epistasis_writer) ?;
107
+ Ok ( ( ) )
108
+ }
109
+
110
+ fn clone_box ( & self ) -> Box < dyn FitnessFunction < S > > {
111
+ Box :: new ( self . clone ( ) )
112
+ }
113
+ }
114
+
115
+ #[ derive( Clone , Debug ) ]
116
+ pub struct Neutral < S : Symbol > {
117
+ marker : std:: marker:: PhantomData < S > ,
118
+ }
119
+
120
+ impl < S : Symbol > Neutral < S > {
121
+ pub fn new ( ) -> Self {
122
+ Self {
123
+ marker : std:: marker:: PhantomData ,
124
+ }
125
+ }
126
+ }
127
+
128
+ impl < S : Symbol > FitnessFunction < S > for Neutral < S > {
129
+ fn update_fitness ( & self , _haplotype : & HaplotypeRef < S > ) -> f64 {
130
+ 1.0
131
+ }
132
+
133
+ fn compute_fitness ( & self , _haplotype : & HaplotypeRef < S > ) -> f64 {
134
+ 1.0
135
+ }
136
+
137
+ fn write ( & self , _path : & Path , _name : & ' static str ) -> Result < ( ) , VirolutionError > {
138
+ Ok ( ( ) )
139
+ }
140
+
141
+ fn clone_box ( & self ) -> Box < dyn FitnessFunction < S > > {
142
+ Box :: new ( self . clone ( ) )
143
+ }
28
144
}
29
145
30
146
impl < S : Symbol > FitnessProvider < S > {
31
147
/// Create a new fitness provider.
32
- pub fn new ( name : & ' static str , function : FitnessFunction < S > , utility : UtilityFunction ) -> Self {
148
+ pub fn new (
149
+ name : & ' static str ,
150
+ function : Box < dyn FitnessFunction < S > > ,
151
+ utility : UtilityFunction ,
152
+ ) -> Self {
33
153
Self {
34
154
name,
35
155
function,
@@ -46,27 +166,24 @@ impl<S: Symbol> FitnessProvider<S> {
46
166
sequence : & [ S ] ,
47
167
model : & FitnessModel ,
48
168
) -> Result < Self , VirolutionError > {
49
- let function = match model. distribution {
50
- FitnessDistribution :: Neutral => {
51
- let table = FitnessTable :: from_model ( sequence, model) ?;
52
- FitnessFunction :: NonEpistatic ( table)
53
- }
169
+ let function: Box < dyn FitnessFunction < S > > = match model. distribution {
170
+ FitnessDistribution :: Neutral => Box :: new ( Neutral :: new ( ) ) ,
54
171
FitnessDistribution :: Exponential ( _) => {
55
172
let table = FitnessTable :: from_model ( sequence, model) ?;
56
- FitnessFunction :: NonEpistatic ( table)
173
+ Box :: new ( NonEpistatic :: new ( table) )
57
174
}
58
175
FitnessDistribution :: Lognormal ( _) => {
59
176
let table = FitnessTable :: from_model ( sequence, model) ?;
60
- FitnessFunction :: NonEpistatic ( table)
177
+ Box :: new ( NonEpistatic :: new ( table) )
61
178
}
62
179
FitnessDistribution :: File ( _) => {
63
180
let table = FitnessTable :: from_model ( sequence, model) ?;
64
- FitnessFunction :: NonEpistatic ( table)
181
+ Box :: new ( NonEpistatic :: new ( table) )
65
182
}
66
183
FitnessDistribution :: Epistatic ( _) => {
67
184
let table = FitnessTable :: from_model ( sequence, model) ?;
68
185
let epistasis = EpistasisTable :: from_model ( model) ?;
69
- FitnessFunction :: SimpleEpistatic ( table, epistasis)
186
+ Box :: new ( SimpleEpistatic :: new ( table, epistasis) )
70
187
}
71
188
} ;
72
189
Ok ( Self {
@@ -99,21 +216,11 @@ impl<S: Symbol> FitnessProvider<S> {
99
216
}
100
217
101
218
fn update_fitness ( & self , haplotype : & HaplotypeRef < S > ) -> f64 {
102
- match & self . function {
103
- FitnessFunction :: NonEpistatic ( table) => table. update_fitness ( haplotype) ,
104
- FitnessFunction :: SimpleEpistatic ( table, epistasis) => {
105
- table. update_fitness ( haplotype) * epistasis. update_fitness ( haplotype)
106
- }
107
- }
219
+ self . function . update_fitness ( haplotype)
108
220
}
109
221
110
222
fn compute_fitness ( & self , haplotype : & HaplotypeRef < S > ) -> f64 {
111
- match & self . function {
112
- FitnessFunction :: NonEpistatic ( table) => table. compute_fitness ( haplotype) ,
113
- FitnessFunction :: SimpleEpistatic ( table, epistasis) => {
114
- table. compute_fitness ( haplotype) * epistasis. compute_fitness ( haplotype)
115
- }
116
- }
223
+ self . function . compute_fitness ( haplotype)
117
224
}
118
225
}
119
226
@@ -141,27 +248,6 @@ impl<S: Symbol> AttributeProvider<S> for FitnessProvider<S> {
141
248
}
142
249
143
250
fn write ( & self , path : & Path ) -> Result < ( ) , VirolutionError > {
144
- match & self . function {
145
- FitnessFunction :: NonEpistatic ( table) => {
146
- let table_name = format ! ( "{}_table.npy" , self . name) ;
147
- let mut table_file =
148
- io:: BufWriter :: new ( fs:: File :: create ( path. join ( table_name) ) . unwrap ( ) ) ;
149
- table. write ( & mut table_file) ?;
150
- }
151
- FitnessFunction :: SimpleEpistatic ( table, epistasis) => {
152
- let table_name = format ! ( "{}_table.npy" , self . name) ;
153
- let epistasis_name = format ! ( "{}_epistasis_table.npy" , self . name) ;
154
-
155
- let table_file = fs:: File :: create ( path. join ( table_name) ) . unwrap ( ) ;
156
- let epistasis_path = fs:: File :: create ( path. join ( epistasis_name) ) . unwrap ( ) ;
157
-
158
- let mut table_writer = io:: BufWriter :: new ( table_file) ;
159
- let mut epistasis_writer = io:: BufWriter :: new ( epistasis_path) ;
160
-
161
- table. write ( & mut table_writer) ?;
162
- epistasis. write ( & mut epistasis_writer) ?;
163
- }
164
- }
165
- Ok ( ( ) )
251
+ self . function . write ( path, self . name )
166
252
}
167
253
}
0 commit comments