Skip to content

Commit 855dbf5

Browse files
committed
Fix compute context manager
1 parent dd40a17 commit 855dbf5

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

causing/model.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def __post_init__(self):
4747
self.graph.add_node(var)
4848
self.trans_graph = networkx.transitive_closure(self.graph, reflexive=True)
4949

50-
@np.errstate(all="raise")
5150
def compute(
5251
self,
5352
xdat: np.array,
@@ -59,7 +58,7 @@ def compute(
5958
fixed_to_yind: int = None,
6059
fixed_vals: list = None,
6160
# override default parameter values
62-
parameters: dict[str, float] = {},
61+
parameters: dict[str, float] | None = None,
6362
) -> np.array:
6463
"""Compute y values for given x values
6564
@@ -69,16 +68,19 @@ def compute(
6968
assert xdat.ndim == 2, f"xdat must be m*tau (is {xdat.ndim}-dimensional)"
7069
assert xdat.shape[0] == self.mdim, f"xdat must be m*tau (is {xdat.shape})"
7170
tau = xdat.shape[1]
71+
if parameters is None:
72+
parameters = {}
7273
parameters = self.parameters | parameters
7374

74-
yhat = np.array([[float("nan")] * tau] * len(self.yvars))
75-
for i, eq in enumerate(self._model_lam):
76-
if fixed_yind == i:
77-
yhat[i, :] = fixed_yval
78-
else:
79-
eq_inputs = np.array(
80-
[[*xval, *yval] for xval, yval in zip(xdat.T, yhat.T)]
81-
)
75+
with np.errstate(all="raise"):
76+
yhat = np.array([[float("nan")] * tau] * len(self.yvars))
77+
for i, eq in enumerate(self._model_lam):
78+
if fixed_yind == i:
79+
yhat[i, :] = fixed_yval
80+
else:
81+
eq_inputs = np.array(
82+
[[*xval, *yval] for xval, yval in zip(xdat.T, yhat.T)]
83+
)
8284
if fixed_to_yind == i:
8385
eq_inputs[:, fixed_from_ind] = fixed_vals
8486

0 commit comments

Comments
 (0)