Skip to content

Commit 4b738a1

Browse files
committed
updated pruning.py
1 parent 7d62904 commit 4b738a1

File tree

1 file changed

+35
-31
lines changed

1 file changed

+35
-31
lines changed

lkauto/preprocessing/pruning.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,63 @@
1-
import pandas as pd
1+
# import pandas as pd
2+
from lenskit.data import Dataset,from_interactions_df
23

3-
4-
def min_ratings_per_user(df: pd.DataFrame, num_ratings: int, count_duplicates: bool = False):
4+
def min_ratings_per_user(dataset: Dataset, num_ratings: int, count_duplicates: bool = False):
55
"""Prune users with less than num_ratings ratings
66
77
Parameters
88
----------
9-
df: pd.DataFrame
10-
Dataframe with columns "user", "item", "rating"
9+
dataset: Dataset
10+
LensKit Dataset object containing user-item interactions with ratings
1111
num_ratings: int
1212
Minimum number of ratings per user
1313
count_duplicates: bool = False
1414
If True, all ratings are counted, otherwise only unique ratings are counted
1515
1616
Returns
1717
-------
18-
pd.DataFrame
19-
Dataframe with columns "user", "item", "rating"
18+
Dataset
19+
Filtered Dataset with only users meeting the minimum rating threshold
20+
the Dataset will contain the columns "user_id", "item_id", "rating"
2021
"""
21-
# get all relevant user_ids
22-
uids = (
23-
df['user']
24-
if count_duplicates
25-
else df.drop_duplicates(['user', 'item'])['user']
26-
)
27-
cnt_items_per_user = uids.value_counts()
28-
users_of_interest = list(cnt_items_per_user[cnt_items_per_user >= num_ratings].index)
22+
# get the user statistics from the dataset
23+
user_stats = dataset.user_stats()
24+
if count_duplicates:
25+
valid_users = user_stats[user_stats['count'] >= num_ratings].index # count: total number of ratings (including duplicates)
26+
else:
27+
valid_users = user_stats[user_stats['item_count'] >= num_ratings].index # item_count: number of unique items rated
28+
# convert the interaction table to a pandas DataFrame and filter by valid users
29+
users_of_interest = dataset.iteraction_table(format='pandas', original_ids=True)
30+
users_of_interest = users_of_interest[users_of_interest['user_id'].isin(valid_users)]
31+
return from_interactions_df(users_of_interest)
32+
2933

30-
return df[df['user'].isin(users_of_interest)]
3134

3235

33-
def max_ratings_per_user(df: pd.DataFrame, num_ratings: int, count_duplicates: bool = False):
36+
def max_ratings_per_user(dataset: Dataset, num_ratings: int, count_duplicates: bool = False):
3437
"""Prune users with more than num_ratings ratings
3538
3639
Parameters
3740
----------
38-
df: pd.DataFrame
39-
Dataframe with columns "user", "item", "rating"
41+
dataset: Dataset
42+
LensKit Dataset object containing user-item interactions with ratings
4043
num_ratings: int
4144
Minimum number of ratings per user
4245
count_duplicates: bool = False
4346
If True, all ratings are counted, otherwise only unique ratings are counted
4447
4548
Returns
4649
-------
47-
pd.DataFrame
48-
Dataframe with columns "user", "item", "rating"
50+
Dataset
51+
Filtered Dataset with only users meeting the minimum rating threshold
52+
the Dataset will contain the columns "user_id", "item_id", "rating"
4953
"""
50-
# get all relevant user_ids
51-
uids = (
52-
df['user']
53-
if count_duplicates
54-
else df.drop_duplicates(['user', 'item'])['user']
55-
)
56-
cnt_items_per_user = uids.value_counts()
57-
users_of_interest = list(cnt_items_per_user[cnt_items_per_user <= num_ratings].index)
58-
59-
return df[df['user'].isin(users_of_interest)]
54+
55+
user_stats = dataset.user_stats()
56+
if count_duplicates:
57+
valid_users = user_stats[user_stats['count'] <= num_ratings].index # count: total number of ratings (including duplicates)
58+
else:
59+
valid_users = user_stats[user_stats['item_count'] <= num_ratings].index # item_count: number of unique items rated
60+
# convert the interaction table to a pandas DataFrame and filter by valid users
61+
users_of_interest = dataset.iteraction_table(format='pandas', original_ids=True)
62+
users_of_interest = users_of_interest[users_of_interest['user_id'].isin(valid_users)]
63+
return from_interactions_df(users_of_interest)

0 commit comments

Comments
 (0)