Skip to content

Commit 2db31b8

Browse files
authored
Test and handle some boundary conditions (#3)
* Test and handle some boundary conditions * Apply rustfmt to code Co-authored-by: Alex Merry <alex.merry@nanoporetech.com>
1 parent df32a9c commit 2db31b8

File tree

3 files changed

+78
-18
lines changed

3 files changed

+78
-18
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ dist/
55
*.egg-info/
66
*.so
77
target/
8-
8+
.*.swp

src/lib.rs

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
#[macro_use(s)]
44
extern crate ndarray;
55

6-
use numpy::PyArray2;
76
use ndarray::{ArrayBase, Data, Ix2};
7+
use numpy::PyArray2;
88

9+
use pyo3::exceptions::ValueError;
910
use pyo3::prelude::*;
1011
use pyo3::wrap_pyfunction;
1112

@@ -15,8 +16,21 @@ fn beam_search(
1516
alphabet: String,
1617
beam_size: usize,
1718
beam_cut_threshold: f32,
18-
) -> (String, Vec<usize>) {
19-
beam_search_(&result.as_array(), alphabet, beam_size, beam_cut_threshold)
19+
) -> PyResult<(String, Vec<usize>)> {
20+
if alphabet.len() != result.shape()[1] {
21+
Err(ValueError::py_err(
22+
"alphabet size does not match probability matrix dimensions",
23+
))
24+
} else if beam_size == 0 {
25+
Err(ValueError::py_err("beam_size cannot be 0"))
26+
} else {
27+
Ok(beam_search_(
28+
&result.as_array(),
29+
alphabet,
30+
beam_size,
31+
beam_cut_threshold,
32+
))
33+
}
2034
}
2135

2236
#[pymodule]
@@ -31,12 +45,11 @@ fn beam_search_<D: Data<Elem = f32>>(
3145
beam_size: usize,
3246
beam_cut_threshold: f32,
3347
) -> (String, Vec<usize>) {
34-
3548
let alphabet: Vec<char> = alphabet.chars().collect();
3649

3750
// alphabet_size minus the blank label
3851
let alphabet_size = alphabet.len() - 1;
39-
let duration = result.len() / alphabet.len();
52+
let duration = result.nrows();
4053

4154
// (base, what, idx)
4255
let mut beam_prevs = vec![(0, 0, 0)];
@@ -56,12 +69,12 @@ fn beam_search_<D: Data<Elem = f32>>(
5669
new_probs.push((beam, 0.0, (n_prob + base_prob) * pr[0]));
5770
}
5871

59-
for b in 1..alphabet_size + 1 {
60-
if pr[b] < beam_cut_threshold {
72+
for (b, &pr_b) in (1..=alphabet_size).zip(pr.iter().skip(1)) {
73+
if pr_b < beam_cut_threshold {
6174
continue;
6275
}
6376
if b == beam_prevs[beam as usize].0 {
64-
new_probs.push((beam, base_prob * pr[b], 0.0));
77+
new_probs.push((beam, base_prob * pr_b, 0.0));
6578
let mut new_beam = beam_forward[beam as usize][b - 1];
6679
if new_beam == -1 && n_prob > 0.0 {
6780
new_beam = beam_prevs.len() as i32;
@@ -70,7 +83,7 @@ fn beam_search_<D: Data<Elem = f32>>(
7083
beam_forward.push(vec![-1; alphabet_size]);
7184
}
7285

73-
new_probs.push((new_beam, n_prob * pr[b], 0.0));
86+
new_probs.push((new_beam, n_prob * pr_b, 0.0));
7487
} else {
7588
let mut new_beam = beam_forward[beam as usize][b - 1];
7689
if new_beam == -1 {
@@ -80,7 +93,7 @@ fn beam_search_<D: Data<Elem = f32>>(
8093
beam_forward.push(vec![-1; alphabet_size]);
8194
}
8295

83-
new_probs.push((new_beam, (base_prob + n_prob) * pr[b], 0.0));
96+
new_probs.push((new_beam, (base_prob + n_prob) * pr_b, 0.0));
8497
}
8598
}
8699
}
@@ -90,19 +103,24 @@ fn beam_search_<D: Data<Elem = f32>>(
90103
let mut last_key: i32 = -1;
91104
let mut last_key_pos = 0;
92105
for i in 0..cur_probs.len() {
93-
if cur_probs[i].0 == last_key {
94-
cur_probs[last_key_pos].1 = cur_probs[last_key_pos].1 + cur_probs[i].1;
95-
cur_probs[last_key_pos].2 = cur_probs[last_key_pos].2 + cur_probs[i].2;
106+
let cur_prob = cur_probs[i];
107+
if cur_prob.0 == last_key {
108+
cur_probs[last_key_pos].1 += cur_prob.1;
109+
cur_probs[last_key_pos].2 += cur_prob.2;
96110
cur_probs[i].0 = -1;
97111
} else {
98112
last_key_pos = i;
99-
last_key = cur_probs[i].0;
113+
last_key = cur_prob.0;
100114
}
101115
}
102116

103117
cur_probs.retain(|x| x.0 != -1);
104118
cur_probs.sort_by(|a, b| (b.1 + b.2).partial_cmp(&(a.1 + a.2)).unwrap());
105119
cur_probs.truncate(beam_size);
120+
if cur_probs.is_empty() {
121+
// we've run out of beam (probably the threshold is too high)
122+
return (String::new(), Vec::new());
123+
}
106124
let top = cur_probs[0].1 + cur_probs[0].2;
107125
for mut x in &mut cur_probs {
108126
x.1 /= top;

tests/test_decode.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,54 @@ def test_beam_search_alphabet(self):
2828
self.assertEqual(len(seq), len(path))
2929
self.assertEqual(len(set(seq)), len(self.alphabet) - 1)
3030

31+
def test_zero_beam_size(self):
32+
""" simple beam search test with zero beam size"""
33+
with self.assertRaises(ValueError):
34+
beam_search(self.probs, self.alphabet, 0, self.beam_cut_threshold)
35+
36+
def test_zero_beam_cut_threshold(self):
37+
""" simple beam search test with beam cut threshold of 0.0"""
38+
seq, path = beam_search(self.probs, self.alphabet, self.beam_size, 0.0)
39+
self.assertEqual(len(seq), len(path))
40+
self.assertEqual(len(set(seq)), len(self.alphabet) - 1)
41+
42+
def test_negative_beam_cut_threshold(self):
43+
""" simple beam search test with beam cut threshold below 0.0"""
44+
seq, path = beam_search(self.probs, self.alphabet, self.beam_size, -0.1)
45+
self.assertEqual(len(seq), len(path))
46+
self.assertEqual(len(set(seq)), len(self.alphabet) - 1)
47+
48+
def test_max_beam_cut_threshold(self):
49+
""" simple beam search test with beam cut threshold of 1.0"""
50+
seq, path = beam_search(self.probs, self.alphabet, self.beam_size, 1.0)
51+
self.assertEqual(len(seq), len(path))
52+
self.assertEqual(len(seq), 0) # with a threshold that high, we won't find anything
53+
54+
def test_high_beam_cut_threshold(self):
55+
""" simple beam search test with beam cut threshold above 1.0"""
56+
seq, path = beam_search(self.probs, self.alphabet, self.beam_size, 1.1)
57+
self.assertEqual(len(seq), len(path))
58+
self.assertEqual(len(seq), 0) # with a threshold that high, we won't find anything
59+
60+
def test_beam_search_mismatched_alphabet_short(self):
61+
""" simple beam search test with too few alphabet chars"""
62+
alphabet = "NAGC"
63+
with self.assertRaises(ValueError):
64+
beam_search(self.probs, alphabet, self.beam_size, self.beam_cut_threshold)
65+
66+
def test_beam_search_mismatched_alphabet_long(self):
67+
""" simple beam search test with too many alphabet chars"""
68+
alphabet = "NAGCTX"
69+
with self.assertRaises(ValueError):
70+
beam_search(self.probs, alphabet, self.beam_size, self.beam_cut_threshold)
71+
3172
def test_beam_search_short_alphabet(self):
3273
""" simple beam search test with short alphabet"""
33-
alphabet = "NAG"
34-
seq, path = beam_search(self.probs, alphabet, self.beam_size, self.beam_cut_threshold)
74+
self.alphabet = "NAG"
75+
self.probs = self.get_random_data()
76+
seq, path = beam_search(self.probs, self.alphabet, self.beam_size, self.beam_cut_threshold)
3577
self.assertEqual(len(seq), len(path))
36-
self.assertEqual(len(set(seq)), len(alphabet) - 1)
78+
self.assertEqual(len(set(seq)), len(self.alphabet) - 1)
3779

3880
def test_beam_search_long_alphabet(self):
3981
""" simple beam search test with long alphabet"""

0 commit comments

Comments
 (0)