diff --git a/src/linearboost/linear_boost.py b/src/linearboost/linear_boost.py index d15c151..7ffe091 100644 --- a/src/linearboost/linear_boost.py +++ b/src/linearboost/linear_boost.py @@ -6,8 +6,8 @@ # See https://github.com/scikit-learn/scikit-learn/blob/main/COPYING for details. # # Additional code and modifications: -# - Hamidreza Keshavarz (hamid9@outlook.com) — machine learning logic, design, and new algorithms -# - Mehdi Samsami (mehdisamsami@live.com) — software refactoring, compatibility with scikit-learn framework, and packaging +# - Hamidreza Keshavarz (hamid9@outlook.com) — machine learning logic, design, and new algorithms +# - Mehdi Samsami (mehdisamsami@live.com) — software refactoring, compatibility with scikit-learn framework, and packaging # # The combined work is licensed under the MIT License. @@ -299,6 +299,20 @@ class LinearBoostClassifier(_DenseAdaBoostClassifier): - 'maxabs': Uses MaxAbsScaler. - 'robust': Applies RobustScaler. + kernel : {'linear', 'poly', 'rbf', 'sigmoid'} or callable, default='linear' + Specifies the kernel type to be used in the algorithm. + If a callable is given, it is used to pre-compute the kernel matrix. + + gamma : float, default=None + Kernel coefficient for 'rbf', 'poly' and 'sigmoid'. If None, then it is + set to 1.0 / n_features. + + degree : int, default=3 + Degree for 'poly' kernels. Ignored by other kernels. + + coef0 : float, default=1 + Independent term in kernel function. It is only significant in 'poly' and 'sigmoid'. + class_weight : {"balanced"}, dict or list of dicts, default=None Weights associated with classes in the form ``{class_label: weight}``. If not given, all classes are supposed to have weight one. @@ -385,6 +399,10 @@ class LinearBoostClassifier(_DenseAdaBoostClassifier): "learning_rate": [Interval(Real, 0, None, closed="neither")], "algorithm": [StrOptions({"SAMME", "SAMME.R"})], "scaler": [StrOptions({s for s in _scalers})], + "kernel": [StrOptions({"linear", "poly", "rbf", "sigmoid"}), callable], + "gamma": [Interval(Real, 0, None, closed="left"), None], + "degree": [Interval(Integral, 1, None, closed="left"), None], + "coef0": [Real, None], "class_weight": [ StrOptions({"balanced"}), dict, @@ -403,14 +421,24 @@ def __init__( scaler="minmax", class_weight=None, loss_function=None, + kernel="linear", + gamma=None, + degree=3, + coef0=1, ): super().__init__( - estimator=SEFR(), n_estimators=n_estimators, learning_rate=learning_rate + estimator=SEFR(kernel=kernel, gamma=gamma, degree=degree, coef0=coef0), + n_estimators=n_estimators, + learning_rate=learning_rate, ) self.algorithm = algorithm self.scaler = scaler self.class_weight = class_weight self.loss_function = loss_function + self.kernel = kernel + self.gamma = gamma + self.degree = degree + self.coef0 = coef0 if SKLEARN_V1_6_OR_LATER: diff --git a/src/linearboost/sefr.py b/src/linearboost/sefr.py index 60def0b..3d08dbd 100644 --- a/src/linearboost/sefr.py +++ b/src/linearboost/sefr.py @@ -8,11 +8,14 @@ from typing_extensions import Self import numpy as np +from numbers import Integral, Real from sklearn.base import BaseEstimator from sklearn.linear_model._base import LinearClassifierMixin +from sklearn.metrics.pairwise import pairwise_kernels from sklearn.utils.extmath import safe_sparse_dot from sklearn.utils.multiclass import check_classification_targets, type_of_target -from sklearn.utils.validation import _check_sample_weight +from sklearn.utils.validation import _check_sample_weight, check_is_fitted +from sklearn.utils._param_validation import Interval, StrOptions from ._utils import ( SKLEARN_V1_6_OR_LATER, @@ -38,28 +41,46 @@ class SEFR(LinearClassifierMixin, BaseEstimator): Parameters ---------- fit_intercept : bool, default=True - Specifies if a constant (a.k.a. bias or intercept) should be - added to the decision function. + Specifies if a constant (a.k.a. bias or intercept) should be + added to the decision function. + + kernel : {'linear', 'poly', 'rbf', 'sigmoid'} or callable, default='linear' + Specifies the kernel type to be used in the algorithm. + If a callable is given, it is used to pre-compute the kernel matrix. + + gamma : float, default=None + Kernel coefficient for 'rbf', 'poly' and 'sigmoid'. If None, then it is + set to 1.0 / n_features. + + degree : int, default=3 + Degree for 'poly' kernels. Ignored by other kernels. + + coef0 : float, default=1 + Independent term in kernel function. It is only significant in 'poly' and 'sigmoid'. Attributes ---------- classes_ : ndarray of shape (n_classes, ) - A list of class labels known to the classifier. + A list of class labels known to the classifier. - coef_ : ndarray of shape (1, n_features) - Coefficient of the features in the decision function. + coef_ : ndarray of shape (1, n_features) or (1, n_samples) + Coefficient of the features in the decision function. When a kernel is used, + the shape will be (1, n_samples). intercept_ : ndarray of shape (1,) - Intercept (a.k.a. bias) added to the decision function. + Intercept (a.k.a. bias) added to the decision function. - If `fit_intercept` is set to False, the intercept is set to zero. + If `fit_intercept` is set to False, the intercept is set to zero. n_features_in_ : int - Number of features seen during :term:`fit`. + Number of features seen during :term:`fit`. feature_names_in_ : ndarray of shape (`n_features_in_`,) - Names of features seen during :term:`fit`. Defined only when `X` - has feature names that are all strings. + Names of features seen during :term:`fit`. Defined only when `X` + has feature names that are all strings. + + X_fit_ : ndarray of shape (n_samples, n_features) + The training data, stored when a kernel is used. Notes ----- @@ -70,22 +91,35 @@ class SEFR(LinearClassifierMixin, BaseEstimator): >>> from linearboost import SEFR >>> from sklearn.datasets import load_breast_cancer >>> X, y = load_breast_cancer(return_X_y=True) - >>> clf = SEFR().fit(X, y) + >>> clf = SEFR(kernel='rbf').fit(X, y) >>> clf.predict(X[:2, :]) array([0, 0]) - >>> clf.predict_proba(X[:2, :]) - array([[1.00...e+000, 2.04...e-154], - [1.00...e+000, 1.63...e-165]]) >>> clf.score(X, y) - 0.86... + 0.89... """ _parameter_constraints: dict = { "fit_intercept": ["boolean"], + "kernel": [StrOptions({"linear", "poly", "rbf", "sigmoid"}), callable], + "gamma": [Interval(Real, 0, None, closed="left"), None], + "degree": [Interval(Integral, 1, None, closed="left"), None], + "coef0": [Real, None], } - def __init__(self, *, fit_intercept=True): + def __init__( + self, + *, + fit_intercept=True, + kernel="linear", + gamma=None, + degree=3, + coef0=1, + ): self.fit_intercept = fit_intercept + self.kernel = kernel + self.gamma = gamma + self.degree = degree + self.coef0 = coef0 if SKLEARN_V1_6_OR_LATER: @@ -145,6 +179,23 @@ def _check_X_y(self, X, y) -> tuple[np.ndarray, np.ndarray]: return X, y + def _get_kernel_matrix(self, X, Y=None): + if Y is None: + Y = self.X_fit_ + + if callable(self.kernel): + return self.kernel(X, Y) + else: + return pairwise_kernels( + X, + Y, + metric=self.kernel, + filter_params=True, + gamma=self.gamma, + degree=self.degree, + coef0=self.coef0, + ) + @_fit_context(prefer_skip_nested_validation=True) def fit(self, X, y, sample_weight=None) -> Self: """ @@ -153,27 +204,33 @@ def fit(self, X, y, sample_weight=None) -> Self: Parameters ---------- X : {array-like, sparse matrix} of shape (n_samples, n_features) - Training vector, where `n_samples` is the number of samples and - `n_features` is the number of features. + Training vector, where `n_samples` is the number of samples and + `n_features` is the number of features. y : array-like of shape (n_samples,) - Target vector relative to X. + Target vector relative to X. sample_weight : array-like of shape (n_samples,) default=None - Array of weights that are assigned to individual samples. - If not provided, then each sample is given unit weight. + Array of weights that are assigned to individual samples. + If not provided, then each sample is given unit weight. Returns ------- self - Fitted estimator. + Fitted estimator. """ _check_n_features(self, X=X, reset=True) _check_feature_names(self, X=X, reset=True) X, y = self._check_X_y(X, y) + self.X_fit_ = X self.classes_, y_ = np.unique(y, return_inverse=True) + if self.kernel == "linear": + K = X + else: + K = self._get_kernel_matrix(X) + pos_labels = y_ == 1 neg_labels = y_ == 0 @@ -193,13 +250,13 @@ def fit(self, X, y, sample_weight=None) -> Self: if np.all(pos_sample_weight == 0) or np.all(neg_sample_weight == 0): raise ValueError("SEFR requires 2 classes; got only 1 class.") - avg_pos = np.average(X[pos_labels, :], axis=0, weights=pos_sample_weight) - avg_neg = np.average(X[neg_labels, :], axis=0, weights=neg_sample_weight) + avg_pos = np.average(K[pos_labels, :], axis=0, weights=pos_sample_weight) + avg_neg = np.average(K[neg_labels, :], axis=0, weights=neg_sample_weight) self.coef_ = (avg_pos - avg_neg) / (avg_pos + avg_neg + 1e-7) self.coef_ = np.reshape(self.coef_, (1, -1)) if self.fit_intercept: - scores = safe_sparse_dot(X, self.coef_.T, dense_output=True) + scores = safe_sparse_dot(K, self.coef_.T, dense_output=True) pos_score_avg = np.average( scores[pos_labels][:, 0], weights=pos_sample_weight ) @@ -217,6 +274,17 @@ def fit(self, X, y, sample_weight=None) -> Self: return self + def decision_function(self, X): + check_is_fitted(self) + X = self._check_X(X) + if self.kernel == "linear": + K = X + else: + K = self._get_kernel_matrix(X) + return ( + safe_sparse_dot(K, self.coef_.T, dense_output=True) + self.intercept_ + ).ravel() + def predict_proba(self, X): """ Probability estimates. @@ -227,16 +295,22 @@ def predict_proba(self, X): Parameters ---------- X : array-like of shape (n_samples, n_features) - Vector to be scored, where `n_samples` is the number of samples and - `n_features` is the number of features. + Vector to be scored, where `n_samples` is the number of samples and + `n_features` is the number of features. Returns ------- T : array-like of shape (n_samples, n_classes) - Returns the probability of the sample for each class in the model, - where classes are ordered as they are in ``self.classes_``. + Returns the probability of the sample for each class in the model, + where classes are ordered as they are in ``self.classes_``. """ - score = self.decision_function(X) / np.linalg.norm(self.coef_) + check_is_fitted(self) + norm_coef = np.linalg.norm(self.coef_) + if norm_coef == 0: + # Handle the case of a zero-norm coefficient vector to avoid division by zero + score = self.decision_function(X) + else: + score = self.decision_function(X) / norm_coef proba = 1.0 / (1.0 + np.exp(-score)) return np.column_stack((1.0 - proba, proba)) @@ -250,13 +324,13 @@ def predict_log_proba(self, X): Parameters ---------- X : array-like of shape (n_samples, n_features) - Vector to be scored, where `n_samples` is the number of samples and - `n_features` is the number of features. + Vector to be scored, where `n_samples` is the number of samples and + `n_features` is the number of features. Returns ------- T : array-like of shape (n_samples, n_classes) - Returns the log-probability of the sample for each class in the - model, where classes are ordered as they are in ``self.classes_``. + Returns the log-probability of the sample for each class in the + model, where classes are ordered as they are in ``self.classes_``. """ return np.log(self.predict_proba(X))