-
-
Notifications
You must be signed in to change notification settings - Fork 445
Open
Description
Describe the bug
Arviz's from_numpyro
function works on many model objects in numpyro, but not the officially supported bridge with JaxNS — numpyro.contrib.nested_sampling
— because NestedSampler.get_samples
requires PRNG and num_samples args.
To Reproduce
>>> import arviz
>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.contrib.nested_sampling import NestedSampler
>>> true_coefs = jnp.array([1., 2., 3.])
>>> data = random.normal(random.PRNGKey(0), (2000, 3))
>>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(1))
>>> def model(data, labels):
... coefs = numpyro.sample('coefs', dist.Normal(0, 1).expand([3]))
... intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
... return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)),
... obs=labels)
>>> ns = NestedSampler(model)
>>> ns.run(random.PRNGKey(2), data, labels)
>>> arviz.from_numpyro(ns)
TypeError: NestedSampler.get_samples() missing 2 required positional arguments: 'rng_key' and 'num_samples'
Expected behavior
Arviz dispatches on the numpyro model type to pass a prng & num_samples args to the NestedSampler.get_samples. Maybe from_numpyro
adds an args
kwarg for specifying these values.
vandalt
Metadata
Metadata
Assignees
Labels
No labels