Skip to content

Commit d4cee45

Browse files
committed
Some refactoring in readwrite module
1 parent 48263fa commit d4cee45

File tree

3 files changed

+85
-37
lines changed

3 files changed

+85
-37
lines changed

src/errors.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
33
use std::fmt;
44

5+
pub type Result<T> = std::result::Result<T, VirolutionError>;
6+
57
#[derive(Clone, Debug)]
68
pub enum VirolutionError {
79
ImplementationError(String),
810
InitializationError(String),
11+
ReadError(String),
912
}
1013

1114
impl fmt::Display for VirolutionError {
@@ -17,6 +20,9 @@ impl fmt::Display for VirolutionError {
1720
VirolutionError::InitializationError(message) => {
1821
write!(f, "InitializationError: {}", message)
1922
}
23+
VirolutionError::ReadError(message) => {
24+
write!(f, "ReadError: {}", message)
25+
}
2026
}
2127
}
2228
}

src/readwrite/haplotype.rs

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,38 @@
1+
use seq_io::fasta;
2+
use seq_io::fasta::Record;
3+
14
use crate::core::haplotype::{Haplotype, Symbol, Wildtype};
25
use crate::encoding::STRICT_DECODE;
6+
use crate::errors::{Result, VirolutionError};
37
use crate::references::HaplotypeRef;
4-
use seq_io::fasta;
5-
use seq_io::fasta::Record;
68

79
pub trait HaplotypeIO {
8-
fn load_wildtype(path: &str) -> Result<HaplotypeRef, fasta::Error>;
10+
fn load_wildtype(path: &str) -> Result<HaplotypeRef>;
911
}
1012

1113
impl HaplotypeIO for Haplotype {
12-
fn load_wildtype(path: &str) -> Result<HaplotypeRef, fasta::Error> {
13-
let mut reader = fasta::Reader::from_path(path)?;
14-
let sequence: Vec<Symbol> = reader
15-
.next()
16-
.unwrap()
17-
.expect("Unable to read sequence.")
18-
.seq()
19-
.iter()
20-
.filter(|&&enc| enc != 0x0au8 && enc != 0x0du8)
21-
.map(|enc| match STRICT_DECODE.get(enc) {
22-
Some(result) => Some(*result),
23-
None => panic!("Unable to decode literal {enc}."),
24-
})
25-
.collect();
26-
Ok(Wildtype::new(sequence))
14+
fn load_wildtype(path: &str) -> Result<HaplotypeRef> {
15+
let mut reader = fasta::Reader::from_path(path).map_err(|_| {
16+
VirolutionError::InitializationError(format!(
17+
"Unable create file reader for fasta file: {path}"
18+
))
19+
})?;
20+
21+
match reader.next() {
22+
None => Err(VirolutionError::InitializationError(format!(
23+
"No sequence found in fasta file: {path}"
24+
))),
25+
Some(Err(_)) => Err(VirolutionError::InitializationError(format!(
26+
"Unable to read sequence from fasta file: {path}"
27+
))),
28+
Some(Ok(sequence_record)) => {
29+
let sequence: Vec<Symbol> = sequence_record
30+
.seq()
31+
.iter()
32+
.filter_map(|enc| STRICT_DECODE.get(enc).copied().map(Some))
33+
.collect();
34+
Ok(Wildtype::new(sequence))
35+
}
36+
}
2737
}
2838
}

src/readwrite/population.rs

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
use serde::Deserialize;
2+
13
use crate::core::haplotype::Symbol;
24
use crate::core::population::Population;
35
use crate::encoding::DECODE;
6+
use crate::errors::{Result, VirolutionError};
47
use crate::references::HaplotypeRef;
5-
use serde::Deserialize;
68

79
pub trait PopulationIO {
8-
fn read(path: &str, wildtype: HaplotypeRef) -> Result<Population, csv::Error>;
10+
fn read(path: &str, wildtype: HaplotypeRef) -> Result<Population>;
911
fn write(&self, path: &str);
1012
}
1113

@@ -20,30 +22,26 @@ impl PopulationIO for Population {
2022
///
2123
/// Warning: This function will create multiple instances of the same
2224
/// haplotype if it is present multiple times in the CSV file.
23-
fn read(path: &str, wildtype: HaplotypeRef) -> Result<Population, csv::Error> {
24-
let mut reader = csv::Reader::from_path(path)?;
25+
fn read(path: &str, wildtype: HaplotypeRef) -> Result<Population> {
26+
let mut reader = csv::Reader::from_path(path)
27+
.map_err(|_err| VirolutionError::ReadError(format!("Failed to read from {path}")))?;
2528
let mut populations: Vec<Population> = Vec::new();
2629

2730
for record in reader.deserialize() {
28-
let record: HaplotypeRecord = record?;
31+
let record: HaplotypeRecord = record.map_err(|_err| {
32+
VirolutionError::ReadError(format!("Failed to parse record in {path}"))
33+
})?;
34+
35+
// skip wildtype
2936
if record.haplotype == "wt" {
3037
continue;
3138
}
32-
let mutations = record.haplotype.split(';');
33-
let mut positions: Vec<usize> = Vec::new();
34-
let mut changes: Vec<Symbol> = Vec::new();
35-
mutations.for_each(|mutation| {
36-
let mut mutation = mutation.split(':');
37-
let position = mutation.next().unwrap();
38-
39-
let mut change = mutation.next().unwrap().split("->");
40-
let _origin = change.next();
41-
let target = change.next().unwrap().chars().next().unwrap() as u8;
42-
43-
positions.push(position.parse::<usize>().unwrap());
44-
changes.push(DECODE.get(&target).copied());
45-
});
39+
40+
// parse haplotype
41+
let (positions, changes) = parse_haplotype(&record.haplotype)?;
4642
let haplotype = wildtype.create_descendant(positions, changes, 0);
43+
44+
// create population and add for merging
4745
let population = Population::from_haplotype(haplotype, record.count);
4846
populations.push(population);
4947
}
@@ -55,3 +53,37 @@ impl PopulationIO for Population {
5553
unimplemented!()
5654
}
5755
}
56+
57+
fn parse_haplotype(haplotype: &str) -> Result<(Vec<usize>, Vec<Symbol>)> {
58+
let mutations = haplotype.split(';');
59+
let parsed_mutations: Result<Vec<(usize, Symbol)>> = mutations.map(parse_mutation).collect();
60+
Ok(parsed_mutations?.into_iter().unzip())
61+
}
62+
63+
fn parse_mutation(mutation: &str) -> Result<(usize, Symbol)> {
64+
let mutation_split: Vec<&str> = mutation.split(':').collect();
65+
match mutation_split.as_slice() {
66+
[position, change] => Ok((
67+
position.parse::<usize>().map_err(|_| {
68+
VirolutionError::ReadError(format!("Failed to parse position: {position}"))
69+
})?,
70+
parse_change(change)?,
71+
)),
72+
_ => Err(VirolutionError::ReadError(format!(
73+
"Invalid mutation format: {mutation}"
74+
))),
75+
}
76+
}
77+
78+
fn parse_change(change: &str) -> Result<Symbol> {
79+
let change_split: Vec<&str> = change.split(':').collect();
80+
match change_split.as_slice() {
81+
[_origin, target] => {
82+
let target = target.chars().next().unwrap() as u8;
83+
Ok(DECODE.get(&target).copied())
84+
}
85+
_ => Err(VirolutionError::ReadError(format!(
86+
"Invalid change format: {change}"
87+
))),
88+
}
89+
}

0 commit comments

Comments
 (0)