Skip to content

Conversation

theo-brown
Copy link
Contributor

@theo-brown theo-brown commented May 19, 2025

Checklist

  • I've formatted the new code by running hatch run dev:format before committing.
  • I've added tests for new code.
  • I've added docstrings for the new code.

Description

L-BFGS tends to be the most commonly used optimizer for fitting GPs (used, for example, in GPytorch/Botorch).

gpx.fit_scipy under the hood used scipy.optimize.minimize, which defaults to running L-BFGS (note, not jax.scipy but the original scipy, as jax.scipy only has BFGS rather than the reduced memory L-BFGS).
Optax has an LBFGS implementation, but the API requires some additional arguments, namely value, grad, and value_fn in the optim.update step, that means that it can't be directly with gpx.fit. Hence, in this PR I've added a new function, fit_lbfgs, that's a hybrid between the fit_scipy and fit functions, using Optax's LBFGS solver in a lax.while loop.

Notes

  • There are still potential use cases where fit_scipy might be preferable. For example, on a CPU-only system with a small number of datapoints, fit_scipy is in fact faster than fit_lbfgs, as the compilation time of the lax loop is nontrivial. However, for larger numbers of datapoints or iterations, fit_lbfgs is faster, even on CPU.
  • I chose to use a lax.while rather than a lax.scan because it allows early termination. I deemed this to be more important than logging the full history of the objective values. It also more closely reflects the fit_scipy interface.

Happy to be corrected on any of the above! Let me know what you think.

@theo-brown
Copy link
Contributor Author

If it would be helpful/interesting, I can provide a speed comparison between fit_scipy and fit_lbfgs.

@thomaspinder
Copy link
Collaborator

Thanks for this @theo-brown - this is something I’d love to get merged. Can you take a look at the failing test please?

@theo-brown
Copy link
Contributor Author

Oops, good catch! Fixed :)

@thomaspinder thomaspinder merged commit b583bc2 into JaxGaussianProcesses:main May 20, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants