Skip to content

Commit 091074c

Browse files
authored
Basis Backend Maintenance (#82)
* Basis.fit_compress() method * update basis tests * pragma: no cover to verify() methods * time() -> process_time() in TimedBlock * PODBasis handle sparse weights (w/ warning) if svdsolver='dense'
1 parent bd746c0 commit 091074c

File tree

12 files changed

+172
-221
lines changed

12 files changed

+172
-221
lines changed

src/opinf/basis/_base.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
]
77

88
import abc
9+
import copy
910
import numpy as np
1011
import scipy.linalg as la
1112

@@ -102,13 +103,14 @@ def compress(self, states):
102103
Parameters
103104
----------
104105
states : (n, ...) ndarray
105-
Matrix of `n`-dimensional state vectors, or a single state vector.
106+
Matrix of :math:`n`-dimensional state vectors, or a single state
107+
vector.
106108
107109
Returns
108110
-------
109111
states_compressed : (r, ...) ndarray
110-
Matrix of `r`-dimensional latent coordinate vectors, or a single
111-
coordinate vector.
112+
Matrix of :math:`r`-dimensional latent coordinate vectors, or a
113+
single coordinate vector.
112114
"""
113115
raise NotImplementedError # pragma: no cover
114116

@@ -119,20 +121,37 @@ def decompress(self, states_compressed, locs=None):
119121
Parameters
120122
----------
121123
states_compressed : (r, ...) ndarray
122-
Matrix of `r`-dimensional latent coordinate vectors, or a single
123-
coordinate vector.
124+
Matrix of :math:`r`-dimensional latent coordinate vectors, or a
125+
single coordinate vector.
124126
locs : slice or (p,) ndarray of integers or None
125127
If given, return the decompressed state at *only* the
126128
`p` specified locations (indices) described by ``locs``.
127129
128130
Returns
129131
-------
130132
states_decompressed : (n, ...) or (p, ...) ndarray
131-
Matrix of `n`-dimensional decompressed state vectors, or the `p`
132-
entries of such at the entries specified by ``locs``.
133+
Matrix of :math:`n`-dimensional decompressed state vectors, or the
134+
:math:`p` entries of such at the entries specified by ``locs``.
133135
"""
134136
raise NotImplementedError # pragma: no cover
135137

138+
def fit_compress(self, states):
139+
"""Construct the basis and map high-dimensional states to
140+
low-dimensional latent coordinates.
141+
142+
Parameters
143+
----------
144+
states : (n, k) ndarray
145+
Matrix of :math:`k` :math:`n`-dimensional snapshots.
146+
147+
Returns
148+
-------
149+
states_compressed : (r, k) ndarray
150+
Matrix of :math:`r`-dimensional latent coordinate vectors.
151+
"""
152+
self.fit(states)
153+
return self.compress(states)
154+
136155
# Projection --------------------------------------------------------------
137156
def project(self, state):
138157
"""Project a high-dimensional state vector to the subset of the
@@ -150,13 +169,14 @@ def project(self, state):
150169
Parameters
151170
----------
152171
states : (n, ...) ndarray
153-
Matrix of `n`-dimensional state vectors, or a single state vector.
172+
Matrix of :math:`n`-dimensional state vectors, or a single state
173+
vector.
154174
155175
Returns
156176
-------
157177
state_projected : (n, ...) ndarray
158-
Matrix of `n`-dimensional projected state vectors, or a single
159-
projected state vector.
178+
Matrix of :math:`n`-dimensional projected state vectors, or a
179+
single projected state vector.
160180
"""
161181
return self.decompress(self.compress(state))
162182

@@ -202,8 +222,12 @@ def load(cls, loadfile: str):
202222
"""Load a transformer from an HDF5 file."""
203223
raise NotImplementedError("use pickle/joblib") # pragma: no cover
204224

225+
def copy(self):
226+
"""Make a copy of the basis."""
227+
return copy.deepcopy(self)
228+
205229
# Verification ------------------------------------------------------------
206-
def verify(self):
230+
def verify(self): # pragma: no cover
207231
"""Verify that :meth:`compress()` and :meth:`decompress()` are
208232
consistent in the sense that the range of :meth:`decompress()` is in
209233
the domain of :meth:`compress()` and that :meth:`project()` defines
@@ -247,7 +271,9 @@ def verify(self):
247271
)
248272
print("compress() and decompress() are consistent")
249273

250-
def _verify_locs(self, states_compressed, states_projected):
274+
def _verify_locs(
275+
self, states_compressed, states_projected
276+
): # pragma: no cover
251277
"""Verification of decompress() with locs != None."""
252278
n = states_projected.shape[0]
253279
locs = np.sort(np.random.choice(n, size=(n // 3), replace=False))

src/opinf/basis/_pod.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,14 @@ def __init__(
293293
if weights.ndim == 1:
294294
self.__sqrt_weights = np.sqrt(weights)
295295
else: # (weights.ndim == 2, checked by LinearBasis)
296+
if sparse.issparse(weights):
297+
weights = weights.toarray()
298+
if weights.shape[0] > 100: # pragma: no cover
299+
warnings.warn(
300+
"computing the square root of a large weight matrix, "
301+
"consider using svdsolver='method-of-snapshots'",
302+
errors.OpInfWarning,
303+
)
296304
self.__sqrt_weights = la.sqrtm(weights)
297305
self.__sqrt_weights_cho = la.cho_factor(self.__sqrt_weights)
298306

src/opinf/ddt/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def mask(self, arr):
174174
raise NotImplementedError # pragma: no cover
175175

176176
# Verification ------------------------------------------------------------
177-
def verify_shapes(self, r: int = 5, m: int = 3):
177+
def verify_shapes(self, r: int = 5, m: int = 3): # pragma: no cover
178178
"""Verify that :meth:`estimate()` is consistent in the sense that the
179179
all outputs have the same number of columns. This method does **not**
180180
check the accuracy of the derivative estimation.

src/opinf/lstsq/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def copy(self):
296296
return copy.deepcopy(self)
297297

298298
# Verification ------------------------------------------------------------
299-
def verify(self):
299+
def verify(self): # pragma: no cover
300300
"""Verify the solver.
301301
302302
If the solver is already trained, check :meth:`solve()`,

src/opinf/operators/_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def verify(
230230
k: int = 10,
231231
fdifftol: float = 1e-5,
232232
ntests: int = 4,
233-
) -> None:
233+
) -> None: # pragma: no cover
234234
"""Verify consistency between dimension properties and required
235235
methods.
236236
@@ -827,7 +827,7 @@ def verify(
827827
ntests: int = 4,
828828
r: int = 6,
829829
m: int = 3,
830-
) -> None:
830+
) -> None: # pragma: no cover
831831
"""Verify consistency between dimension properties and required
832832
methods.
833833

src/opinf/pre/_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def load(cls, loadfile):
229229
raise NotImplementedError("use pickle/joblib") # pragma: no cover
230230

231231
# Verification ------------------------------------------------------------
232-
def verify(self, tol: float = 1e-4):
232+
def verify(self, tol: float = 1e-4): # pragma: no cover
233233
r"""Verify that :meth:`transform()` and :meth:`inverse_transform()`
234234
are consistent and that :meth:`transform_ddts()`, if implemented,
235235
is consistent with :meth:`transform()`.
@@ -333,7 +333,7 @@ def verify(self, tol: float = 1e-4):
333333
)
334334
print("transform() and transform_ddts() are consistent")
335335

336-
def _verify_locs(self, states, states_transformed):
336+
def _verify_locs(self, states, states_transformed): # pragma: no cover
337337
"""Verification for inverse_transform() with locs != None"""
338338
n = states.shape[0]
339339
locs = np.sort(np.random.choice(n, size=(n // 3), replace=False))

src/opinf/roms/_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ def _process_training_data(
404404
# Dimensionality reduction.
405405
if self.basis is not None:
406406
if fit_basis:
407+
# NOTE: self.basis.fit_compress() here?
407408
self.basis.fit(np.hstack(states))
408409
states = [self.basis.compress(Q) for Q in states]
409410
if lhs is not None:
@@ -734,7 +735,7 @@ def training_error(reg_params):
734735
"""
735736
try:
736737
update_model(reg_params)
737-
except Exception as ex:
738+
except Exception as ex: # pragma: no cover
738739
if verbose:
739740
print(f"{type(ex).__name__} in refit(): {ex}")
740741
return np.inf
@@ -947,7 +948,7 @@ def training_error(reg_params):
947948
"""
948949
try:
949950
update_model(reg_params)
950-
except Exception as ex:
951+
except Exception as ex: # pragma: no cover
951952
if verbose:
952953
print(f"{type(ex).__name__} in refit(): {ex}")
953954
return np.inf

src/opinf/utils/_timer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,15 +139,15 @@ def __enter__(self):
139139
sys.stdout = self.__new_buffer = io.StringIO()
140140
if self.verbose:
141141
print(f"{self.message}...", end=self.__front, flush=True)
142-
self._tic = time.time()
142+
self._tic = time.process_time()
143143
if self.timelimit is not None:
144144
signal.signal(signal.SIGALRM, self._signal_handler)
145145
signal.alarm(self.timelimit)
146146
return self
147147

148148
def __exit__(self, exc_type, exc_value, exc_traceback):
149149
"""Calculate and report the elapsed time."""
150-
self._toc = time.time()
150+
self._toc = time.process_time()
151151
if self.timelimit is not None:
152152
signal.alarm(0)
153153
elapsed = self._toc - self._tic
@@ -170,7 +170,7 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
170170
self._reset_stdout()
171171
raise
172172
else: # If no exception, report execution time.
173-
if self.verbose:
173+
if self.verbose and self.message:
174174
print(f"done in {elapsed:.2f} s.", flush=True, end=self.__back)
175175
logging.info(f"{self.message}...done in {elapsed:.6f} s.")
176176
self.__elapsed = elapsed

0 commit comments

Comments
 (0)