Skip to content

Commit a11bb92

Browse files
author
Max
committed
Fixed imports in lkauto.py, changed one function in get_default_configuration_space.py to use Datasets
1 parent e943137 commit a11bb92

File tree

3 files changed

+22
-11
lines changed

3 files changed

+22
-11
lines changed

lkauto/lkauto.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33

44
from ConfigSpace import ConfigurationSpace
55

6-
from utils.get_model_from_cs import get_model_from_cs
7-
from optimization_strategies.bayesian_optimization import bayesian_optimization
8-
from optimization_strategies.random_search import random_search
9-
from utils.filer import Filer
10-
from ensemble.ensemble_builder import build_ensemble
11-
from preprocessing.preprocessing import preprocess_data
12-
from utils.logging import get_logger
6+
from lkauto.utils.get_model_from_cs import get_model_from_cs
7+
from lkauto.optimization_strategies.bayesian_optimization import bayesian_optimization
8+
from lkauto.optimization_strategies.random_search import random_search
9+
from lkauto.utils.filer import Filer
10+
from lkauto.ensemble.ensemble_builder import build_ensemble
11+
from lkauto.preprocessing.preprocessing import preprocess_data
12+
from lkauto.utils.logging import get_logger
1313

1414
from lenskit.metrics import RMSE
1515
from lenskit.metrics import NDCG

lkauto/utils/validation_split.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from typing import Iterator
2+
13
import pandas as pd
24
import numpy as np
35
# from lenskit.crossfold import partition_rows
4-
from lenskit.splitting import crossfold_records
5-
from lenskit.data import from_interactions_df
6+
from lenskit.splitting import crossfold_records, crossfold_users, SampleFrac, TTSplit
7+
from lenskit.data import from_interactions_df, Dataset
68

79

810
def validation_split(data: pd.DataFrame, strategie: str = 'user_based', num_folds: int = 1,
@@ -99,7 +101,8 @@ def row_based_validation_split(data: pd.DataFrame, num_folds: int = 1, frac: flo
99101
return fold_indices
100102

101103

102-
def user_based_validation_split(data: pd.DataFrame, num_folds: int = 1, frac: float = 0.25, random_state=42) -> dict:
104+
def user_based_validation_split(data: Dataset, num_folds: int = 1, frac: float = 0.25, random_state=42) -> Iterator[
105+
TTSplit]:
103106
"""
104107
Returns a dictionary with the indices of the train and validation split for the given data.
105108
The dictionary has the following structure:
@@ -131,6 +134,8 @@ def user_based_validation_split(data: pd.DataFrame, num_folds: int = 1, frac: fl
131134
dict
132135
dictionary with the indices of the train and validation split for the given data.
133136
"""
137+
138+
"""
134139
# initialize a dictionary with the indices of the train and validation split for the given data
135140
fold_indices = {i: {"train": np.array([]), "validation": np.array([])} for i in
136141
range(num_folds)}
@@ -150,6 +155,12 @@ def user_based_validation_split(data: pd.DataFrame, num_folds: int = 1, frac: fl
150155
num_folds=num_folds)
151156
152157
return fold_indices
158+
"""
159+
160+
splits = crossfold_users(data=data, partitions=num_folds, method=SampleFrac(0.25))
161+
162+
return splits
163+
153164

154165

155166
def __holdout_validation_split(fold_indices: dict, data: pd.DataFrame, frac: float, random_state=42):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"swig",
1717
"smac==2.3.1",
1818
"matplotlib~=3.6",
19-
"lenskit==2025.2.0",
19+
"lenskit==2025.1.1",
2020
"numpy>=2.0.0",
2121
"tables~=3.8",
2222
"typing~=3.5"

0 commit comments

Comments
 (0)