Skip to content

Commit b8cb8f4

Browse files
v1.3.0 updates (#149)
* generalize method of accessing likelihood in apps * generalize subdirectory names used by video app * move version def from setup.py to __init__ * refactor train_hydra script to import a train function for greater flexibility
1 parent cd10c87 commit b8cb8f4

File tree

8 files changed

+377
-295
lines changed

8 files changed

+377
-295
lines changed

lightning_pose/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = "1.3.0"

lightning_pose/apps/utils.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
@st.cache_resource
19-
def update_labeled_file_list(model_preds_folders: list, use_ood: bool = False):
19+
def update_labeled_file_list(model_preds_folders: List[str], use_ood: bool = False) -> List[list]:
2020
per_model_preds = []
2121
for model_pred_folder in model_preds_folders:
2222
# pull labeled results from each model folder
@@ -40,15 +40,20 @@ def update_labeled_file_list(model_preds_folders: list, use_ood: bool = False):
4040

4141

4242
@st.cache_resource
43-
def update_vid_metric_files_list(video: str, model_preds_folders: list):
43+
def update_vid_metric_files_list(
44+
video: str,
45+
model_preds_folders: List[str],
46+
video_subdir: str = "video_preds",
47+
) -> List[list]:
4448
per_vid_preds = []
4549
for model_preds_folder in model_preds_folders:
4650
# pull each prediction file associated with a particular video
4751
# wrap in Path so that it looks like an UploadedFile object
52+
video_dir = os.path.join(model_preds_folder, video_subdir)
53+
if not os.path.isdir(video_dir):
54+
continue
4855
model_preds = [
49-
f
50-
for f in os.listdir(os.path.join(model_preds_folder, "video_preds"))
51-
if os.path.isfile(os.path.join(model_preds_folder, "video_preds", f))
56+
f for f in os.listdir(video_dir) if os.path.isfile(os.path.join(video_dir, f))
5257
]
5358
ret_files = []
5459
for file in model_preds:
@@ -59,16 +64,17 @@ def update_vid_metric_files_list(video: str, model_preds_folders: list):
5964

6065

6166
@st.cache_resource
62-
def get_all_videos(model_preds_folders: list):
67+
def get_all_videos(model_preds_folders: List[str], video_subdir: str = "video_preds") -> list:
6368
# find each video that is predicted on by the models
6469
# wrap in Path so that it looks like an UploadedFile object
6570
# returned by streamlit's file_uploader
6671
ret_videos = set()
6772
for model_preds_folder in model_preds_folders:
73+
video_dir = os.path.join(model_preds_folder, video_subdir)
74+
if not os.path.isdir(video_dir):
75+
continue
6876
model_preds = [
69-
f
70-
for f in os.listdir(os.path.join(model_preds_folder, "video_preds"))
71-
if os.path.isfile(os.path.join(model_preds_folder, "video_preds", f))
77+
f for f in os.listdir(video_dir) if os.path.isfile(os.path.join(video_dir, f))
7278
]
7379
for file in model_preds:
7480
if "temporal" in file:
@@ -97,7 +103,7 @@ def concat_dfs(dframes: Dict[str, pd.DataFrame]) -> Tuple[pd.DataFrame, List[str
97103

98104

99105
@st.cache_data
100-
def get_df_box(df_orig, keypoint_names, model_names):
106+
def get_df_box(df_orig: pd.DataFrame, keypoint_names: list, model_names: list) -> pd.DataFrame:
101107
df_boxes = []
102108
for keypoint in keypoint_names:
103109
for model_curr in model_names:
@@ -112,7 +118,13 @@ def get_df_box(df_orig, keypoint_names, model_names):
112118

113119

114120
@st.cache_data
115-
def get_df_scatter(df_0, df_1, data_type, model_names, keypoint_names):
121+
def get_df_scatter(
122+
df_0: pd.DataFrame,
123+
df_1: pd.DataFrame,
124+
data_type: str,
125+
model_names: list,
126+
keypoint_names: list
127+
) -> pd.DataFrame:
116128
df_scatters = []
117129
for keypoint in keypoint_names:
118130
df_scatters.append(
@@ -147,7 +159,7 @@ def get_full_name(keypoint: str, coordinate: str, model: str) -> str:
147159
# ----------------------------------------------
148160
@st.cache_data
149161
def build_precomputed_metrics_df(
150-
dframes: Dict[str, pd.DataFrame], keypoint_names: List[str], **kwargs
162+
dframes: Dict[str, pd.DataFrame], keypoint_names: List[str], **kwargs,
151163
) -> dict:
152164
concat_dfs = defaultdict(list)
153165
for model_name, df_dict in dframes.items():
@@ -179,7 +191,7 @@ def build_precomputed_metrics_df(
179191

180192
@st.cache_data
181193
def get_precomputed_error(
182-
df: pd.DataFrame, keypoint_names: List[str], model_name: str
194+
df: pd.DataFrame, keypoint_names: List[str], model_name: str,
183195
) -> pd.DataFrame:
184196
# collect results
185197
df_ = df
@@ -192,17 +204,17 @@ def get_precomputed_error(
192204

193205
@st.cache_data
194206
def compute_confidence(
195-
df: pd.DataFrame, keypoint_names: List[str], model_name: str, **kwargs
207+
df: pd.DataFrame, keypoint_names: List[str], model_name: str, **kwargs,
196208
) -> pd.DataFrame:
209+
197210
if df.shape[1] % 3 == 1:
198-
# get rid of "set" column if present
199-
tmp = df.iloc[:, :-1].to_numpy().reshape(df.shape[0], -1, 3)
211+
# collect "set" column if present
200212
set = df.iloc[:, -1].to_numpy()
201213
else:
202-
tmp = df.to_numpy().reshape(df.shape[0], -1, 3)
203214
set = None
204215

205-
results = tmp[:, :, 2]
216+
mask = df.columns.get_level_values("coords").isin(["likelihood"])
217+
results = df.loc[:, mask].to_numpy()
206218

207219
# collect results
208220
df_ = pd.DataFrame(columns=keypoint_names)
@@ -219,7 +231,7 @@ def compute_confidence(
219231

220232
# ------------ utils related to model finding in dir ---------
221233
# write a function that finds all model folders in the model_dir
222-
def get_model_folders(model_dir):
234+
def get_model_folders(model_dir: str) -> List[str]:
223235
# strip trailing slash if present
224236
if model_dir[-1] == os.sep:
225237
model_dir = model_dir[:-1]
@@ -232,7 +244,7 @@ def get_model_folders(model_dir):
232244

233245

234246
# just to get the last two levels of the path
235-
def get_model_folders_vis(model_folders):
247+
def get_model_folders_vis(model_folders: List[str]) -> List[str]:
236248
fs = []
237249
for f in model_folders:
238250
fs.append(f.split("/")[-2:])

lightning_pose/apps/video_diagnostics.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333

3434
def run():
35+
3536
args = parser.parse_args()
3637

3738
st.title("Video Diagnostics")
@@ -53,23 +54,19 @@ def run():
5354
# get the last two levels of each path to be presented to user
5455
model_folders_vis = get_model_folders_vis(model_folders)
5556

56-
selected_models_vis = st.sidebar.multiselect(
57-
"Select models", model_folders_vis, default=None
58-
)
57+
selected_models_vis = st.sidebar.multiselect("Select models", model_folders_vis, default=None)
5958

6059
# append this to full path
61-
selected_models = [
62-
"/" + os.path.join(args.model_dir, f) for f in selected_models_vis
63-
]
60+
selected_models = ["/" + os.path.join(args.model_dir, f) for f in selected_models_vis]
6461

6562
# ----- selecting videos to analyze -----
66-
all_videos_: list = get_all_videos(selected_models)
63+
all_videos_: list = get_all_videos(selected_models, video_subdir=args.video_subdir)
6764

6865
# choose from the different videos that were predicted
6966
video_to_plot = st.sidebar.selectbox("Select a video:", [*all_videos_], key="video")
7067

7168
prediction_files = update_vid_metric_files_list(
72-
video=video_to_plot, model_preds_folders=selected_models
69+
video=video_to_plot, model_preds_folders=selected_models, video_subdir=args.video_subdir,
7370
)
7471

7572
model_names = copy.copy(selected_models_vis)
@@ -100,9 +97,7 @@ def run():
10097
dframe = pd.read_csv(model_pred_file_path, index_col=None)
10198
dframes_metrics[model_name][str(model_pred_file)] = dframe
10299
else:
103-
dframe = pd.read_csv(
104-
model_pred_file_path, header=[1, 2], index_col=0
105-
)
100+
dframe = pd.read_csv(model_pred_file_path, header=[1, 2], index_col=0)
106101
dframes_traces[model_name] = dframe
107102
dframes_metrics[model_name]["confidence"] = dframe
108103
# data_types = dframe.iloc[:, -1].unique()
@@ -221,6 +216,7 @@ def run():
221216
parser = argparse.ArgumentParser()
222217

223218
parser.add_argument("--model_dir", type=str, default=[])
219+
parser.add_argument("--video_subdir", type=str, default="video_preds")
224220
parser.add_argument("--make_dir", action="store_true", default=False)
225221

226222
run()

0 commit comments

Comments
 (0)