|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import sys |
| 4 | +import warnings |
3 | 5 | from numbers import Integral, Real
|
4 | 6 |
|
| 7 | +if sys.version_info >= (3, 11): |
| 8 | + from typing import Self |
| 9 | +else: |
| 10 | + from typing_extensions import Self |
| 11 | + |
5 | 12 | import numpy as np
|
6 | 13 | from sklearn.base import clone
|
7 | 14 | from sklearn.ensemble import AdaBoostClassifier
|
@@ -65,13 +72,18 @@ class LinearBoostClassifier(AdaBoostClassifier):
|
65 | 72 |
|
66 | 73 | algorithm : {'SAMME', 'SAMME.R'}, default='SAMME'
|
67 | 74 | If 'SAMME' then use the SAMME discrete boosting algorithm.
|
68 |
| - If 'SAMME.R' then use the SAMME.R real boosting algorithm. |
| 75 | + If 'SAMME.R' then use the SAMME.R real boosting algorithm |
| 76 | + (only available in scikit-learn < 1.6). |
69 | 77 | The SAMME.R algorithm typically converges faster than SAMME,
|
70 | 78 | achieving a lower test error with fewer boosting iterations.
|
71 | 79 |
|
72 |
| - .. deprecated:: sklearn 1.6 |
73 |
| - `algorithm` is deprecated and will be removed in sklearn 1.8. This |
74 |
| - estimator only implements the 'SAMME' algorithm. |
| 80 | + .. deprecated:: scikit-learn 1.4 |
| 81 | + `"SAMME.R"` is deprecated and will be removed in scikit-learn 1.6. |
| 82 | + '"SAMME"' will become the default. |
| 83 | +
|
| 84 | + .. deprecated:: scikit-learn 1.6 |
| 85 | + `algorithm` is deprecated and will be removed in scikit-learn 1.8. |
| 86 | + This estimator only implements the 'SAMME' algorithm in scikit-learn >= 1.6. |
75 | 87 |
|
76 | 88 | scaler : str, default='minmax'
|
77 | 89 | Specifies the scaler to apply to the data. Options include:
|
@@ -111,21 +123,21 @@ class LinearBoostClassifier(AdaBoostClassifier):
|
111 | 123 | where:
|
112 | 124 | - y_true: Ground truth (correct) target values.
|
113 | 125 | - y_pred: Estimated target values.
|
114 |
| - - sample_weight: Sample weights. |
| 126 | + - sample_weight: Sample weights (optional). |
115 | 127 |
|
116 | 128 | Attributes
|
117 | 129 | ----------
|
118 | 130 | estimator_ : estimator
|
119 | 131 | The base estimator (SEFR) from which the ensemble is grown.
|
120 | 132 |
|
121 |
| - .. versionadded:: sklearn 1.2 |
| 133 | + .. versionadded:: scikit-learn 1.2 |
122 | 134 | `base_estimator_` was renamed to `estimator_`.
|
123 | 135 |
|
124 | 136 | base_estimator_ : estimator
|
125 | 137 | The base estimator from which the ensemble is grown.
|
126 | 138 |
|
127 |
| - .. deprecated:: sklearn 1.2 |
128 |
| - `base_estimator_` is deprecated and will be removed in sklearn 1.4. |
| 139 | + .. deprecated:: scikit-learn 1.2 |
| 140 | + `base_estimator_` is deprecated and will be removed in scikit-learn 1.4. |
129 | 141 | Use `estimator_` instead.
|
130 | 142 |
|
131 | 143 | estimators_ : list of classifiers
|
@@ -176,10 +188,9 @@ class LinearBoostClassifier(AdaBoostClassifier):
|
176 | 188 | _parameter_constraints: dict = {
|
177 | 189 | "n_estimators": [Interval(Integral, 1, None, closed="left")],
|
178 | 190 | "learning_rate": [Interval(Real, 0, None, closed="neither")],
|
179 |
| - "algorithm": [ |
180 |
| - StrOptions({"SAMME", "SAMME.R"}), |
181 |
| - Hidden(StrOptions({"deprecated"})), |
182 |
| - ], |
| 191 | + "algorithm": [StrOptions({"SAMME"}), Hidden(StrOptions({"deprecated"}))] |
| 192 | + if SKLEARN_V1_6_OR_LATER |
| 193 | + else [StrOptions({"SAMME", "SAMME.R"})], |
183 | 194 | "scaler": [StrOptions({s for s in _scalers})],
|
184 | 195 | "class_weight": [
|
185 | 196 | StrOptions({"balanced_subsample", "balanced"}),
|
@@ -257,7 +268,7 @@ def _check_X_y(self, X, y) -> tuple[np.ndarray, np.ndarray]:
|
257 | 268 |
|
258 | 269 | return X, y
|
259 | 270 |
|
260 |
| - def fit(self, X, y, sample_weight=None) -> "LinearBoostClassifier": |
| 271 | + def fit(self, X, y, sample_weight=None) -> Self: |
261 | 272 | X, y = self._check_X_y(X, y)
|
262 | 273 | self.classes_ = np.unique(y)
|
263 | 274 | self.n_classes_ = self.classes_.shape[0]
|
@@ -291,7 +302,14 @@ def fit(self, X, y, sample_weight=None) -> "LinearBoostClassifier":
|
291 | 302 | else:
|
292 | 303 | sample_weight = expanded_class_weight
|
293 | 304 |
|
294 |
| - return super().fit(X_transformed, y, sample_weight) |
| 305 | + with warnings.catch_warnings(): |
| 306 | + if SKLEARN_V1_6_OR_LATER: |
| 307 | + warnings.filterwarnings( |
| 308 | + "ignore", |
| 309 | + category=FutureWarning, |
| 310 | + message=".*parameter 'algorithm' is deprecated.*", |
| 311 | + ) |
| 312 | + return super().fit(X_transformed, y, sample_weight) |
295 | 313 |
|
296 | 314 | def _boost(self, iboost, X, y, sample_weight, random_state):
|
297 | 315 | estimator = self._make_estimator(random_state=random_state)
|
|
0 commit comments