16
16
17
17
18
18
@st .cache_resource
19
- def update_labeled_file_list (model_preds_folders : list , use_ood : bool = False ):
19
+ def update_labeled_file_list (model_preds_folders : List [ str ] , use_ood : bool = False ) -> List [ list ] :
20
20
per_model_preds = []
21
21
for model_pred_folder in model_preds_folders :
22
22
# pull labeled results from each model folder
@@ -40,15 +40,20 @@ def update_labeled_file_list(model_preds_folders: list, use_ood: bool = False):
40
40
41
41
42
42
@st .cache_resource
43
- def update_vid_metric_files_list (video : str , model_preds_folders : list ):
43
+ def update_vid_metric_files_list (
44
+ video : str ,
45
+ model_preds_folders : List [str ],
46
+ video_subdir : str = "video_preds" ,
47
+ ) -> List [list ]:
44
48
per_vid_preds = []
45
49
for model_preds_folder in model_preds_folders :
46
50
# pull each prediction file associated with a particular video
47
51
# wrap in Path so that it looks like an UploadedFile object
52
+ video_dir = os .path .join (model_preds_folder , video_subdir )
53
+ if not os .path .isdir (video_dir ):
54
+ continue
48
55
model_preds = [
49
- f
50
- for f in os .listdir (os .path .join (model_preds_folder , "video_preds" ))
51
- if os .path .isfile (os .path .join (model_preds_folder , "video_preds" , f ))
56
+ f for f in os .listdir (video_dir ) if os .path .isfile (os .path .join (video_dir , f ))
52
57
]
53
58
ret_files = []
54
59
for file in model_preds :
@@ -59,16 +64,17 @@ def update_vid_metric_files_list(video: str, model_preds_folders: list):
59
64
60
65
61
66
@st .cache_resource
62
- def get_all_videos (model_preds_folders : list ) :
67
+ def get_all_videos (model_preds_folders : List [ str ], video_subdir : str = "video_preds" ) -> list :
63
68
# find each video that is predicted on by the models
64
69
# wrap in Path so that it looks like an UploadedFile object
65
70
# returned by streamlit's file_uploader
66
71
ret_videos = set ()
67
72
for model_preds_folder in model_preds_folders :
73
+ video_dir = os .path .join (model_preds_folder , video_subdir )
74
+ if not os .path .isdir (video_dir ):
75
+ continue
68
76
model_preds = [
69
- f
70
- for f in os .listdir (os .path .join (model_preds_folder , "video_preds" ))
71
- if os .path .isfile (os .path .join (model_preds_folder , "video_preds" , f ))
77
+ f for f in os .listdir (video_dir ) if os .path .isfile (os .path .join (video_dir , f ))
72
78
]
73
79
for file in model_preds :
74
80
if "temporal" in file :
@@ -97,7 +103,7 @@ def concat_dfs(dframes: Dict[str, pd.DataFrame]) -> Tuple[pd.DataFrame, List[str
97
103
98
104
99
105
@st .cache_data
100
- def get_df_box (df_orig , keypoint_names , model_names ) :
106
+ def get_df_box (df_orig : pd . DataFrame , keypoint_names : list , model_names : list ) -> pd . DataFrame :
101
107
df_boxes = []
102
108
for keypoint in keypoint_names :
103
109
for model_curr in model_names :
@@ -112,7 +118,13 @@ def get_df_box(df_orig, keypoint_names, model_names):
112
118
113
119
114
120
@st .cache_data
115
- def get_df_scatter (df_0 , df_1 , data_type , model_names , keypoint_names ):
121
+ def get_df_scatter (
122
+ df_0 : pd .DataFrame ,
123
+ df_1 : pd .DataFrame ,
124
+ data_type : str ,
125
+ model_names : list ,
126
+ keypoint_names : list
127
+ ) -> pd .DataFrame :
116
128
df_scatters = []
117
129
for keypoint in keypoint_names :
118
130
df_scatters .append (
@@ -147,7 +159,7 @@ def get_full_name(keypoint: str, coordinate: str, model: str) -> str:
147
159
# ----------------------------------------------
148
160
@st .cache_data
149
161
def build_precomputed_metrics_df (
150
- dframes : Dict [str , pd .DataFrame ], keypoint_names : List [str ], ** kwargs
162
+ dframes : Dict [str , pd .DataFrame ], keypoint_names : List [str ], ** kwargs ,
151
163
) -> dict :
152
164
concat_dfs = defaultdict (list )
153
165
for model_name , df_dict in dframes .items ():
@@ -179,7 +191,7 @@ def build_precomputed_metrics_df(
179
191
180
192
@st .cache_data
181
193
def get_precomputed_error (
182
- df : pd .DataFrame , keypoint_names : List [str ], model_name : str
194
+ df : pd .DataFrame , keypoint_names : List [str ], model_name : str ,
183
195
) -> pd .DataFrame :
184
196
# collect results
185
197
df_ = df
@@ -192,17 +204,17 @@ def get_precomputed_error(
192
204
193
205
@st .cache_data
194
206
def compute_confidence (
195
- df : pd .DataFrame , keypoint_names : List [str ], model_name : str , ** kwargs
207
+ df : pd .DataFrame , keypoint_names : List [str ], model_name : str , ** kwargs ,
196
208
) -> pd .DataFrame :
209
+
197
210
if df .shape [1 ] % 3 == 1 :
198
- # get rid of "set" column if present
199
- tmp = df .iloc [:, :- 1 ].to_numpy ().reshape (df .shape [0 ], - 1 , 3 )
211
+ # collect "set" column if present
200
212
set = df .iloc [:, - 1 ].to_numpy ()
201
213
else :
202
- tmp = df .to_numpy ().reshape (df .shape [0 ], - 1 , 3 )
203
214
set = None
204
215
205
- results = tmp [:, :, 2 ]
216
+ mask = df .columns .get_level_values ("coords" ).isin (["likelihood" ])
217
+ results = df .loc [:, mask ].to_numpy ()
206
218
207
219
# collect results
208
220
df_ = pd .DataFrame (columns = keypoint_names )
@@ -219,7 +231,7 @@ def compute_confidence(
219
231
220
232
# ------------ utils related to model finding in dir ---------
221
233
# write a function that finds all model folders in the model_dir
222
- def get_model_folders (model_dir ) :
234
+ def get_model_folders (model_dir : str ) -> List [ str ] :
223
235
# strip trailing slash if present
224
236
if model_dir [- 1 ] == os .sep :
225
237
model_dir = model_dir [:- 1 ]
@@ -232,7 +244,7 @@ def get_model_folders(model_dir):
232
244
233
245
234
246
# just to get the last two levels of the path
235
- def get_model_folders_vis (model_folders ) :
247
+ def get_model_folders_vis (model_folders : List [ str ]) -> List [ str ] :
236
248
fs = []
237
249
for f in model_folders :
238
250
fs .append (f .split ("/" )[- 2 :])
0 commit comments