Skip to content

Commit fcf77c9

Browse files
HEGAB7Mahmoud
andauthored
fix: make col_sample min equals to 1 (#385)
Co-authored-by: Mahmoud <mhegab@raisaenergy.com>
1 parent c143d4a commit fcf77c9

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

ngboost/ngboost.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class NGBoost:
3535
n_estimators : the number of boosting iterations to fit
3636
learning_rate : the learning rate
3737
minibatch_frac : the percent subsample of rows to use in each boosting iteration
38+
col_sample : the percent subsample of columns to use in each boosting iteration
3839
verbose : flag indicating whether output should be printed during fitting
3940
verbose_eval : increment (in boosting iterations) at which output should be printed
4041
tol : numerical tolerance to be used in optimization
@@ -147,7 +148,10 @@ def sample(self, X, Y, sample_weight, params):
147148
)
148149

149150
if self.col_sample != 1.0:
150-
col_size = int(self.col_sample * X.shape[1])
151+
if self.col_sample > 0.0:
152+
col_size = max(1, int(self.col_sample * X.shape[1]))
153+
else:
154+
col_size = 0
151155
col_idx = self.random_state.choice(
152156
np.arange(X.shape[1]), col_size, replace=False
153157
)

0 commit comments

Comments
 (0)