Skip to content

Commit cbd6a44

Browse files
author
Max
committed
Now saves the trained recommender model
1 parent 9b187fc commit cbd6a44

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

lkauto/ensemble/ensemble_builder.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import pandas as pd
22
import numpy as np
3+
from lenskit import Pipeline
34
from lenskit.data import Dataset
5+
from typing import List
6+
7+
from lenskit.pipeline import predict_pipeline
48

59
from lkauto.utils.filer import Filer
610
from lkauto.utils.get_model_from_cs import get_model_from_cs
@@ -40,7 +44,8 @@ def build_ensemble(train: Dataset,
4044

4145
es = EnsembleSelection(ensemble_size=ensemble_size, lenskit_metric=lenskit_metric, maximize_metric=maximize_metric)
4246
es.ensemble_fit(ensemble_x, ensemble_y)
43-
es.base_models = [get_model_from_cs(cs, feedback='explicit') for cs, weight in zip(bm_cs_list, es.weights_) if weight > 0]
47+
es.base_models_tmp = [get_model_from_cs(cs, feedback='explicit') for cs, weight in zip(bm_cs_list, es.weights_) if weight > 0]
48+
es.base_models = models_to_pipelines(base_models=es.base_models_tmp)
4449
es.old_to_new_idx = {old_i: new_i for new_i, old_i in enumerate([idx for idx, weight in enumerate(es.weights_) if weight > 0])}
4550

4651
incumbent = {"model": str(es),
@@ -51,3 +56,12 @@ def build_ensemble(train: Dataset,
5156
"weights": list(es.weights_)}
5257

5358
return es, incumbent
59+
60+
def models_to_pipelines(base_models) -> List[Pipeline]:
61+
pipelines = []
62+
63+
for model in base_models:
64+
pipeline = predict_pipeline(model)
65+
pipelines.append(pipeline)
66+
67+
return pipelines

lkauto/lkauto.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,12 +219,12 @@ def get_best_prediction_model(train: Dataset,
219219
logger.info('GES Ensemble Model')
220220
else:
221221
# build model from best model configuration found by SMAC
222-
model = get_model_from_cs(incumbent, feedback='explicit')
222+
# model = get_model_from_cs(incumbent, feedback='explicit')
223223
incumbent = incumbent.get_dictionary()
224224
logger.info('--Best Model--')
225225
logger.info(incumbent)
226226

227-
if save:
227+
if save and ensemble_size == 1:
228228
filer.save_model(model)
229229
filer.save_incumbent(incumbent)
230230
logger.info('Saved model and incumbent to ' + filer.output_directory_path)

0 commit comments

Comments
 (0)