Skip to content

Commit 266f3d9

Browse files
authored
Upgrade some union, list, dict annotations to 3.10 (#239)
* Upgrade some union, list, dict annotations to 3.10 * specify 3.10 as minimum python version * fixup setup.py * fixup predictions.py * remove more list, union * remove int
1 parent 1496b1f commit 266f3d9

26 files changed

+306
-316
lines changed

cli/friendly.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import argparse
22
import shutil
33
import sys
4-
from typing import List
54

65

76
class ArgumentParser(argparse.ArgumentParser):
@@ -39,14 +38,14 @@ def __init__(self, *args, **kwargs):
3938
class _HelpFormatter(argparse.HelpFormatter):
4039
"""Modifications on help text formatting for easier readability."""
4140

42-
def _split_lines(self, text: str, width: int) -> List[str]:
41+
def _split_lines(self, text: str, width: int) -> list[str]:
4342
"""Modified to preserve newlines and long words."""
4443
# First split into paragraphs, then wrap each separately:
4544
# https://docs.python.org/3/library/textwrap.html#textwrap.TextWrapper.replace_whitespace
4645
paragraphs = text.splitlines()
4746
import textwrap
4847

49-
lines: List[str] = []
48+
lines: list[str] = []
5049
for p in paragraphs:
5150
p_lines = textwrap.wrap(
5251
p, width, break_long_words=False, break_on_hyphens=False

lightning_pose/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
__version__ = "1.6.1"
1+
__version__ = "1.7.0"
22

33
from pathlib import Path
4+
45
from omegaconf import OmegaConf
56

67
LP_ROOT_PATH = (Path(__file__).parent.parent).absolute()

lightning_pose/apps/utils.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
from collections import defaultdict
55
from pathlib import Path
6-
from typing import Dict, List, Tuple
6+
from typing import Tuple
77

88
import pandas as pd
99
import streamlit as st
@@ -16,7 +16,7 @@
1616

1717

1818
@st.cache_resource
19-
def update_labeled_file_list(model_preds_folders: List[str], use_ood: bool = False) -> List[list]:
19+
def update_labeled_file_list(model_preds_folders: list[str], use_ood: bool = False) -> list[list]:
2020
per_model_preds = []
2121
for model_pred_folder in model_preds_folders:
2222
# pull labeled results from each model folder
@@ -42,9 +42,9 @@ def update_labeled_file_list(model_preds_folders: List[str], use_ood: bool = Fal
4242
@st.cache_resource
4343
def update_vid_metric_files_list(
4444
video: str,
45-
model_preds_folders: List[str],
45+
model_preds_folders: list[str],
4646
video_subdir: str = "video_preds",
47-
) -> List[list]:
47+
) -> list[list]:
4848
per_vid_preds = []
4949
for model_preds_folder in model_preds_folders:
5050
# pull each prediction file associated with a particular video
@@ -65,7 +65,7 @@ def update_vid_metric_files_list(
6565

6666

6767
@st.cache_resource
68-
def get_all_videos(model_preds_folders: List[str], video_subdir: str = "video_preds") -> list:
68+
def get_all_videos(model_preds_folders: list[str], video_subdir: str = "video_preds") -> list:
6969
# find each video that is predicted on by the models
7070
# wrap in Path so that it looks like an UploadedFile object
7171
# returned by streamlit's file_uploader
@@ -89,7 +89,7 @@ def get_all_videos(model_preds_folders: List[str], video_subdir: str = "video_pr
8989

9090

9191
@st.cache_data
92-
def concat_dfs(dframes: Dict[str, pd.DataFrame]) -> Tuple[pd.DataFrame, List[str]]:
92+
def concat_dfs(dframes: dict[str, pd.DataFrame]) -> Tuple[pd.DataFrame, list[str]]:
9393
counter = 0
9494
for model_name, dframe in dframes.items():
9595
mask = dframe.columns.get_level_values("coords").isin(["x", "y", "likelihood"])
@@ -143,7 +143,7 @@ def get_df_scatter(
143143
return pd.concat(df_scatters)
144144

145145

146-
def get_col_names(keypoint: str, coordinate: str, models: List[str]) -> List[str]:
146+
def get_col_names(keypoint: str, coordinate: str, models: list[str]) -> list[str]:
147147
return [get_full_name(keypoint, coordinate, model) for model in models]
148148

149149

@@ -162,7 +162,7 @@ def get_full_name(keypoint: str, coordinate: str, model: str) -> str:
162162
# ----------------------------------------------
163163
@st.cache_data
164164
def build_precomputed_metrics_df(
165-
dframes: Dict[str, pd.DataFrame], keypoint_names: List[str], **kwargs,
165+
dframes: dict[str, pd.DataFrame], keypoint_names: list[str], **kwargs,
166166
) -> dict:
167167
concat_dfs = defaultdict(list)
168168
for model_name, df_dict in dframes.items():
@@ -194,7 +194,7 @@ def build_precomputed_metrics_df(
194194

195195
@st.cache_data
196196
def get_precomputed_error(
197-
df: pd.DataFrame, keypoint_names: List[str], model_name: str,
197+
df: pd.DataFrame, keypoint_names: list[str], model_name: str,
198198
) -> pd.DataFrame:
199199
# collect results
200200
df_ = df
@@ -206,7 +206,7 @@ def get_precomputed_error(
206206

207207
@st.cache_data
208208
def compute_confidence(
209-
df: pd.DataFrame, keypoint_names: List[str], model_name: str, **kwargs,
209+
df: pd.DataFrame, keypoint_names: list[str], model_name: str, **kwargs,
210210
) -> pd.DataFrame:
211211

212212
if df.shape[1] % 3 == 1:
@@ -236,7 +236,7 @@ def get_model_folders(
236236
model_dir: str,
237237
require_predictions: bool = True,
238238
require_tb_logs: bool = False,
239-
) -> List[str]:
239+
) -> list[str]:
240240
"""Find all model folders in a higher-level directory, conditional on directory contents."""
241241
# strip trailing slash if present
242242
if model_dir[-1] == os.sep:
@@ -260,7 +260,7 @@ def get_model_folders(
260260

261261

262262
# just to get the last two levels of the path
263-
def get_model_folders_vis(model_folders: List[str]) -> List[str]:
263+
def get_model_folders_vis(model_folders: list[str]) -> list[str]:
264264
fs = []
265265
for f in model_folders:
266266
fs.append(f.split("/")[-2:])

lightning_pose/data/dali.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Data pipelines based on efficient video reading by nvidia dali package."""
22

33
import os
4-
from typing import Dict, List, Literal, Optional, Union
4+
from typing import Literal
55

66
import numpy as np
77
import nvidia.dali.fn as fn
@@ -25,14 +25,14 @@
2525
# cannot typecheck due to way pipeline_def decorator consumes additional args
2626
@pipeline_def
2727
def video_pipe(
28-
filenames: Union[List[str], str],
29-
resize_dims: Optional[List[int]] = None,
28+
filenames: list[str] | str,
29+
resize_dims: list[int] | None = None,
3030
random_shuffle: bool = False,
3131
sequence_length: int = 16,
3232
pad_sequences: bool = True,
3333
initial_fill: int = 16,
34-
normalization_mean: List[float] = _IMAGENET_MEAN,
35-
normalization_std: List[float] = _IMAGENET_STD,
34+
normalization_mean: list[float] = _IMAGENET_MEAN,
35+
normalization_std: list[float] = _IMAGENET_STD,
3636
name: str = "reader",
3737
step: int = 1,
3838
pad_last_batch: bool = False,
@@ -174,7 +174,7 @@ def __len__(self) -> int:
174174
@staticmethod
175175
def _dali_output_to_tensors(
176176
batch: list
177-
) -> Union[UnlabeledBatchDict, MultiviewUnlabeledBatchDict]:
177+
) -> UnlabeledBatchDict | MultiviewUnlabeledBatchDict:
178178

179179
# always batch_size=1
180180

@@ -229,7 +229,7 @@ def _dali_output_to_tensors(
229229
frames=frames, transforms=transforms, bbox=bbox, is_multiview=True,
230230
)
231231

232-
def __next__(self) -> Union[UnlabeledBatchDict, MultiviewUnlabeledBatchDict]:
232+
def __next__(self) -> UnlabeledBatchDict | MultiviewUnlabeledBatchDict:
233233
batch = super().__next__()
234234
return self._dali_output_to_tensors(batch=batch)
235235

@@ -245,10 +245,10 @@ def __init__(
245245
self,
246246
train_stage: Literal["predict", "train"],
247247
model_type: Literal["base", "context"],
248-
filenames: Union[List[str], List[List[str]]],
249-
resize_dims: List[int],
250-
dali_config: Union[dict, DictConfig] = None,
251-
imgaug: Optional[str] = "default",
248+
filenames: list[str] | list[list[str]],
249+
resize_dims: list[int],
250+
dali_config: dict | DictConfig = None,
251+
imgaug: str | None = "default",
252252
num_threads: int = 1,
253253
) -> None:
254254

@@ -315,9 +315,9 @@ def num_iters(self) -> int:
315315

316316
def _setup_pipe_dict(
317317
self,
318-
filenames: Union[List[str], List[List[str]]],
318+
filenames: list[str] | list[list[str]],
319319
imgaug: str,
320-
) -> Dict[str, dict]:
320+
) -> dict[str, dict]:
321321
"""All of the pipeline args in one place."""
322322
# When running with multiple GPUs, the LOCAL_RANK variable correctly
323323
# contains the DDP Local Rank, which is also the cuda device index.

lightning_pose/data/datamodules.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Data modules split a dataset into train, val, and test modules."""
22

33
import copy
4-
from typing import List, Literal, Optional, Union
4+
from typing import Literal
55

66
import imgaug.augmenters as iaa
77
import lightning.pytorch as pl
@@ -36,9 +36,9 @@ def __init__(
3636
test_batch_size: int = 1,
3737
num_workers: int = 8,
3838
train_probability: float = 0.8,
39-
val_probability: Optional[float] = None,
40-
test_probability: Optional[float] = None,
41-
train_frames: Optional[Union[float, int]] = None,
39+
val_probability: float | None = None,
40+
test_probability: float | None = None,
41+
train_frames: float | int | None = None,
4242
torch_seed: int = 42,
4343
) -> None:
4444
"""Data module splits a dataset into train, val, and test data loaders.
@@ -75,7 +75,7 @@ def __init__(
7575
self.test_dataset = None # populated by self.setup()
7676
self.torch_seed = torch_seed
7777

78-
def setup(self, stage: Optional[str] = None) -> None: # stage arg needed for ptl
78+
def setup(self, stage: str | None = None) -> None: # stage arg needed for ptl
7979

8080
datalen = self.dataset.__len__()
8181
print(f"Number of labeled images in the full dataset (train+val+test): {datalen}")
@@ -170,17 +170,17 @@ class UnlabeledDataModule(BaseDataModule):
170170
def __init__(
171171
self,
172172
dataset: torch.utils.data.Dataset,
173-
video_paths_list: Union[List[str], str],
174-
dali_config: Union[dict, DictConfig],
175-
view_names: Optional[List[str]] = None,
173+
video_paths_list: list[str] | str,
174+
dali_config: dict | DictConfig,
175+
view_names: list[str] | None = None,
176176
train_batch_size: int = 16,
177177
val_batch_size: int = 16,
178178
test_batch_size: int = 1,
179179
num_workers: int = 8,
180180
train_probability: float = 0.8,
181-
val_probability: Optional[float] = None,
182-
test_probability: Optional[float] = None,
183-
train_frames: Optional[float] = None,
181+
val_probability: float | None = None,
182+
test_probability: float | None = None,
183+
train_frames: float | None = None,
184184
torch_seed: int = 42,
185185
imgaug: Literal["default", "dlc", "dlc-top-down"] = "default",
186186
) -> None:

lightning_pose/data/datasets.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import os
55
from pathlib import Path
6-
from typing import Callable, List, Literal, Optional, Tuple, Union
6+
from typing import Callable, List, Literal, Tuple
77

88
import numpy as np
99
import pandas as pd
@@ -36,8 +36,8 @@ def __init__(
3636
self,
3737
root_directory: str,
3838
csv_path: str,
39-
header_rows: Optional[List[int]] = [0, 1, 2],
40-
imgaug_transform: Optional[Callable] = None,
39+
header_rows: list[int] | None = [0, 1, 2],
40+
imgaug_transform: Callable | None = None,
4141
do_context: bool = False,
4242
) -> None:
4343
"""Initialize a dataset for regression (rather than heatmap) models.
@@ -193,8 +193,8 @@ def __init__(
193193
self,
194194
root_directory: str,
195195
csv_path: str,
196-
header_rows: Optional[List[int]] = [0, 1, 2],
197-
imgaug_transform: Optional[Callable] = None,
196+
header_rows: list[int] | None = [0, 1, 2],
197+
imgaug_transform: Callable | None = None,
198198
downsample_factor: Literal[1, 2, 3] = 2,
199199
do_context: bool = False,
200200
uniform_heatmaps: bool = False,
@@ -308,13 +308,13 @@ class MultiviewHeatmapDataset(torch.utils.data.Dataset):
308308
def __init__(
309309
self,
310310
root_directory: str,
311-
csv_paths: List[str],
312-
view_names: List[str],
313-
header_rows: Optional[List[int]] = [0, 1, 2],
311+
csv_paths: list[str],
312+
view_names: list[str],
313+
header_rows: list[int] | None = [0, 1, 2],
314314
downsample_factor: Literal[1, 2, 3] = 2,
315315
uniform_heatmaps: bool = False,
316316
do_context: bool = False,
317-
imgaug_transform: Optional[Callable] = None,
317+
imgaug_transform: Callable | None = None,
318318
) -> None:
319319
"""Initialize the MultiViewHeatmap Dataset.
320320
@@ -424,10 +424,12 @@ def output_shape(self) -> tuple:
424424
def num_views(self) -> int:
425425
return len(self.view_names)
426426

427-
def fusion(self, datadict: dict) -> Tuple[
428-
Union[
429-
TensorType["num_views", "RGB":3, "image_height", "image_width", float],
430-
TensorType["num_views", "frames", "RGB":3, "image_height", "image_width", float]
427+
def fusion(
428+
self, datadict: dict
429+
) -> Tuple[
430+
TensorType["num_views", "RGB":3, "image_height", "image_width", float]
431+
| TensorType[
432+
"num_views", "frames", "RGB":3, "image_height", "image_width", float
431433
],
432434
TensorType["keypoints"],
433435
TensorType["num_views", "heatmap_height", "heatmap_width", float],
@@ -489,6 +491,6 @@ def __getitem__(self, idx: int) -> MultiviewHeatmapLabeledExampleDict:
489491
bbox=bboxes,
490492
idxs=idx,
491493
num_views=self.num_views, # int
492-
concat_order=concat_order, # List[str]
493-
view_names=self.view_names, # List[str]
494+
concat_order=concat_order, # list[str]
495+
view_names=self.view_names, # list[str]
494496
)

0 commit comments

Comments
 (0)