File tree Expand file tree Collapse file tree 1 file changed +11
-7
lines changed Expand file tree Collapse file tree 1 file changed +11
-7
lines changed Original file line number Diff line number Diff line change 1
1
import numpy as np
2
2
import pandas as pd
3
3
4
- from src .config import CONFIG
5
-
6
4
from ._core import _transition_counts
7
5
8
- alpha = float (CONFIG ["model" ]["laplace_alpha" ])
9
-
10
6
11
7
def build_transition_matrices (df : pd .DataFrame , * , dtype = np .float32 ) -> np .ndarray :
12
8
counts = _transition_counts (df , dtype = dtype )
13
- counts += alpha
14
- counts /= counts .sum (axis = 2 , keepdims = True )
15
- return counts .astype (dtype )
9
+ row_sum = counts .sum (axis = 2 , keepdims = True )
10
+ empty = row_sum == 0
11
+ if np .any (empty ):
12
+ n_states = counts .shape [2 ]
13
+ idx_b , idx_i , _ = np .where (empty )
14
+ counts [idx_b , idx_i , :] = 0
15
+ counts [idx_b , idx_i , idx_i ] = 1
16
+ row_sum [empty ] = 1
17
+ probs = counts / row_sum
18
+
19
+ return probs .astype (dtype )
You can’t perform that action at this time.
0 commit comments