Skip to content

Commit b828504

Browse files
committed
uv ruff applied
1 parent 75add27 commit b828504

File tree

2 files changed

+366
-363
lines changed

2 files changed

+366
-363
lines changed

src/linearboost/linear_boost.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -216,18 +216,17 @@ def __init__(
216216
):
217217
if algorithm not in {"SAMME", "SAMME.R"}:
218218
raise ValueError("algorithm must be 'SAMME' or 'SAMME.R'")
219-
219+
220220
super().__init__(
221-
estimator=SEFR(),
222-
n_estimators=n_estimators,
223-
learning_rate=learning_rate
221+
estimator=SEFR(), n_estimators=n_estimators, learning_rate=learning_rate
224222
)
225223
self.algorithm = algorithm
226224
self.scaler = scaler
227225
self.class_weight = class_weight
228226
self.loss_function = loss_function
229227

230228
if SKLEARN_V1_6_OR_LATER:
229+
231230
def __sklearn_tags__(self):
232231
tags = super().__sklearn_tags__()
233232
tags.input_tags.sparse = False
@@ -236,7 +235,6 @@ def __sklearn_tags__(self):
236235
tags.classifier_tags.poor_score = True
237236
return tags
238237

239-
240238
def _more_tags(self) -> dict[str, bool]:
241239
return {
242240
"binary_only": True,
@@ -306,7 +304,6 @@ def fit(self, X, y, sample_weight=None) -> Self:
306304
sample_weight = sample_weight * expanded_class_weight
307305
else:
308306
sample_weight = expanded_class_weight
309-
310307

311308
with warnings.catch_warnings():
312309
if SKLEARN_V1_6_OR_LATER:
@@ -316,7 +313,7 @@ def fit(self, X, y, sample_weight=None) -> Self:
316313
message=".*parameter 'algorithm' is deprecated.*",
317314
)
318315
return super().fit(X_transformed, y, sample_weight)
319-
316+
320317
def _samme_proba(self, estimator, n_classes, X):
321318
"""Calculate algorithm 4, step 2, equation c) of Zhu et al [1].
322319
@@ -345,7 +342,9 @@ def _boost(self, iboost, X, y, sample_weight, random_state):
345342
y_pred = estimator.predict(X)
346343

347344
incorrect = y_pred != y
348-
estimator_error = np.mean(np.average(incorrect, weights=sample_weight, axis=0))
345+
estimator_error = np.mean(
346+
np.average(incorrect, weights=sample_weight, axis=0)
347+
)
349348

350349
if estimator_error <= 0:
351350
return sample_weight, 1.0, 0.0
@@ -355,20 +354,23 @@ def _boost(self, iboost, X, y, sample_weight, random_state):
355354
return None, None, None
356355

357356
# Compute SEFR-specific weight update
358-
estimator_weight = self.learning_rate * np.log((1 - estimator_error) / estimator_error)
357+
estimator_weight = self.learning_rate * np.log(
358+
(1 - estimator_error) / estimator_error
359+
)
359360

360361
if iboost < self.n_estimators - 1:
361362
sample_weight = np.exp(
362-
np.log(sample_weight) + estimator_weight * incorrect * (sample_weight > 0)
363+
np.log(sample_weight)
364+
+ estimator_weight * incorrect * (sample_weight > 0)
363365
)
364366

365367
return sample_weight, estimator_weight, estimator_error
366-
368+
367369
else: # standard SAMME
368370
y_pred = estimator.predict(X)
369371
incorrect = y_pred != y
370372
estimator_error = np.mean(np.average(incorrect, weights=sample_weight))
371-
373+
372374
if estimator_error <= 0:
373375
return sample_weight, 1.0, 0.0
374376
if estimator_error >= 0.5:
@@ -378,17 +380,18 @@ def _boost(self, iboost, X, y, sample_weight, random_state):
378380
"BaseClassifier in AdaBoostClassifier ensemble is worse than random, ensemble cannot be fit."
379381
)
380382
return None, None, None
381-
382-
estimator_weight = (self.learning_rate *
383-
np.log((1. - estimator_error) / max(estimator_error, 1e-10)))
384-
383+
384+
estimator_weight = self.learning_rate * np.log(
385+
(1.0 - estimator_error) / max(estimator_error, 1e-10)
386+
)
387+
385388
sample_weight *= np.exp(estimator_weight * incorrect)
386389

387390
# Normalize sample weights
388391
sample_weight /= np.sum(sample_weight)
389392

390393
return sample_weight, estimator_weight, estimator_error
391-
394+
392395
def decision_function(self, X):
393396
check_is_fitted(self)
394397
X_transformed = self.scaler_.transform(X)
@@ -399,7 +402,8 @@ def decision_function(self, X):
399402
n_classes = len(classes)
400403

401404
pred = sum(
402-
self._samme_proba(estimator, n_classes, X_transformed) for estimator in self.estimators_
405+
self._samme_proba(estimator, n_classes, X_transformed)
406+
for estimator in self.estimators_
403407
)
404408
pred /= self.estimator_weights_.sum()
405409
if n_classes == 2:
@@ -410,7 +414,7 @@ def decision_function(self, X):
410414
else:
411415
# Standard SAMME algorithm from AdaBoostClassifier (discrete)
412416
return super().decision_function(X_transformed)
413-
417+
414418
def predict(self, X):
415419
"""Predict classes for X.
416420
@@ -434,4 +438,3 @@ def predict(self, X):
434438
return self.classes_.take(pred > 0, axis=0)
435439

436440
return self.classes_.take(np.argmax(pred, axis=1), axis=0)
437-

0 commit comments

Comments
 (0)