Skip to content

yangfa-zhang/lunax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

73 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Python version

CN | EN

Lunax is a machine learning framework specifically designed for the processing and analysis of tabular data. The name Lunax is derived from the name of a beloved feline mascot luna🐱 at South China University of Technology. Navigate to API documentations for more detailed information. ⭐️ Star it if you like it ⭐️


Installation

conda create -n your_env_name python=3.11
conda activate your_env_name
pip install lunax

Features

  • Data loading and Data pre-processing
  • EDA analysis
  • Supports multi-model training and Hyperparameter tuning
  • Comprehensive model evaluation
  • Ensemble learning
  • Explainable AI (XAI)
  • Object-oriented design with unified interfaces for easy extension
  • Comprehensive unit testing with pytest for code quality control

Quick Start

Data Loading and Pre-processing

from lunax.data_processing import *
df_train = load_data('train.csv') # or df = load_data('train.parquet')
target = 'label_column_name'
df_train = preprocess_data(df_train,target) # data pre-processing, including missing value handling, feature encoding, feature scaling
X_train, X_val, y_train, y_val = split_data(df_train, target)

Exploratory Data Analysis

from lunax.viz import numeric_eda, categoric_eda
numeric_eda([df_train,df_test],['train','test'],target=target) # numeric feature analysis
categoric_eda([df_train,df_test],['train','test'],target=target) # categorical feature analysis

Automation Machine Learning Modeling

from lunax.models import xgb_clf # or xgb_reg, lgbm_reg, lgbm_clf, cat_clf, cat_reg
from lunax.hyper_opt import OptunaTuner
tuner = OptunaTuner(n_trials=10,model_class="XGBClassifier") # Hyperparameter optimizer, n_trials is the number of optimization times
# or "XGBRegressor", "LGBMRegressor", "LGBMClassifier" , "CatClassifier", "CatRegressor"
results = tuner.optimize(X_train, y_train, X_val, y_val)
best_params = results['best_params']
model = xgb_clf(best_params)
model.fit(X_train, y_train)

Model Evaluation

model.evaluate(X_val, y_val)
[lunax]> label information:
+---------+---------+
|   label |   count |
+=========+=========+
|       1 |     319 |
+---------+---------+
|       0 |     119 |
+---------+---------+
[lunax]> model evaluation results:
+-----------+------------+-------------+----------+------+
| metrics   |   accuracy |   precision |   recall |   f1 |
+===========+============+=============+==========+======+
| values    |       0.73 |        0.53 |     0.73 | 0.61 |
+-----------+------------+-------------+----------+------+

Ensemble Learning

from lunax.ensembles import HillClimbingEnsemble
model1 = xgb_clf()
model2 = lgbm_clf()
model3 = cat_clf()
for model in [model1, model2, model3]:
    model.fit(X_train, y_train)
ensemble = HillClimbingEnsemble(
    models=[model1, model2, model3],
    metric=['auc'],
    maximize=True
)
best_weights = ensemble.fit(X_val, y_val)
predictions = ensemble.predict(df_test)

Explainable AI (XAI)

from lunax.xai import TreeExplainer
explainer = TreeExplainer(model)
explainer.plot_summary(X_val)
importance = explainer.get_feature_importance(X_val)
[lunax]> Clear blue/red separation indicates a highly influential feature.

[lunax]> Feature Importance Ranking:
+----+---------------+---------------------+
|    |    Feature    |     Importance      |
+----+---------------+---------------------+
| 1  |     cloud     | 2.3085615634918213  |
| 2  |   sunshine    | 0.6377484202384949  |
| 3  |   dewpoint    | 0.5257667899131775  |
| 4  |   humidity    | 0.4827548861503601  |
| 5  |   windspeed   | 0.40086665749549866 |
| 6  |      id       | 0.38620123267173767 |
| 7  |   pressure    | 0.3780971169471741  |
| 8  |    mintemp    | 0.32988569140434265 |
| 9  |      day      | 0.30587586760520935 |
| 10 |    maxtemp    | 0.26082852482795715 |
| 11 | winddirection | 0.23236176371574402 |
| 12 |  temparature  | 0.17218443751335144 |
+----+---------------+---------------------+

Test Set Prediction

df_test = load_data('test.csv')
df_test = preprocess_data(df_train,target)
y_pred = model.predict(df_test)
# y_pred_proba = model.predict_proba(X_test)

About

Lunax is a machine learning framework specifically designed for the processing and analysis of tabular data.

Topics

Resources

License

Stars

Watchers

Forks

Languages