1
- git import logging
1
+ import logging
2
2
3
3
import pandas as pd
4
4
from lkauto .preprocessing .pruning import min_ratings_per_user , max_ratings_per_user
5
+ from lenskit .data import Dataset
6
+ from lenskit .data import from_interactions_df
5
7
6
-
7
- def preprocess_data (data : pd .DataFrame ,
8
+ def preprocess_data (data : Dataset , #data: pd.DataFrame
8
9
user_col : str ,
9
10
item_col : str ,
10
11
rating_col : str = None ,
@@ -13,7 +14,7 @@ def preprocess_data(data: pd.DataFrame,
13
14
drop_na_values : bool = True ,
14
15
drop_duplicates : bool = True ,
15
16
min_interactions_per_user : int = None ,
16
- max_interactions_per_user : int = None ) -> pd . DataFrame :
17
+ max_interactions_per_user : int = None ) -> Dataset :
17
18
"""Preprocess data for LensKit
18
19
This method can perform the following steps based on the user input:
19
20
1. rename columns to "user", "item", "rating", "timestamp"
@@ -54,6 +55,10 @@ def preprocess_data(data: pd.DataFrame,
54
55
logger = logging .getLogger ('lenskit-auto' )
55
56
logger .info ('--Start Preprocessing--' )
56
57
58
+ data = data .interaction_table (format = 'pandas' )
59
+ original_cols = data .columns .tolist ()
60
+ # print(original_cols)
61
+
57
62
# rename columns
58
63
if include_timestamp :
59
64
if rating_col is None :
@@ -92,4 +97,4 @@ def preprocess_data(data: pd.DataFrame,
92
97
93
98
logger .info ('--End Preprocessing--' )
94
99
95
- return data
100
+ return from_interactions_df ( data )
0 commit comments