@@ -47,7 +47,6 @@ def __post_init__(self):
47
47
self .graph .add_node (var )
48
48
self .trans_graph = networkx .transitive_closure (self .graph , reflexive = True )
49
49
50
- @np .errstate (all = "raise" )
51
50
def compute (
52
51
self ,
53
52
xdat : np .array ,
@@ -59,7 +58,7 @@ def compute(
59
58
fixed_to_yind : int = None ,
60
59
fixed_vals : list = None ,
61
60
# override default parameter values
62
- parameters : dict [str , float ] = {} ,
61
+ parameters : dict [str , float ] | None = None ,
63
62
) -> np .array :
64
63
"""Compute y values for given x values
65
64
@@ -69,16 +68,19 @@ def compute(
69
68
assert xdat .ndim == 2 , f"xdat must be m*tau (is { xdat .ndim } -dimensional)"
70
69
assert xdat .shape [0 ] == self .mdim , f"xdat must be m*tau (is { xdat .shape } )"
71
70
tau = xdat .shape [1 ]
71
+ if parameters is None :
72
+ parameters = {}
72
73
parameters = self .parameters | parameters
73
74
74
- yhat = np .array ([[float ("nan" )] * tau ] * len (self .yvars ))
75
- for i , eq in enumerate (self ._model_lam ):
76
- if fixed_yind == i :
77
- yhat [i , :] = fixed_yval
78
- else :
79
- eq_inputs = np .array (
80
- [[* xval , * yval ] for xval , yval in zip (xdat .T , yhat .T )]
81
- )
75
+ with np .errstate (all = "raise" ):
76
+ yhat = np .array ([[float ("nan" )] * tau ] * len (self .yvars ))
77
+ for i , eq in enumerate (self ._model_lam ):
78
+ if fixed_yind == i :
79
+ yhat [i , :] = fixed_yval
80
+ else :
81
+ eq_inputs = np .array (
82
+ [[* xval , * yval ] for xval , yval in zip (xdat .T , yhat .T )]
83
+ )
82
84
if fixed_to_yind == i :
83
85
eq_inputs [:, fixed_from_ind ] = fixed_vals
84
86
0 commit comments