Skip to content

Commit d8be38e

Browse files
removed lp
1 parent 067a50d commit d8be38e

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

src/markov/transitions.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
import numpy as np
22
import pandas as pd
33

4-
from src.config import CONFIG
5-
64
from ._core import _transition_counts
75

8-
alpha = float(CONFIG["model"]["laplace_alpha"])
9-
106

117
def build_transition_matrices(df: pd.DataFrame, *, dtype=np.float32) -> np.ndarray:
128
counts = _transition_counts(df, dtype=dtype)
13-
counts += alpha
14-
counts /= counts.sum(axis=2, keepdims=True)
15-
return counts.astype(dtype)
9+
row_sum = counts.sum(axis=2, keepdims=True)
10+
empty = row_sum == 0
11+
if np.any(empty):
12+
n_states = counts.shape[2]
13+
idx_b, idx_i, _ = np.where(empty)
14+
counts[idx_b, idx_i, :] = 0
15+
counts[idx_b, idx_i, idx_i] = 1
16+
row_sum[empty] = 1
17+
probs = counts / row_sum
18+
19+
return probs.astype(dtype)

0 commit comments

Comments
 (0)