Skip to content

Commit 672f767

Browse files
committed
add Self type annotations, improve sklearn compatibility and docstring
1 parent 5a4b99f commit 672f767

File tree

1 file changed

+32
-14
lines changed

1 file changed

+32
-14
lines changed

src/linearboost/linear_boost.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
from __future__ import annotations
22

3+
import sys
4+
import warnings
35
from numbers import Integral, Real
46

7+
if sys.version_info >= (3, 11):
8+
from typing import Self
9+
else:
10+
from typing_extensions import Self
11+
512
import numpy as np
613
from sklearn.base import clone
714
from sklearn.ensemble import AdaBoostClassifier
@@ -65,13 +72,18 @@ class LinearBoostClassifier(AdaBoostClassifier):
6572
6673
algorithm : {'SAMME', 'SAMME.R'}, default='SAMME'
6774
If 'SAMME' then use the SAMME discrete boosting algorithm.
68-
If 'SAMME.R' then use the SAMME.R real boosting algorithm.
75+
If 'SAMME.R' then use the SAMME.R real boosting algorithm
76+
(only available in scikit-learn < 1.6).
6977
The SAMME.R algorithm typically converges faster than SAMME,
7078
achieving a lower test error with fewer boosting iterations.
7179
72-
.. deprecated:: sklearn 1.6
73-
`algorithm` is deprecated and will be removed in sklearn 1.8. This
74-
estimator only implements the 'SAMME' algorithm.
80+
.. deprecated:: scikit-learn 1.4
81+
`"SAMME.R"` is deprecated and will be removed in scikit-learn 1.6.
82+
'"SAMME"' will become the default.
83+
84+
.. deprecated:: scikit-learn 1.6
85+
`algorithm` is deprecated and will be removed in scikit-learn 1.8.
86+
This estimator only implements the 'SAMME' algorithm in scikit-learn >= 1.6.
7587
7688
scaler : str, default='minmax'
7789
Specifies the scaler to apply to the data. Options include:
@@ -111,21 +123,21 @@ class LinearBoostClassifier(AdaBoostClassifier):
111123
where:
112124
- y_true: Ground truth (correct) target values.
113125
- y_pred: Estimated target values.
114-
- sample_weight: Sample weights.
126+
- sample_weight: Sample weights (optional).
115127
116128
Attributes
117129
----------
118130
estimator_ : estimator
119131
The base estimator (SEFR) from which the ensemble is grown.
120132
121-
.. versionadded:: sklearn 1.2
133+
.. versionadded:: scikit-learn 1.2
122134
`base_estimator_` was renamed to `estimator_`.
123135
124136
base_estimator_ : estimator
125137
The base estimator from which the ensemble is grown.
126138
127-
.. deprecated:: sklearn 1.2
128-
`base_estimator_` is deprecated and will be removed in sklearn 1.4.
139+
.. deprecated:: scikit-learn 1.2
140+
`base_estimator_` is deprecated and will be removed in scikit-learn 1.4.
129141
Use `estimator_` instead.
130142
131143
estimators_ : list of classifiers
@@ -176,10 +188,9 @@ class LinearBoostClassifier(AdaBoostClassifier):
176188
_parameter_constraints: dict = {
177189
"n_estimators": [Interval(Integral, 1, None, closed="left")],
178190
"learning_rate": [Interval(Real, 0, None, closed="neither")],
179-
"algorithm": [
180-
StrOptions({"SAMME", "SAMME.R"}),
181-
Hidden(StrOptions({"deprecated"})),
182-
],
191+
"algorithm": [StrOptions({"SAMME"}), Hidden(StrOptions({"deprecated"}))]
192+
if SKLEARN_V1_6_OR_LATER
193+
else [StrOptions({"SAMME", "SAMME.R"})],
183194
"scaler": [StrOptions({s for s in _scalers})],
184195
"class_weight": [
185196
StrOptions({"balanced_subsample", "balanced"}),
@@ -257,7 +268,7 @@ def _check_X_y(self, X, y) -> tuple[np.ndarray, np.ndarray]:
257268

258269
return X, y
259270

260-
def fit(self, X, y, sample_weight=None) -> "LinearBoostClassifier":
271+
def fit(self, X, y, sample_weight=None) -> Self:
261272
X, y = self._check_X_y(X, y)
262273
self.classes_ = np.unique(y)
263274
self.n_classes_ = self.classes_.shape[0]
@@ -291,7 +302,14 @@ def fit(self, X, y, sample_weight=None) -> "LinearBoostClassifier":
291302
else:
292303
sample_weight = expanded_class_weight
293304

294-
return super().fit(X_transformed, y, sample_weight)
305+
with warnings.catch_warnings():
306+
if SKLEARN_V1_6_OR_LATER:
307+
warnings.filterwarnings(
308+
"ignore",
309+
category=FutureWarning,
310+
message=".*parameter 'algorithm' is deprecated.*",
311+
)
312+
return super().fit(X_transformed, y, sample_weight)
295313

296314
def _boost(self, iboost, X, y, sample_weight, random_state):
297315
estimator = self._make_estimator(random_state=random_state)

0 commit comments

Comments
 (0)