3
3
#[ macro_use( s) ]
4
4
extern crate ndarray;
5
5
6
- use numpy:: PyArray2 ;
7
6
use ndarray:: { ArrayBase , Data , Ix2 } ;
7
+ use numpy:: PyArray2 ;
8
8
9
+ use pyo3:: exceptions:: ValueError ;
9
10
use pyo3:: prelude:: * ;
10
11
use pyo3:: wrap_pyfunction;
11
12
@@ -15,8 +16,21 @@ fn beam_search(
15
16
alphabet : String ,
16
17
beam_size : usize ,
17
18
beam_cut_threshold : f32 ,
18
- ) -> ( String , Vec < usize > ) {
19
- beam_search_ ( & result. as_array ( ) , alphabet, beam_size, beam_cut_threshold)
19
+ ) -> PyResult < ( String , Vec < usize > ) > {
20
+ if alphabet. len ( ) != result. shape ( ) [ 1 ] {
21
+ Err ( ValueError :: py_err (
22
+ "alphabet size does not match probability matrix dimensions" ,
23
+ ) )
24
+ } else if beam_size == 0 {
25
+ Err ( ValueError :: py_err ( "beam_size cannot be 0" ) )
26
+ } else {
27
+ Ok ( beam_search_ (
28
+ & result. as_array ( ) ,
29
+ alphabet,
30
+ beam_size,
31
+ beam_cut_threshold,
32
+ ) )
33
+ }
20
34
}
21
35
22
36
#[ pymodule]
@@ -31,12 +45,11 @@ fn beam_search_<D: Data<Elem = f32>>(
31
45
beam_size : usize ,
32
46
beam_cut_threshold : f32 ,
33
47
) -> ( String , Vec < usize > ) {
34
-
35
48
let alphabet: Vec < char > = alphabet. chars ( ) . collect ( ) ;
36
49
37
50
// alphabet_size minus the blank label
38
51
let alphabet_size = alphabet. len ( ) - 1 ;
39
- let duration = result. len ( ) / alphabet . len ( ) ;
52
+ let duration = result. nrows ( ) ;
40
53
41
54
// (base, what, idx)
42
55
let mut beam_prevs = vec ! [ ( 0 , 0 , 0 ) ] ;
@@ -56,12 +69,12 @@ fn beam_search_<D: Data<Elem = f32>>(
56
69
new_probs. push ( ( beam, 0.0 , ( n_prob + base_prob) * pr[ 0 ] ) ) ;
57
70
}
58
71
59
- for b in 1 ..alphabet_size + 1 {
60
- if pr [ b ] < beam_cut_threshold {
72
+ for ( b , & pr_b ) in ( 1 ..= alphabet_size) . zip ( pr . iter ( ) . skip ( 1 ) ) {
73
+ if pr_b < beam_cut_threshold {
61
74
continue ;
62
75
}
63
76
if b == beam_prevs[ beam as usize ] . 0 {
64
- new_probs. push ( ( beam, base_prob * pr [ b ] , 0.0 ) ) ;
77
+ new_probs. push ( ( beam, base_prob * pr_b , 0.0 ) ) ;
65
78
let mut new_beam = beam_forward[ beam as usize ] [ b - 1 ] ;
66
79
if new_beam == -1 && n_prob > 0.0 {
67
80
new_beam = beam_prevs. len ( ) as i32 ;
@@ -70,7 +83,7 @@ fn beam_search_<D: Data<Elem = f32>>(
70
83
beam_forward. push ( vec ! [ -1 ; alphabet_size] ) ;
71
84
}
72
85
73
- new_probs. push ( ( new_beam, n_prob * pr [ b ] , 0.0 ) ) ;
86
+ new_probs. push ( ( new_beam, n_prob * pr_b , 0.0 ) ) ;
74
87
} else {
75
88
let mut new_beam = beam_forward[ beam as usize ] [ b - 1 ] ;
76
89
if new_beam == -1 {
@@ -80,7 +93,7 @@ fn beam_search_<D: Data<Elem = f32>>(
80
93
beam_forward. push ( vec ! [ -1 ; alphabet_size] ) ;
81
94
}
82
95
83
- new_probs. push ( ( new_beam, ( base_prob + n_prob) * pr [ b ] , 0.0 ) ) ;
96
+ new_probs. push ( ( new_beam, ( base_prob + n_prob) * pr_b , 0.0 ) ) ;
84
97
}
85
98
}
86
99
}
@@ -90,19 +103,24 @@ fn beam_search_<D: Data<Elem = f32>>(
90
103
let mut last_key: i32 = -1 ;
91
104
let mut last_key_pos = 0 ;
92
105
for i in 0 ..cur_probs. len ( ) {
93
- if cur_probs[ i] . 0 == last_key {
94
- cur_probs[ last_key_pos] . 1 = cur_probs[ last_key_pos] . 1 + cur_probs[ i] . 1 ;
95
- cur_probs[ last_key_pos] . 2 = cur_probs[ last_key_pos] . 2 + cur_probs[ i] . 2 ;
106
+ let cur_prob = cur_probs[ i] ;
107
+ if cur_prob. 0 == last_key {
108
+ cur_probs[ last_key_pos] . 1 += cur_prob. 1 ;
109
+ cur_probs[ last_key_pos] . 2 += cur_prob. 2 ;
96
110
cur_probs[ i] . 0 = -1 ;
97
111
} else {
98
112
last_key_pos = i;
99
- last_key = cur_probs [ i ] . 0 ;
113
+ last_key = cur_prob . 0 ;
100
114
}
101
115
}
102
116
103
117
cur_probs. retain ( |x| x. 0 != -1 ) ;
104
118
cur_probs. sort_by ( |a, b| ( b. 1 + b. 2 ) . partial_cmp ( & ( a. 1 + a. 2 ) ) . unwrap ( ) ) ;
105
119
cur_probs. truncate ( beam_size) ;
120
+ if cur_probs. is_empty ( ) {
121
+ // we've run out of beam (probably the threshold is too high)
122
+ return ( String :: new ( ) , Vec :: new ( ) ) ;
123
+ }
106
124
let top = cur_probs[ 0 ] . 1 + cur_probs[ 0 ] . 2 ;
107
125
for mut x in & mut cur_probs {
108
126
x. 1 /= top;
0 commit comments