Skip to content

Commit 291f3e7

Browse files
update/put optimization defaults in one place (#288)
1 parent 1d5418c commit 291f3e7

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

lightning_pose/models/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
DEFAULT_LR_SCHEDULER_PARAMS = OmegaConf.create(
3535
{
36-
"milestones": [100, 200, 300],
36+
"milestones": [150, 200, 250],
3737
"gamma": 0.5,
3838
}
3939
)

lightning_pose/utils/scripts.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@
4040
SemiSupervisedHeatmapTrackerMHCRNN,
4141
SemiSupervisedRegressionTracker,
4242
)
43+
from lightning_pose.models.base import (
44+
_apply_defaults_for_lr_scheduler_params,
45+
_apply_defaults_for_optimizer_params,
46+
)
4347
from lightning_pose.utils import io as io_utils
4448
from lightning_pose.utils.pca import KeypointPCA
4549

@@ -345,15 +349,15 @@ def get_model(
345349
"""Create model: regression or heatmap based, supervised or semi-supervised."""
346350

347351
optimizer = cfg.training.get("optimizer", "Adam")
348-
optimizer_params = cfg.training.get(
349-
"optimizer_params",
350-
OmegaConf.create({"learning_rate": 1e-4}),
352+
optimizer_params = _apply_defaults_for_optimizer_params(
353+
optimizer,
354+
cfg.training.get("optimizer_params"),
351355
)
352356

353-
lr_scheduler = cfg.training.lr_scheduler
354-
355-
lr_scheduler_params = OmegaConf.to_object(
356-
cfg.training.lr_scheduler_params[lr_scheduler]
357+
lr_scheduler = cfg.training.get("lr_scheduler", "multisteplr")
358+
lr_scheduler_params = _apply_defaults_for_lr_scheduler_params(
359+
lr_scheduler,
360+
cfg.training.get("lr_scheduler_params", {}).get(f"{lr_scheduler}")
357361
)
358362

359363
semi_supervised = io_utils.check_if_semi_supervised(cfg.model.losses_to_use)

0 commit comments

Comments
 (0)