Skip to content

Commit a6ed589

Browse files
remove frozen requirements (#105)
* clean up requirements * fix typechecking errors introduced by unfreezing requirements * replace skimage op with cv2
1 parent 6135e5a commit a6ed589

File tree

11 files changed

+89
-56
lines changed

11 files changed

+89
-56
lines changed

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,15 @@ cd <SOME_FOLDER>
5757
git clone https://github.com/danbider/lightning-pose.git
5858
```
5959

60-
Then move into the newly-created repository folder, and install dependencies:
60+
Then move into the newly-created repository folder:
6161
```console
6262
cd lightning-pose
63-
pip install -e .
6463
```
64+
and install dependencies using one of the lines below that suits your needs best:
65+
* `pip install -e . `: basic installation, covers most use-cases (note the period!)
66+
* `pip install -e .[dev] `: basic install + dev tools
67+
* `pip install -e .[extra_models] `: basic install + tools for loading resnet-50 simclr weights
68+
* `pip install -e .[dev,extra_models] `: install all available requirements
6569

6670
This installation might take between 3-10 minutes, depending on your machine and internet connection.
6771

docs/contributing.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,18 @@ If you have found a bug or would like to request a minor change, please
66

77
In order to contribute code to the repo, please follow the steps below.
88

9-
You will also need to install the following dev packages:
9+
Whenever you initially install the lightning pose repo, instead of
1010
```bash
11-
pip install flake8 isort
11+
pip install -e .
12+
```
13+
run
14+
```bash
15+
pip install -e .[dev]
16+
```
17+
18+
Alternatively, if you have already installed the repo, install the following dev packages:
19+
```bash
20+
pip install black flake8 isort
1221
```
1322

1423
### Create a pull request

lightning_pose/models/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"resnet50",
77
"resnet101",
88
"resnet152",
9-
"resnet50_contrastive",
9+
"resnet50_contrastive", # needs extra install: pip install -e .[extra_models]
1010
"resnet50_animal_apose",
1111
"resnet50_animal_ap10k",
1212
"resnet50_human_jhmdb",
@@ -15,6 +15,6 @@
1515
"efficientnet_b0",
1616
"efficientnet_b1",
1717
"efficientnet_b2",
18-
"vit_h_sam",
19-
"vit_b_sam",
18+
# "vit_h_sam",
19+
# "vit_b_sam",
2020
]

lightning_pose/models/backbones/torchvision.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections import OrderedDict
2+
from typing import Tuple
23

34
import torch
45
import torchvision.models as tvmodels
@@ -28,8 +29,14 @@ def build_backbone(
2829

2930
if backbone_arch == "resnet50_contrastive":
3031
# load resnet50 pretrained using SimCLR on imagenet
31-
from pl_bolts.models.self_supervised import SimCLR
32-
32+
try:
33+
from pl_bolts.models.self_supervised import SimCLR
34+
except ImportError:
35+
raise Exception(
36+
"lightning-bolts package is not installed.\n"
37+
"Run `pip install lightning-bolts` "
38+
"in order to access 'resnet50_contrastive' backbone"
39+
)
3340
ckpt_url = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt" # noqa: E501
3441
simclr = SimCLR.load_from_checkpoint(ckpt_url, strict=False)
3542
base = simclr.encoder

lightning_pose/models/heatmap_tracker.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Models that produce heatmaps of keypoints from images."""
22

3-
from typing import Optional, Tuple, Union
3+
from typing import Dict, Optional, Tuple, Union
44

55
import torch
66
from kornia.filters import filter2d
@@ -9,6 +9,7 @@
99
from omegaconf import DictConfig
1010
from torch import nn
1111
from torchtyping import TensorType
12+
from typeguard import typechecked
1213
from typing_extensions import Literal
1314

1415
from lightning_pose.data.utils import (
@@ -349,6 +350,7 @@ def predict_step(
349350
return predicted_keypoints, confidence
350351

351352

353+
@typechecked
352354
class SemiSupervisedHeatmapTracker(SemiSupervisedTrackerMixin, HeatmapTracker):
353355
"""Model produces heatmaps of keypoints from labeled/unlabeled images."""
354356

@@ -411,7 +413,7 @@ def __init__(
411413
# self.register_buffer("total_unsupervised_importance", torch.tensor(1.0))
412414
self.total_unsupervised_importance = torch.tensor(1.0)
413415

414-
def get_loss_inputs_unlabeled(self, batch: UnlabeledBatchDict) -> dict:
416+
def get_loss_inputs_unlabeled(self, batch: UnlabeledBatchDict) -> Dict:
415417
"""Return predicted heatmaps and their softmaxes (estimated keypoints)."""
416418
# images -> heatmaps
417419
predicted_heatmaps = self.forward(batch["frames"])

lightning_pose/models/heatmap_tracker_mhcrnn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Models that produce heatmaps of keypoints from images."""
22

3-
from typing import Optional, Tuple, Union
3+
from typing import Dict, Optional, Tuple, Union
44

55
import torch
66
from kornia.geometry.subpix import spatial_softmax2d
@@ -259,7 +259,7 @@ def __init__(
259259
# self.register_buffer("total_unsupervised_importance", torch.tensor(1.0))
260260
self.total_unsupervised_importance = torch.tensor(1.0)
261261

262-
def get_loss_inputs_unlabeled(self, batch: UnlabeledBatchDict) -> dict:
262+
def get_loss_inputs_unlabeled(self, batch: UnlabeledBatchDict) -> Dict:
263263
"""Return predicted heatmaps and their softmaxes (estimated keypoints)."""
264264
# images -> heatmaps
265265
pred_heatmaps_crnn, pred_heatmaps_sf = self.forward(batch["frames"])

lightning_pose/models/regression_tracker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Models that produce (x, y) coordinates of keypoints from images."""
22

3-
from typing import Optional, Tuple, Union
3+
from typing import Dict, Optional, Tuple, Union
44

55
import torch
66
from omegaconf import DictConfig
@@ -204,7 +204,7 @@ def __init__(
204204
self.total_unsupervised_importance = torch.tensor(1.0)
205205
# self.register_buffer("total_unsupervised_importance", torch.tensor(1.0))
206206

207-
def get_loss_inputs_unlabeled(self, batch: UnlabeledBatchDict) -> dict:
207+
def get_loss_inputs_unlabeled(self, batch: UnlabeledBatchDict) -> Dict:
208208
"""Return predicted heatmaps and their softmaxes (estimated keypoints)."""
209209
predicted_keypoints = self.forward(batch["frames"])
210210
# undo augmentation if needed

lightning_pose/utils/io.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def get_videos_in_dir(video_dir: str, return_mp4_only: bool = True) -> List[str]
190190

191191

192192
@typechecked
193-
def check_video_paths(video_paths: Union[List[str], str]) -> list:
193+
def check_video_paths(video_paths: Union[List[str], str]) -> List[str]:
194194
# get input data
195195
if isinstance(video_paths, list):
196196
# presumably a list of files
@@ -203,8 +203,7 @@ def check_video_paths(video_paths: Union[List[str], str]) -> list:
203203
filenames = get_videos_in_dir(video_paths)
204204
else:
205205
raise ValueError(
206-
"`video_paths_list` must be a list of files, a single file, "
207-
+ "or a directory name"
206+
"`video_paths_list` must be a list of files, a single file, or a directory name"
208207
)
209208
for filename in filenames:
210209
assert filename.endswith(".mp4"), "video files must be mp4 format!"

lightning_pose/utils/predictions.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
import time
55
from typing import List, Optional, Tuple, Union
66

7+
import cv2
78
import lightning.pytorch as pl
89
import matplotlib.pyplot as plt
910
import numpy as np
1011
import pandas as pd
1112
import torch
1213
from omegaconf import DictConfig, OmegaConf
1314
from pytorch_lightning import LightningModule
14-
from skimage.draw import disk
1515
from torchtyping import TensorType
1616
from tqdm import tqdm
1717
from typeguard import typechecked
@@ -690,7 +690,7 @@ def create_labeled_video(
690690
xs_arr,
691691
ys_arr,
692692
mask_array=None,
693-
dotsize=5,
693+
dotsize=4,
694694
colormap="cool",
695695
fps=None,
696696
filename="movie.mp4",
@@ -719,14 +719,10 @@ def create_labeled_video(
719719
colors = make_cmap(n_keypoints, cmap=colormap)
720720

721721
nx, ny = clip.size
722-
duration = int(clip.duration - clip.start)
722+
dur = int(clip.duration - clip.start)
723723
fps_og = clip.fps
724724

725-
print(
726-
"Duration of video [s]: {}, recorded with {} fps!".format(
727-
np.round(duration, 2), np.round(fps_og, 2)
728-
)
729-
)
725+
print(f"Duration of video [s]: {np.round(dur, 2)}, recorded at {np.round(fps_og, 2)} fps!")
730726

731727
# add marker to each frame t, where t is in sec
732728
def add_marker(get_frame, t):
@@ -742,8 +738,13 @@ def add_marker(get_frame, t):
742738
if mask_array[index, bpindex]:
743739
xc = min(int(xs_arr[index, bpindex]), nx - 1)
744740
yc = min(int(ys_arr[index, bpindex]), ny - 1)
745-
rr, cc = disk(center=(yc, xc), radius=dotsize, shape=(ny, nx))
746-
frame[rr, cc, :] = colors[bpindex]
741+
frame = cv2.circle(
742+
frame,
743+
center=(xc, yc),
744+
radius=dotsize,
745+
color=colors[bpindex].tolist(),
746+
thickness=-1
747+
)
747748
return frame
748749

749750
clip_marked = clip.fl(add_marker)

lightning_pose/utils/scripts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Helper functions to build pipeline components from config dictionary."""
22

33
import os
4-
from typing import Dict, Optional, Union
4+
from typing import Dict, List, Optional, Union
55

66
import imgaug.augmenters as iaa
77
import lightning.pytorch as pl

0 commit comments

Comments
 (0)