|
40 | 40 | SemiSupervisedHeatmapTrackerMHCRNN,
|
41 | 41 | SemiSupervisedRegressionTracker,
|
42 | 42 | )
|
| 43 | +from lightning_pose.models.base import ( |
| 44 | + _apply_defaults_for_lr_scheduler_params, |
| 45 | + _apply_defaults_for_optimizer_params, |
| 46 | +) |
43 | 47 | from lightning_pose.utils import io as io_utils
|
44 | 48 | from lightning_pose.utils.pca import KeypointPCA
|
45 | 49 |
|
@@ -345,15 +349,15 @@ def get_model(
|
345 | 349 | """Create model: regression or heatmap based, supervised or semi-supervised."""
|
346 | 350 |
|
347 | 351 | optimizer = cfg.training.get("optimizer", "Adam")
|
348 |
| - optimizer_params = cfg.training.get( |
349 |
| - "optimizer_params", |
350 |
| - OmegaConf.create({"learning_rate": 1e-4}), |
| 352 | + optimizer_params = _apply_defaults_for_optimizer_params( |
| 353 | + optimizer, |
| 354 | + cfg.training.get("optimizer_params"), |
351 | 355 | )
|
352 | 356 |
|
353 |
| - lr_scheduler = cfg.training.lr_scheduler |
354 |
| - |
355 |
| - lr_scheduler_params = OmegaConf.to_object( |
356 |
| - cfg.training.lr_scheduler_params[lr_scheduler] |
| 357 | + lr_scheduler = cfg.training.get("lr_scheduler", "multisteplr") |
| 358 | + lr_scheduler_params = _apply_defaults_for_lr_scheduler_params( |
| 359 | + lr_scheduler, |
| 360 | + cfg.training.get("lr_scheduler_params", {}).get(f"{lr_scheduler}") |
357 | 361 | )
|
358 | 362 |
|
359 | 363 | semi_supervised = io_utils.check_if_semi_supervised(cfg.model.losses_to_use)
|
|
0 commit comments