Skip to content

Commit 1d5418c

Browse files
bug fix in left-right/top-down augmentation pipeline (#287)
1 parent 1fba888 commit 1d5418c

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

lightning_pose/data/augmentations.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,15 @@ def imgaug_transform(params_dict: dict | DictConfig) -> iaa.Sequential:
6363
transform_args = args.get("args", ())
6464
transform_kwargs = args.get("kwargs", {})
6565

66-
# make sure any lists are converted to tuples; DictConfig cannot load tuples from yaml
67-
# files, but no iaa args are lists
66+
# DictConfig cannot load tuples from yaml files
67+
# make sure any lists are converted to tuples
68+
# unless the list contains a single item, then pass through the item (hack for Rot90)
6869
for kw, arg in transform_kwargs.items():
6970
if isinstance(arg, list) or isinstance(arg, ListConfig):
70-
transform_kwargs[kw] = tuple(arg)
71+
if len(arg) == 1:
72+
transform_kwargs[kw] = arg[0]
73+
else:
74+
transform_kwargs[kw] = tuple(arg)
7175

7276
# add transform to pipeline
7377
if apply_prob == 0.0:
@@ -91,13 +95,13 @@ def expand_imgaug_str_to_dict(params: str) -> dict[str, Any]:
9195
pass # no augmentations
9296
elif params in ["dlc", "dlc-lr", "dlc-top-down"]:
9397

94-
# flip horizontally
95-
if params in ["dlc-lr", "dlc-top-down"]:
96-
params_dict["Fliplr"] = {"p": 1.0, "kwargs": {"p": 0.5}}
98+
# rotate 0 or 180 degrees
99+
if params in ["dlc-lr"]:
100+
params_dict["Rot90"] = {"p": 1.0, "kwargs": {"k": [[0, 2]]}}
97101

98-
# flip vertically
102+
# rotate 0, 90, 180, or 270 degrees
99103
if params in ["dlc-top-down"]:
100-
params_dict["Flipud"] = {"p": 1.0, "kwargs": {"p": 0.5}}
104+
params_dict["Rot90"] = {"p": 1.0, "kwargs": {"k": [[0, 1, 2, 3]]}}
101105

102106
# rotate
103107
rotation = 25 # rotation uniformly sampled from (-rotation, +rotation)

tests/utils/test_scripts.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def test_get_imgaug_transform_dlc(cfg):
122122
assert pipe.__str__().find("Resize") == -1
123123
assert pipe.__str__().find("Fliplr") == -1
124124
assert pipe.__str__().find("Flipud") == -1
125+
assert pipe.__str__().find("Rot90") == -1
125126
assert pipe.__str__().find("Affine") != -1
126127
assert pipe.__str__().find("MotionBlur") != -1
127128
assert pipe.__str__().find("CoarseDropout") != -1
@@ -142,8 +143,9 @@ def test_get_imgaug_transform_dlc_lr(cfg):
142143
cfg_tmp.training.imgaug = "dlc-lr"
143144
pipe = get_imgaug_transform(cfg_tmp)
144145
assert pipe.__str__().find("Resize") == -1
145-
assert pipe.__str__().find("Fliplr") != -1
146+
assert pipe.__str__().find("Fliplr") == -1
146147
assert pipe.__str__().find("Flipud") == -1
148+
assert pipe.__str__().find("Rot90(name=UnnamedRot90, parameters=[Choice(a=[0, 2]") != -1
147149
assert pipe.__str__().find("Affine") != -1
148150
assert pipe.__str__().find("MotionBlur") != -1
149151
assert pipe.__str__().find("CoarseDropout") != -1
@@ -164,8 +166,9 @@ def test_get_imgaug_transform_dlc_top_down(cfg):
164166
cfg_tmp.training.imgaug = "dlc-top-down"
165167
pipe = get_imgaug_transform(cfg_tmp)
166168
assert pipe.__str__().find("Resize") == -1
167-
assert pipe.__str__().find("Fliplr") != -1
168-
assert pipe.__str__().find("Flipud") != -1
169+
assert pipe.__str__().find("Fliplr") == -1
170+
assert pipe.__str__().find("Flipud") == -1
171+
assert pipe.__str__().find("Rot90(name=UnnamedRot90, parameters=[Choice(a=[0, 1, 2, 3]") != -1
169172
assert pipe.__str__().find("Affine") != -1
170173
assert pipe.__str__().find("MotionBlur") != -1
171174
assert pipe.__str__().find("CoarseDropout") != -1

0 commit comments

Comments
 (0)