diff --git a/ngboost/ngboost.py b/ngboost/ngboost.py index 13f420d0..2406e574 100644 --- a/ngboost/ngboost.py +++ b/ngboost/ngboost.py @@ -35,6 +35,7 @@ class NGBoost: n_estimators : the number of boosting iterations to fit learning_rate : the learning rate minibatch_frac : the percent subsample of rows to use in each boosting iteration + col_sample : the percent subsample of columns to use in each boosting iteration verbose : flag indicating whether output should be printed during fitting verbose_eval : increment (in boosting iterations) at which output should be printed tol : numerical tolerance to be used in optimization @@ -147,7 +148,10 @@ def sample(self, X, Y, sample_weight, params): ) if self.col_sample != 1.0: - col_size = int(self.col_sample * X.shape[1]) + if self.col_sample > 0.0: + col_size = max(1, int(self.col_sample * X.shape[1])) + else: + col_size = 0 col_idx = self.random_state.choice( np.arange(X.shape[1]), col_size, replace=False )