Skip to content

Commit 9b187fc

Browse files
author
Max
committed
Fixed problem where ensemble could not be built
1 parent 9a1b04d commit 9b187fc

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

lkauto/ensemble/greedy_ensemble_selection.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,14 @@ def __init__(self, ensemble_size: int, lenskit_metric,
2929

3030
self.ensemble_size = ensemble_size
3131

32+
self.lenskit_metric = lenskit_metric()
33+
3234
if maximize_metric:
3335
def minimized_metric(y_ture, y_pred):
34-
return -lenskit_metric.measure_list(y_pred, y_ture)
36+
return -self.lenskit_metric.measure_list(y_pred, y_ture)
3537
else:
3638
def minimized_metric(y_ture, y_pred):
37-
return lenskit_metric.measure_list(y_pred, y_ture)
39+
return self.lenskit_metric.measure_list(y_pred, y_ture)
3840

3941
self.metric = minimized_metric
4042

@@ -131,15 +133,12 @@ def _fast(self, predictions: List[np.ndarray], labels: np.ndarray) -> None:
131133
labels_df.insert(0, "item_id", labels_df.index)
132134

133135
fant_ensemble_prediction_df = pd.DataFrame(fant_ensemble_prediction)
134-
fant_ensemble_prediction_df.columns = ["rating"]
136+
fant_ensemble_prediction_df.columns = ["score"]
135137
fant_ensemble_prediction_df.insert(0, "item_id", fant_ensemble_prediction_df.index)
136138

137139
labels_il = ItemList.from_df(labels_df)
138140
fant_ensemble_prediction_il = ItemList.from_df(fant_ensemble_prediction_df)
139141

140-
print("!!! labels: \n", labels_df)
141-
print("!!! fant_ensemble_prediction: \n", fant_ensemble_prediction_df)
142-
143142
losses[j] = self.metric(labels_il, fant_ensemble_prediction_il)
144143

145144
all_best = np.argwhere(losses == np.nanmin(losses)).flatten()

0 commit comments

Comments
 (0)