-
Notifications
You must be signed in to change notification settings - Fork 70
Add deterministic advi #564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add deterministic advi #564
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
What exactly needs jax? |
Thank you both very much for having a look so quickly! @jessegrabowski Good point, yes maybe! I'll take a look. @ricardoV94 Currently, JAX is used to compute the hvp and the jacobian of the objective. That involves computing a gradient for each of the fixed draws and then taking an average. What's quite nice in JAX is that this can be done with That said, JAX isn't strictly necessary. Anything that can provide the Are you concerned about the JAX dependency? If so, maybe I could have a go at doing a JAX-free version using the code just mentioned and then only support JAX optionally. I do think it might be nice to have since it's probably more efficient and would hopefully also run fast on GPUs. But interested in your thoughts. Also, I see one of the pre-commit checks seem to be failing. I can do the work to make the pre-commit hooks happy, sorry I haven't done that yet. |
I think a jax dependency is fine. But if it's optional that's obviously
even better!
…On Thu, 14 Aug 2025, 12:14 Martin Ingram, ***@***.***> wrote:
*martiningram* left a comment (pymc-devs/pymc-extras#564)
<#564 (comment)>
Thank you both very much for having a look so quickly!
@jessegrabowski <https://github.com/jessegrabowski> Good point, yes
maybe! I'll take a look.
@ricardoV94 <https://github.com/ricardoV94> Currently, JAX is used to
compute the hvp and the jacobian of the objective. That involves computing
a gradient for each of the fixed draws and then taking an average. What's
quite nice in JAX is that this can be done with vmap easily:
https://github.com/pymc-devs/pymc-extras/pull/564/files#diff-1d6e8b962a8c3ca803c55bea43c19863223ed50ae3814acc55424834ade1215cR44
That said, JAX isn't strictly necessary. Anything that can provide the
DADVIFuns is fine:
https://github.com/pymc-devs/pymc-extras/pull/564/files#diff-48ee4e85c0ff57f5b8af20dfd608bd0e37c3a2c76169a7bbe499e77ff3802d9dR13
. In fact, I have code in the original research repo
<https://github.com/martiningram/dadvi/blob/main/dadvi/objective_from_model.py#L5>
that turns the regular hvp and gradient function into the DADVIFuns. But
I think it'll be slower because of the for loops e.g. here
<https://github.com/martiningram/dadvi/blob/main/dadvi/objective_from_model.py#L56>
.
Are you concerned about the JAX dependency? If so, maybe I could have a go
at doing a JAX-free version using the code just mentioned and then only
support JAX optionally. I do think it might be nice to have since it's
probably more efficient and would hopefully also run fast on GPUs. But
interested in your thoughts.
—
Reply to this email directly, view it on GitHub
<#564 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACCUO4HELS5DTXEAZOJRT3NSYW7AVCNFSM6AAAAACD4JGBS6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZTCOBZGAZTIMZRGQ>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
PyTensor has the equivalent The reason I ask is if you don't have anything jax specific you can still end up using jax, but also C or numba which may be better for certain users. |
@ricardoV94 Oh cool, thanks, I didn't realise! I'll take a look if I can use those. I agree it would be nice to support as many users as possible. |
Happy to assist you. If you're vectorizing the Jacobian you probably want to build Everything is described here although a bit scattered: https://pytensor.readthedocs.io/en/latest/tutorial/gradients.html |
Hey @ricardoV94 (and potentially others!), I think I could use your advice with the vectorisation. I think I've read enough to do it without using the functions here but I'd really like to try to get this vectorised for speed. To explain a bit: the code expects the definition of
The function should then return the estimate of the kl divergence using these draws, as well as its gradient with respect to the variational parameters. The KL divergence is the sum of the entropy of the approximation (a simple function of the variational parameters only) and the average of the log posterior densities from the draws. That's the part that I'd like to vectorise. Now in JAX, the way I do this is to...:
Thanks to
This makes sense in my head but the problem I see is that the pymc model's So in essence, I think I need code to do Thanks a lot for your help :) |
If you get the logp of a pymc model using The path followed by the laplace code is to freeze the model and extract the negative logp , then create a flat vector input replacing the individual value inputs, then compile the loss_and_grads/hess/hessp functions, (optionally in jax) My hope is that you can get the correct loss function for DADVI, then you should be able to directly pass it into The 4 steps you outline seem correct to me. |
Thanks a lot @jessegrabowski . I'll give it a go! |
Hey all, I think I made good progress with the pytensor version. A first version is here: https://github.com/pymc-devs/pymc-extras/pull/564/files#diff-1b6e7da940ec73fce49f5e13ae1db5369ec011cb0b55974ec04d81e519e923f6R55 I think the only major thing missing is to transform the draws back into the constrained space from the unconstrained space. Is there a code snippet anyone could point me to? Thanks for your help and for all the helpful advice you've already given! |
You can make a pytensor function from the model value variables to the
output variables. An example of that is how get_jaxified_graph is used in
the jax based samplers
https://github.com/pymc-devs/pymc/blob/main/pymc/sampling/jax.py#L682
If you look in the source of get_jaxified_graph you can see how it's done
…On Sat, 16 Aug 2025, 11:31 Martin Ingram, ***@***.***> wrote:
*martiningram* left a comment (pymc-devs/pymc-extras#564)
<#564 (comment)>
Hey all, I think I made good progress with the pytensor version. A first
version is here:
https://github.com/pymc-devs/pymc-extras/pull/564/files#diff-1b6e7da940ec73fce49f5e13ae1db5369ec011cb0b55974ec04d81e519e923f6R55
I think the only major thing missing is to transform the draws back into
the constrained space from the unconstrained space. Is there a code snippet
anyone could point me to? Thanks for your help and for all the helpful
advice you've already given!
—
Reply to this email directly, view it on GitHub
<#564 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACCUP5BFLVE5ZDJ4XL6433N5FD5AVCNFSM6AAAAACD4JGBS6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZTCOJTG4ZTSMZXG4>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
Hi everyone,
I'm one of the authors of the paper on deterministic ADVI. There is an open feature request for this in PyMC here so I thought I'd kick things off with this PR.
In simple terms, DADVI is like ADVI but rather than using a new draw to estimate its objective at each step, it uses a fixed set of draws during the optimisation. That means that (1) it can use regular off-the-shelf optimisers rather than stochastic optimisation, making convergence more reliable, and (2) it's possible to use techniques to improve the variance estimates. This is in the paper, as well as tools to assess how big the error is from using fixed draws.
This PR covers only the first part -- optimising ADVI with fixed draws. This is because I thought I'd start simple and because I'm hoping that it already addresses a real problem with ADVI, which is the difficulty in assessing convergence.
In addition to adding the code, there is an example notebook in
notebooks/deterministic_advi_example.ipynb
. It fits DADVI to the PyMC basic linear regression example. I can add more examples, but I thought I'd start simple.I mostly lifted the code from my research repository, so there are probably some style differences. Let me know what would be important to change.
Note that JAX is needed, but there shouldn't be any other dependencies.
Very keen to hear what you all think! :)
All the best,
Martin