Skip to content

Commit d899e14

Browse files
authored
Regularization Selection Fixes (#79)
* regression test: parametric ROM + regselect * delay test case processing * fix unit tests related to test case TypeErrors * ROM.fit() calls model.fit(), not _fit_solver() * BayesianROM regselect: catch LinAlg errors and warnings * regselect: reset solver; fix regselect for interpolatory models * literature: Gkimisis 2025 * v0.5.15 -> v0.5.16
1 parent 6cb936a commit d899e14

File tree

15 files changed

+241
-94
lines changed

15 files changed

+241
-94
lines changed

docs/literature.bib

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,3 +666,14 @@ @article{kang2025semiconductor
666666
doi = {10.48550/arXiv.2504.03990},
667667
category = {application}
668668
}
669+
670+
@article{gkimisis2025spatiallylocalized,
671+
title = {Non-intrusive reduced-order modeling for dynamical systems with spatially localized features},
672+
author = {Leonidas Gkimisis and Nicole Aretz and Marco Tezzele and Thomas Richter and Peter Benner and Karen E. Willcox},
673+
journal = {Computer Methods in Applied Mechanics and Engineering},
674+
volume = {444},
675+
pages = {118115},
676+
year = {2025},
677+
doi = {10.1016/j.cma.2025.118115},
678+
category = {method}
679+
}

docs/source/opinf/changelog.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,17 @@
55
New versions may introduce substantial new features or API adjustments.
66
:::
77

8+
## Version 0.5.16
9+
10+
Backend improvements to the regularization selection procedure.
11+
12+
- `ROM.fit()` calls `ROM.model.fit()` instead of `ROM.model._fit_solver()`, which works better for inheritance.
13+
- `BayesianROM.fit_regselect_*()` catches errors and suppresses warnings from the least-squares solver.
14+
- Fixed a bug for interpolatory parametric models where `fit_regselect_*()` would only update the regularizer for the overall `model.solver`, not for each of the individual solvers for the models to be interpolated.
15+
- Fixed a bug related to test cases being processed before the basis was initialized.
16+
- Removed unnecessary abstract methods from `models.mono._base._OpInfModel`.
17+
- Small updates to the literature page.
18+
819
## Version 0.5.15
920

1021
- Improvement to `fit_regselect_*()` so that the regularization does not have to be initialized before fitting the model. This fixes a longstanding chicken/egg problem and makes using `fit_regselect_*()` much less cumbersome.

src/opinf/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
https://github.com/Willcox-Research-Group/rom-operator-inference-Python3
88
"""
99

10-
__version__ = "0.5.15"
10+
__version__ = "0.5.16"
1111

1212
from . import (
1313
basis,

src/opinf/lstsq/_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,10 @@ def _load_dict(hf, attr="options"):
287287
options[key] = None if value == "NULL" else value
288288
return options
289289

290+
def reset(self) -> None:
291+
"""Reset the solver by deleting data matrices."""
292+
SolverTemplate.__init__(self)
293+
290294
def copy(self):
291295
"""Make a copy of the solver."""
292296
return copy.deepcopy(self)

src/opinf/lstsq/_tikhonov.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ def regresidual(self, Ohat: np.ndarray) -> np.ndarray:
113113
raise NotImplementedError # pragma: no cover
114114

115115
# Persistence -------------------------------------------------------------
116+
def reset(self) -> None:
117+
"""Reset the solver by deleting data matrices and the regularizer."""
118+
super().reset()
119+
self.regularizer = None
120+
116121
def _save(self, savefile, overwrite=False, extras=tuple()):
117122
"""Serialize the solver, saving it in HDF5 format.
118123
The model can be recovered with the :meth:`_load()` class method.

src/opinf/models/mono/_base.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -391,31 +391,6 @@ def _check_is_trained(self):
391391

392392
# Fitting -----------------------------------------------------------------
393393
@abc.abstractmethod
394-
def _process_fit_arguments(self, *args, **kwargs):
395-
"""Prepare training data, validate and set dimensions, etc."""
396-
pass
397-
398-
@abc.abstractmethod
399-
def _assemble_data_matrix(self, *args, **kwargs):
400-
"""Construct the Operator Inference data matrix."""
401-
pass
402-
403-
@abc.abstractmethod
404-
def _extract_operators(self, *args, **kwargs):
405-
"""Unpack the Operator Inference solution, the operator matrix."""
406-
pass
407-
408-
@abc.abstractmethod
409-
def _fit_solver(self, *args, **kwargs):
410-
"""Initialize the regression solver."""
411-
pass
412-
413-
@abc.abstractmethod
414-
def refit(self, *args, **kwargs):
415-
"""Solve the regression and unpack the results."""
416-
pass
417-
418-
@abc.abstractmethod
419-
def fit(self, *args, **kwargs):
394+
def fit(self, *args, **kwargs): # pragma: no cover
420395
"""Learn model operators from data."""
421396
pass

src/opinf/roms/_base.py

Lines changed: 72 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -433,22 +433,27 @@ def _process_training_data(
433433

434434
return states, lhs, inputs
435435

436-
def _process_test_cases(self, test_cases, TestCaseClass):
437-
if test_cases is not None:
438-
if isinstance(test_cases, TestCaseClass):
439-
test_cases = [test_cases]
440-
processed_test_cases = []
441-
for tcase in test_cases:
442-
if not isinstance(tcase, TestCaseClass):
443-
raise TypeError(
444-
"test cases must be "
445-
f"'utils.{TestCaseClass.__name__}' objects"
446-
)
447-
processed_test_cases.append(
448-
tcase.copy(self.encode(tcase.initial_conditions))
436+
def _process_test_cases(
437+
self,
438+
test_cases,
439+
TestCaseClass: utils._gridsearch._RegTest,
440+
):
441+
if test_cases is None:
442+
return []
443+
444+
if isinstance(test_cases, TestCaseClass):
445+
test_cases = [test_cases]
446+
processed_test_cases = []
447+
for tcase in test_cases:
448+
if not isinstance(tcase, TestCaseClass):
449+
raise TypeError(
450+
"test cases must be "
451+
f"'utils.{TestCaseClass.__name__}' objects"
449452
)
450-
return processed_test_cases
451-
return []
453+
processed_test_cases.append(
454+
tcase.copy(self.encode(tcase.initial_conditions))
455+
)
456+
return processed_test_cases
452457

453458
def _get_stability_limits(self, states, stability_margin):
454459
shifts = [np.mean(Q, axis=1).reshape((-1, 1)) for Q in states]
@@ -466,18 +471,25 @@ def _get_stability_limits(self, states, stability_margin):
466471
limits[ell] = np.inf
467472
return shifts, limits
468473

469-
def _fit_solver(
474+
def _fit_model(
470475
self,
471476
parameters,
472477
states,
473478
lhs,
474479
inputs,
475-
fit_transformer,
476-
fit_basis,
480+
fit_transformer: bool,
481+
fit_basis: bool,
482+
solver_only: bool = False,
477483
):
478484
"""Process the training data and fit the model solver.
479485
Returns the processed training data.
480486
487+
Parameters
488+
----------
489+
solver_only : bool
490+
If ``True``, call ``self.model._fit_solver()`` instead of
491+
``self.model.fit()``. This is useful for regularization selection.
492+
If ``False`` (default), call ``self.model.fit()``.
481493
"""
482494
self._check_fit_args(lhs=lhs, inputs=inputs)
483495
if parameters is None:
@@ -495,11 +507,17 @@ def _fit_solver(
495507

496508
if parameters is None:
497509
inputdata = None if inputs is None else np.hstack(inputs)
498-
self.model._fit_solver(
499-
np.hstack(states), np.hstack(lhs), inputdata
500-
)
510+
if solver_only:
511+
self.model._fit_solver(
512+
np.hstack(states), np.hstack(lhs), inputdata
513+
)
514+
else:
515+
self.model.fit(np.hstack(states), np.hstack(lhs), inputdata)
501516
else:
502-
self.model._fit_solver(parameters, states, lhs, inputs)
517+
if solver_only:
518+
self.model._fit_solver(parameters, states, lhs, inputs)
519+
else:
520+
self.model.fit(parameters, states, lhs, inputs)
503521

504522
return states
505523

@@ -649,23 +667,29 @@ def fit_regselect_continuous(
649667
raise ValueError("argument 'test_time_length' must be nonnegative")
650668
if regularizer_factory is None:
651669
regularizer_factory = _identity
652-
processed_test_cases = self._process_test_cases(
653-
test_cases, utils.ContinuousRegTest
654-
)
670+
671+
# Reset the solver (in case the basis dimension changed between calls).
672+
interp = modutils.is_interpolatory(self.model)
673+
if hasattr(self.model, "reset"):
674+
self.model.solver.reset()
655675

656676
# Fit the model for the first time.
657-
self._fit_solver(
677+
self._fit_model(
658678
parameters=parameters,
659679
states=states,
660680
lhs=ddts,
661681
inputs=inputs,
662682
fit_transformer=fit_transformer,
663683
fit_basis=fit_basis,
684+
solver_only=True,
664685
)
665686

666687
# Set up the regularization selection.
667688
states_ = [self.encode(Q) for Q in states]
668689
shifts, limits = self._get_stability_limits(states_, stability_margin)
690+
processed_test_cases = self._process_test_cases(
691+
test_cases, utils.ContinuousRegTest
692+
)
669693

670694
def unstable(_Q, ell, size):
671695
"""Return ``True`` if the solution is unstable."""
@@ -695,7 +719,12 @@ def unstable(_Q, ell, size):
695719

696720
def update_model(reg_params):
697721
"""Reset the regularizer and refit the model operators."""
698-
self.model.solver.regularizer = regularizer_factory(reg_params)
722+
reg = regularizer_factory(reg_params)
723+
if interp:
724+
for solver in self.model.solvers:
725+
solver.regularizer = reg
726+
else:
727+
self.model.solver.regularizer = reg
699728
self.model.refit()
700729

701730
def training_error(reg_params):
@@ -861,22 +890,28 @@ def fit_regselect_discrete(
861890
)
862891
if regularizer_factory is None:
863892
regularizer_factory = _identity
864-
processed_test_cases = self._process_test_cases(
865-
test_cases, utils.DiscreteRegTest
866-
)
893+
894+
# Reset the solver (in case the basis dimension changed between calls).
895+
interp = modutils.is_interpolatory(self.model)
896+
if hasattr(self.model, "reset"):
897+
self.model.solver.reset()
867898

868899
# Fit the model for the first time and get compressed training data.
869-
states_ = self._fit_solver(
900+
states_ = self._fit_model(
870901
parameters=parameters,
871902
states=states,
872903
lhs=None,
873904
inputs=inputs,
874905
fit_transformer=fit_transformer,
875906
fit_basis=fit_basis,
907+
solver_only=True,
876908
)
877909

878910
# Set up the regularization selection.
879911
shifts, limits = self._get_stability_limits(states_, stability_margin)
912+
processed_test_cases = self._process_test_cases(
913+
test_cases, utils.DiscreteRegTest
914+
)
880915

881916
def unstable(_Q, ell):
882917
"""Return ``True`` if the solution is unstable."""
@@ -897,7 +932,12 @@ def unstable(_Q, ell):
897932

898933
def update_model(reg_params):
899934
"""Reset the regularizer and refit the model operators."""
900-
self.model.solver.regularizer = regularizer_factory(reg_params)
935+
reg = regularizer_factory(reg_params)
936+
if interp:
937+
for solver in self.model.solvers:
938+
solver.regularizer = reg
939+
else:
940+
self.model.solver.regularizer = reg
901941
self.model.refit()
902942

903943
def training_error(reg_params):

src/opinf/roms/_bayes.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
import warnings
1010
import numpy as np
11-
import scipy.linalg
1211
import scipy.stats
12+
import scipy.linalg
1313

1414
from .. import errors, lstsq, post, utils
1515
from ..models import _utils as modutils
@@ -224,12 +224,11 @@ def posterior(self) -> OperatorPosterior:
224224

225225
def _initialize_posterior(self):
226226
"""Set the operator posterior if numerically possible."""
227-
means, precisions = self.model.solver.posterior()
228227
try:
228+
means, precisions = self.model.solver.posterior()
229229
self.__posterior = OperatorPosterior(means, precisions)
230-
except np.linalg.LinAlgError as ex:
231-
if ex.args[0] == "Matrix is not positive definite":
232-
self.__posterior = None
230+
except np.linalg.LinAlgError:
231+
self.__posterior = None
233232

234233
def draw_operators(self):
235234
"""Set the :attr:`model` operators to a new random draw from the
@@ -292,23 +291,26 @@ def fit_regselect_continuous(
292291
raise ValueError("argument 'test_time_length' must be nonnegative")
293292
if regularizer_factory is None:
294293
regularizer_factory = _identity
295-
processed_test_cases = self._process_test_cases(
296-
test_cases, utils.ContinuousRegTest
297-
)
298294

299295
# Fit the model for the first time.
300-
self._fit_solver(
296+
if hasattr(self.model.solver, "reset"):
297+
self.model.solver.reset()
298+
self._fit_model(
301299
parameters=parameters,
302300
states=states,
303301
lhs=ddts,
304302
inputs=inputs,
305303
fit_transformer=fit_transformer,
306304
fit_basis=fit_basis,
305+
solver_only=True,
307306
)
308307

309308
# Set up the regularization selection.
310309
states_ = [self.encode(Q) for Q in states]
311310
shifts, limits = self._get_stability_limits(states_, stability_margin)
311+
processed_test_cases = self._process_test_cases(
312+
test_cases, utils.ContinuousRegTest
313+
)
312314

313315
def unstable(_Q, ell, size):
314316
"""Return ``True`` if the solution is unstable."""
@@ -339,7 +341,9 @@ def unstable(_Q, ell, size):
339341
def update_model(reg_params):
340342
"""Reset the regularizer and refit the model operators."""
341343
self.model.solver.regularizer = regularizer_factory(reg_params)
342-
self._initialize_posterior()
344+
with warnings.catch_warnings():
345+
warnings.simplefilter("ignore", scipy.linalg.LinAlgWarning)
346+
self._initialize_posterior()
343347

344348
def training_error(reg_params):
345349
"""Compute the training error for a single regularization
@@ -439,22 +443,25 @@ def fit_regselect_discrete(
439443
)
440444
if regularizer_factory is None:
441445
regularizer_factory = _identity
442-
processed_test_cases = self._process_test_cases(
443-
test_cases, utils.DiscreteRegTest
444-
)
445446

446447
# Fit the model for the first time.
447-
states_ = self._fit_solver(
448+
if hasattr(self.model.solver, "reset"):
449+
self.model.solver.reset()
450+
states_ = self._fit_model(
448451
parameters=parameters,
449452
states=states,
450453
lhs=None,
451454
inputs=inputs,
452455
fit_transformer=fit_transformer,
453456
fit_basis=fit_basis,
457+
solver_only=True,
454458
)
455459

456460
# Set up the regularization selection.
457461
shifts, limits = self._get_stability_limits(states_, stability_margin)
462+
processed_test_cases = self._process_test_cases(
463+
test_cases, utils.DiscreteRegTest
464+
)
458465

459466
def unstable(_Q, ell):
460467
"""Return ``True`` if the solution is unstable."""
@@ -476,7 +483,9 @@ def unstable(_Q, ell):
476483
def update_model(reg_params):
477484
"""Reset the regularizer and refit the model operators."""
478485
self.model.solver.regularizer = regularizer_factory(reg_params)
479-
self._initialize_posterior()
486+
with warnings.catch_warnings():
487+
warnings.simplefilter("ignore", scipy.linalg.LinAlgWarning)
488+
self._initialize_posterior()
480489

481490
def training_error(reg_params):
482491
"""Compute the mean training error for a single regularization

0 commit comments

Comments
 (0)