Skip to content

Commit 62c12d1

Browse files
committed
Fix attribute and fitness computations
1 parent 1c774a8 commit 62c12d1

File tree

8 files changed

+100
-50
lines changed

8 files changed

+100
-50
lines changed

flake.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/core/attributes.rs

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ pub trait AttributeProvider<S: Symbol>: Sync + Send + std::fmt::Debug {
7676
///
7777
/// The attribute set definition contains the providers that compute the attributes of the set.
7878
/// It can be used to register new providers and to create an attribute set.
79-
#[derive(Clone)]
79+
#[derive(Clone, Debug)]
8080
pub struct AttributeSetDefinition<S: Symbol> {
8181
providers: HashMap<&'static str, Arc<dyn AttributeProvider<S> + Send + Sync>>,
8282
eager: Vec<&'static str>,
@@ -192,6 +192,30 @@ impl<S: Symbol> AttributeSet<S> {
192192
Self::new(self.definition.clone(), haplotype)
193193
}
194194

195+
/// Get or compute the value of an attribute.
196+
pub fn get_or_compute_raw(&self, id: &'static str) -> Result<AttributeValue> {
197+
let provider = self.definition.providers.get(&id).ok_or_else(|| {
198+
VirolutionError::ImplementationError(format!("No provider found for attribute {}", id))
199+
})?;
200+
201+
// First, try to read the value
202+
{
203+
let values = self.values.read().unwrap();
204+
if let Some(value) = values.get(&id) {
205+
return Ok(value.clone());
206+
}
207+
}
208+
209+
// Compute the attribute using the provider
210+
let value = provider.compute(&self.haplotype.upgrade());
211+
212+
// Write the computed value
213+
let mut values = self.values.write().unwrap();
214+
values.insert(id, value.clone());
215+
216+
Ok(value)
217+
}
218+
195219
/// Get or compute the value of an attribute.
196220
pub fn get_or_compute(&self, id: &'static str) -> Result<AttributeValue> {
197221
let provider = self.definition.providers.get(&id).ok_or_else(|| {
@@ -207,20 +231,13 @@ impl<S: Symbol> AttributeSet<S> {
207231
}
208232

209233
// Compute the attribute using the provider
210-
if let Some(provider) = self.definition.providers.get(&id) {
211-
let value = provider.compute(&self.haplotype.upgrade());
212-
213-
// Write the computed value
214-
let mut values = self.values.write().unwrap();
215-
values.insert(id, value.clone());
216-
217-
Ok(provider.map(value))
218-
} else {
219-
// If there is no provider, return None
220-
Err(VirolutionError::ImplementationError(
221-
"No provider found for attribute".to_string(),
222-
))
223-
}
234+
let value = provider.compute(&self.haplotype.upgrade());
235+
236+
// Write the computed value
237+
let mut values = self.values.write().unwrap();
238+
values.insert(id, value.clone());
239+
240+
Ok(provider.map(value))
224241
}
225242

226243
/// Get the value of an already computed attribute.
@@ -260,7 +277,9 @@ impl<S: Symbol> Clone for AttributeSet<S> {
260277
impl<S: Symbol> std::fmt::Debug for AttributeSet<S> {
261278
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262279
let values = self.values.read().unwrap();
280+
let haplotype = self.haplotype.upgrade();
263281
f.debug_struct("AttributeSet")
282+
.field("haplotype", &haplotype)
264283
.field("values", &values)
265284
.finish()
266285
}

src/core/haplotype.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,10 +287,21 @@ impl<S: Symbol> Haplotype<S> {
287287
recombinant
288288
}
289289

290+
#[allow(dead_code)]
291+
pub(crate) fn is_wildtype(&self) -> bool {
292+
matches!(self, Haplotype::Wildtype(_))
293+
}
294+
295+
#[allow(dead_code)]
290296
pub(crate) fn is_mutant(&self) -> bool {
291297
matches!(self, Haplotype::Mutant(_))
292298
}
293299

300+
#[allow(dead_code)]
301+
pub(crate) fn is_recombinant(&self) -> bool {
302+
matches!(self, Haplotype::Recombinant(_))
303+
}
304+
294305
/// Unwraps the haplotype into a mutant.
295306
///
296307
/// This function will panic if the haplotype is not a mutant and is only intended for internal
@@ -466,6 +477,10 @@ impl<S: Symbol> Haplotype<S> {
466477
self.get_attributes().get(id)
467478
}
468479

480+
pub fn get_or_compute_attribute_raw(&self, id: &'static str) -> Result<AttributeValue> {
481+
self.get_attributes().get_or_compute_raw(id)
482+
}
483+
469484
pub fn get_or_compute_attribute(&self, id: &'static str) -> Result<AttributeValue> {
470485
self.get_attributes().get_or_compute(id)
471486
}

src/core/hosts.rs

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,13 @@ impl<S: Symbol> HostSpecs<S> {
6060
}
6161

6262
pub fn try_get_spec_from_index(&self, index: usize) -> Option<&HostSpec<S>> {
63-
for spec in self.0.iter() {
64-
if spec.range.contains(&index) {
65-
return Some(spec);
66-
}
67-
}
68-
None
63+
self.0.iter().find(|&spec| spec.range.contains(&index))
64+
}
65+
}
66+
67+
impl<S: Symbol> Default for HostSpecs<S> {
68+
fn default() -> Self {
69+
Self::new()
6970
}
7071
}
7172

@@ -111,6 +112,19 @@ pub struct HostMapBuffer {
111112
hosts: Vec<usize>,
112113
}
113114

115+
impl std::fmt::Debug for HostMapBuffer {
116+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117+
let n_infections = self.infections.iter().filter(|x| x.is_some()).count();
118+
let infectivity = n_infections as f64 / self.n_infectants as f64;
119+
f.debug_struct("HostMapBuffer")
120+
.field("n_hosts", &self.n_hosts)
121+
.field("n_infectants", &self.n_infectants)
122+
.field("n_infections", &n_infections)
123+
.field("infectivity", &infectivity)
124+
.finish()
125+
}
126+
}
127+
114128
/// A map from host index to infectant indices.
115129
///
116130
/// The map is constructed from a list of infections between infectant and host indices. It uses
@@ -149,11 +163,11 @@ impl HostMapBuffer {
149163
}
150164
}
151165

152-
pub fn build<F: FnMut(&mut Option<usize>)>(&mut self, builder: F) {
166+
pub fn build<F: FnMut((usize, &mut Option<usize>))>(&mut self, builder: F) {
153167
let infection_slice = &mut self.infections[0..self.n_infectants];
154168

155169
// set new infections
156-
infection_slice.iter_mut().for_each(builder);
170+
infection_slice.iter_mut().enumerate().for_each(builder);
157171

158172
// reset offsets
159173
self.offsets[0..=self.n_hosts].fill(0);

src/init/fitness.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,11 @@ impl LognormalParameters {
180180

181181
impl FileParameters {
182182
pub fn load_table(&self) -> Vec<f64> {
183-
let reader = NpyFile::new(std::fs::File::open(&self.path).unwrap_or_else(
184-
|_e| panic!("Could not open file: {}", self.path),
185-
)).unwrap();
183+
let reader = NpyFile::new(
184+
std::fs::File::open(&self.path)
185+
.unwrap_or_else(|_e| panic!("Could not open file: {}", self.path)),
186+
)
187+
.unwrap();
186188
reader
187189
.data::<f64>()
188190
.unwrap()

src/providers/fitness.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ impl<S: Symbol> FitnessProvider<S> {
216216
haplotype
217217
.try_get_ancestor()
218218
.expect("Could not find ancestor during fitness update...")
219-
.get_or_compute_attribute(self.name)
220-
.unwrap(),
219+
.get_or_compute_attribute_raw(self.name)
220+
.unwrap()
221221
)
222222
.unwrap();
223223
let update = self.update_fitness(haplotype);

src/runner.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ impl Runner {
7373

7474
let wildtype = Haplotype::load_wildtype(sequence, &attribute_definitions);
7575

76-
// perform sanity checks
76+
// perform sanity checks
7777
if !settings
7878
.schedule
7979
.check_transfer_table_sizes(args.n_compartments)
@@ -334,7 +334,7 @@ impl Runner {
334334

335335
for generation in 0..=self.args.generations {
336336
// logging
337-
log::debug!("Generate logging message for generation {generation}...");
337+
log::debug!("Running generation {generation}...");
338338
let population_sizes: Vec<usize> = self
339339
.simulations
340340
.iter()

src/simulation.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -366,20 +366,20 @@ impl<S: Symbol> Simulation<S> for BasicSimulation<S> {
366366
let host_sampler = Uniform::new(0, self.parameters.host_population_size)
367367
.expect("Invalid host population size");
368368
let mut rng = rand::rng();
369-
self.host_map_buffer.build(|ref mut infectant| {
370-
let host_candidate = host_sampler.sample(&mut rng);
371-
**infectant = self
372-
.host_specs
373-
.try_get_spec_from_index(host_candidate)
374-
.map(|spec| {
375-
if spec.host.infect(&self.wildtype, &mut rng) {
376-
Some(host_candidate)
377-
} else {
378-
None
379-
}
380-
})
381-
.flatten();
382-
});
369+
self.host_map_buffer
370+
.build(|(infectant, ref mut infection)| {
371+
let host_candidate = host_sampler.sample(&mut rng);
372+
**infection = self
373+
.host_specs
374+
.try_get_spec_from_index(host_candidate)
375+
.and_then(|spec| {
376+
if spec.host.infect(&self.population.get(&infectant), &mut rng) {
377+
Some(host_candidate)
378+
} else {
379+
None
380+
}
381+
})
382+
});
383383
}
384384

385385
#[cfg(feature = "parallel")]

0 commit comments

Comments
 (0)