Skip to content

Commit ecffeac

Browse files
Allow NaNs in input X (#362)
1 parent f1dd7c6 commit ecffeac

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

ngboost/ngboost.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""The NGBoost library"""
2+
23
# pylint: disable=line-too-long,too-many-instance-attributes,too-many-arguments
34
# pylint: disable=unused-argument,too-many-locals,too-many-branches,too-many-statements
45
# pylint: disable=unused-variable,invalid-unary-operand-type,attribute-defined-outside-init
@@ -342,7 +343,12 @@ def partial_fit(
342343
raise ValueError("y cannot be None")
343344

344345
X, Y = check_X_y(
345-
X, Y, accept_sparse=True, y_numeric=True, multi_output=self.multi_output
346+
X,
347+
Y,
348+
accept_sparse=True,
349+
force_all_finite="allow-nan",
350+
multi_output=self.multi_output,
351+
y_numeric=True,
346352
)
347353

348354
self.n_features = X.shape[1]
@@ -357,8 +363,9 @@ def partial_fit(
357363
X_val,
358364
Y_val,
359365
accept_sparse=True,
360-
y_numeric=True,
366+
force_all_finite="allow-nan",
361367
multi_output=self.multi_output,
368+
y_numeric=True,
362369
)
363370
val_params = self.pred_param(X_val)
364371
val_loss_list = []
@@ -490,7 +497,7 @@ def pred_dist(self, X, max_iter=None):
490497
A NGBoost distribution object
491498
"""
492499

493-
X = check_array(X, accept_sparse=True)
500+
X = check_array(X, accept_sparse=True, force_all_finite="allow-nan")
494501

495502
params = np.asarray(self.pred_param(X, max_iter))
496503
dist = self.Dist(params.T)
@@ -537,7 +544,7 @@ def predict(self, X, max_iter=None):
537544
Numpy array of the estimates of Y
538545
"""
539546

540-
X = check_array(X, accept_sparse=True)
547+
X = check_array(X, accept_sparse=True, force_all_finite="allow-nan")
541548

542549
return self.pred_dist(X, max_iter=max_iter).predict()
543550

0 commit comments

Comments
 (0)