Skip to content

Commit 4f82339

Browse files
committed
Apply transformation first
1 parent ca4cf72 commit 4f82339

File tree

1 file changed

+16
-18
lines changed

1 file changed

+16
-18
lines changed

src/linearboost/linear_boost.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -271,38 +271,36 @@ def _check_X_y(self, X, y) -> tuple[np.ndarray, np.ndarray]:
271271
def fit(self, X, y, sample_weight=None) -> Self:
272272
if self.algorithm not in {"SAMME", "SAMME.R"}:
273273
raise ValueError("algorithm must be 'SAMME' or 'SAMME.R'")
274-
X, y = check_X_y(X, y, accept_sparse=True)
274+
275+
if self.scaler not in _scalers:
276+
raise ValueError('Invalid scaler provided; got "%s".' % self.scaler)
277+
278+
if self.scaler == "minmax":
279+
self.scaler_ = clone(_scalers["minmax"])
280+
else:
281+
self.scaler_ = make_pipeline(
282+
clone(_scalers[self.scaler]), clone(_scalers["minmax"])
283+
)
284+
X_transformed = self.scaler_.fit_transform(X)
275285

276286
if sample_weight is not None:
277287
sample_weight = np.asarray(sample_weight)
278-
if sample_weight.shape[0] != X.shape[0]:
288+
if sample_weight.shape[0] != X_transformed.shape[0]:
279289
raise ValueError(
280-
f"sample_weight.shape == {sample_weight.shape} is incompatible with X.shape == {X.shape}"
290+
f"sample_weight.shape == {sample_weight.shape} is incompatible with X.shape == {X_transformed.shape}"
281291
)
282-
# fix here
283292
nonzero_mask = (
284293
sample_weight.sum(axis=1) != 0
285294
if sample_weight.ndim > 1
286295
else sample_weight != 0
287296
)
288-
X = X[nonzero_mask]
297+
X_transformed = X_transformed[nonzero_mask]
289298
y = y[nonzero_mask]
290299
sample_weight = sample_weight[nonzero_mask]
291-
X, y = self._check_X_y(X, y)
300+
X_transformed, y = self._check_X_y(X_transformed, y)
292301
self.classes_ = np.unique(y)
293302
self.n_classes_ = self.classes_.shape[0]
294303

295-
if self.scaler not in _scalers:
296-
raise ValueError('Invalid scaler provided; got "%s".' % self.scaler)
297-
298-
if self.scaler == "minmax":
299-
self.scaler_ = clone(_scalers["minmax"])
300-
else:
301-
self.scaler_ = make_pipeline(
302-
clone(_scalers[self.scaler]), clone(_scalers["minmax"])
303-
)
304-
X_transformed = self.scaler_.fit_transform(X)
305-
306304
if self.class_weight is not None:
307305
valid_presets = ("balanced", "balanced_subsample")
308306
if (
@@ -326,7 +324,7 @@ def fit(self, X, y, sample_weight=None) -> Self:
326324
warnings.filterwarnings(
327325
"ignore",
328326
category=FutureWarning,
329-
message=".*parameter 'algorithm' is deprecated.*",
327+
message=".*parameter 'algorithm' may change in the future.*",
330328
)
331329
return super().fit(X_transformed, y, sample_weight)
332330

0 commit comments

Comments
 (0)