Skip to content

Numpyro SVI conversion #2453

@kylejcaron

Description

@kylejcaron

Tell us about it

It would be convenient to have a way to convert the output of a numpyro model fit with variational inference to an arviz InferenceData object - this would be analagous to pymc, where they also use InferenceData for their VI inference, and set chains=1.

The downside of this would be that many ArviZ diagnostics are meant for HMC, not VI and this could confuse inexperienced users

Thoughts on implementation

This might be tricky because while HMC stores everything needed for arviz, the SVI object does not - the outputted SVIRunResult (a namedtuple that contains the fitted params) is also needed along with the original SVI object (which contains the model and guide functions).

Possible Option 1

az.from_numpyro(
    svi,
    svi_result=svi_result,
    prior=...,
    posterior_predictive=...,
    coords=...,
  ...
)

Possible Option 2:

az.from_numpyro_svi(
   svi,
   svi_result=svi_result,
   prior=...,
   posterior_predictive=...,
   coords=...,
   ...
)

Possible Option 3:

az.from_numpyro(
    SVIWrapper(svi, result=svi_result), # mimics the mcmc structure for getting samples
    prior=...,
    posterior_predictive=...,
    coords=...,
  ...
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions