Add fit_lbfgs
using Optax
#514
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Checklist
hatch run dev:format
before committing.Description
L-BFGS tends to be the most commonly used optimizer for fitting GPs (used, for example, in GPytorch/Botorch).
gpx.fit_scipy
under the hood usedscipy.optimize.minimize
, which defaults to running L-BFGS (note, notjax.scipy
but the original scipy, asjax.scipy
only has BFGS rather than the reduced memory L-BFGS).Optax has an LBFGS implementation, but the API requires some additional arguments, namely
value
,grad
, andvalue_fn
in theoptim.update
step, that means that it can't be directly withgpx.fit
. Hence, in this PR I've added a new function,fit_lbfgs
, that's a hybrid between thefit_scipy
andfit
functions, using Optax's LBFGS solver in a lax.while loop.Notes
fit_scipy
might be preferable. For example, on a CPU-only system with a small number of datapoints,fit_scipy
is in fact faster thanfit_lbfgs
, as the compilation time of the lax loop is nontrivial. However, for larger numbers of datapoints or iterations,fit_lbfgs
is faster, even on CPU.lax.while
rather than alax.scan
because it allows early termination. I deemed this to be more important than logging the full history of the objective values. It also more closely reflects thefit_scipy
interface.Happy to be corrected on any of the above! Let me know what you think.