Skip to content

Commit e1475a6

Browse files
committed
updated preprocessing.py
1 parent 7d62904 commit e1475a6

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

lkauto/preprocessing/preprocessing.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
git import logging
1+
import logging
22

33
import pandas as pd
44
from lkauto.preprocessing.pruning import min_ratings_per_user, max_ratings_per_user
5+
from lenskit.data import Dataset
6+
from lenskit.data import from_interactions_df
57

6-
7-
def preprocess_data(data: pd.DataFrame,
8+
def preprocess_data(data: Dataset, #data: pd.DataFrame
89
user_col: str,
910
item_col: str,
1011
rating_col: str = None,
@@ -13,7 +14,7 @@ def preprocess_data(data: pd.DataFrame,
1314
drop_na_values: bool = True,
1415
drop_duplicates: bool = True,
1516
min_interactions_per_user: int = None,
16-
max_interactions_per_user: int = None) -> pd.DataFrame:
17+
max_interactions_per_user: int = None) -> Dataset:
1718
"""Preprocess data for LensKit
1819
This method can perform the following steps based on the user input:
1920
1. rename columns to "user", "item", "rating", "timestamp"
@@ -54,6 +55,10 @@ def preprocess_data(data: pd.DataFrame,
5455
logger = logging.getLogger('lenskit-auto')
5556
logger.info('--Start Preprocessing--')
5657

58+
data = data.interaction_table(format='pandas')
59+
original_cols = data.columns.tolist()
60+
# print(original_cols)
61+
5762
# rename columns
5863
if include_timestamp:
5964
if rating_col is None:
@@ -92,4 +97,4 @@ def preprocess_data(data: pd.DataFrame,
9297

9398
logger.info('--End Preprocessing--')
9499

95-
return data
100+
return from_interactions_df(data)

0 commit comments

Comments
 (0)