Skip to content

Commit 809d9d9

Browse files
author
Max
committed
Updated ensemble_builder.py and greedy_ensemble_selection.py
1 parent 8ffdc5d commit 809d9d9

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

lkauto/ensemble/ensemble_builder.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
import pandas as pd
22
import numpy as np
3+
from lenskit.data import Dataset
4+
35
from lkauto.utils.filer import Filer
46
from lkauto.utils.get_model_from_cs import get_model_from_cs
57

68
from lkauto.ensemble.greedy_ensemble_selection import EnsembleSelection
79

810

9-
def build_ensemble(train: pd.DataFrame,
11+
def build_ensemble(train: Dataset,
1012
top_n_runs: pd.DataFrame,
1113
filer: Filer,
1214
ensemble_size: int,
1315
lenskit_metric,
1416
maximize_metric: bool):
1517
config_ids = top_n_runs.sort_values(by='error', ascending=True)['run_id']
16-
ensemble_y = train['rating']
17-
ensemble_X = []
18+
ensemble_y = train.interaction_table(format="pandas")['rating']
19+
ensemble_x = []
1820
val_indices = None
1921
bm_cs_list = []
2022

@@ -32,12 +34,12 @@ def build_ensemble(train: pd.DataFrame,
3234
val_indices = bm_pred[list(bm_pred)[0]]
3335

3436
# Append predictions to ensemble train X
35-
ensemble_X.append(np.array(bm_pred[list(bm_pred)[1]]))
37+
ensemble_x.append(np.array(bm_pred[list(bm_pred)[1]]))
3638

3739
ensemble_y = np.array(ensemble_y.loc[val_indices])
3840

3941
es = EnsembleSelection(ensemble_size=ensemble_size, lenskit_metric=lenskit_metric, maximize_metric=maximize_metric)
40-
es.ensemble_fit(ensemble_X, ensemble_y)
42+
es.ensemble_fit(ensemble_x, ensemble_y)
4143
es.base_models = [get_model_from_cs(cs, feedback='explicit') for cs, weight in zip(bm_cs_list, es.weights_) if weight > 0]
4244
es.old_to_new_idx = {old_i: new_i for new_i, old_i in enumerate([idx for idx, weight in enumerate(es.weights_) if weight > 0])}
4345

lkauto/ensemble/greedy_ensemble_selection.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import numpy as np
88
import pandas as pd
9+
from lenskit.data import Dataset
910

1011

1112
class EnsembleSelection:
@@ -39,13 +40,13 @@ def minimized_metric(y_ture, y_pred):
3940
# Will be filled later from external
4041
self.base_models = None
4142

42-
def fit(self, data: pd.DataFrame):
43+
def fit(self, data: Dataset):
4344
""" Fit base models (we assume the ensemble part, ensemble_fit, was already fitted here or is fitted later)
4445
4546
Parameters
4647
----------
47-
data: DataFrame
48-
Dataframe with columns "user", "item", "rating"
48+
data: Dataset
49+
Dataset with columns "user", "item", "rating"
4950
"""
5051
if self.base_models is None:
5152
raise ValueError("Base Models is None; we need a list of base models to fit them here!")
@@ -55,11 +56,11 @@ def fit(self, data: pd.DataFrame):
5556

5657
return self
5758

58-
def predict(self, X_data: pd.DataFrame):
59+
def predict(self, x_data: Dataset):
5960
"""
6061
"user", "item" Dataframe
6162
"""
62-
bm_preds = [bm.predict(X_data) for bm in self.base_models]
63+
bm_preds = [bm.predict(x_data) for bm in self.base_models]
6364
test_ind = bm_preds[0].index
6465
ens_predictions = self.ensemble_predict([np.array(bm_pred) for bm_pred in bm_preds])
6566

0 commit comments

Comments
 (0)