Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 0 additions & 59 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,65 +107,6 @@ jupytext --to notebook example.py
jupytext --to py:percent example.ipynb
```

# Simple example

Let us import some dependencies and simulate a toy dataset $\mathcal{D}$.

```python
from jax import config

config.update("jax_enable_x64", True)

import gpjax as gpx
from jax import grad, jit
import jax.numpy as jnp
import jax.random as jr
import optax as ox

key = jr.key(123)

f = lambda x: 10 * jnp.sin(x)

n = 50
x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,1)).sort()
y = f(x) + jr.normal(key, shape=(n,1))
D = gpx.Dataset(X=x, y=y)

# Construct the prior
meanf = gpx.mean_functions.Zero()
kernel = gpx.kernels.RBF()
prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)

# Define a likelihood
likelihood = gpx.likelihoods.Gaussian(num_datapoints = n)

# Construct the posterior
posterior = prior * likelihood

# Define an optimiser
optimiser = ox.adam(learning_rate=1e-2)

# Obtain Type 2 MLEs of the hyperparameters
opt_posterior, history = gpx.fit(
model=posterior,
objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
train_data=D,
optim=optimiser,
num_iters=500,
safe=True,
key=key,
)

# Infer the predictive posterior distribution
xtest = jnp.linspace(-3., 3., 100).reshape(-1, 1)
latent_dist = opt_posterior(xtest, D)
predictive_dist = opt_posterior.likelihood(latent_dist)

# Obtain the predictive mean and standard deviation
pred_mean = predictive_dist.mean()
pred_std = predictive_dist.stddev()
```

# Installation

## Stable version
Expand Down
6 changes: 3 additions & 3 deletions docs/scripts/sharp_bits_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@
plt.savefig("../_static/step_size_figure.png", bbox_inches="tight")

# %%
import tensorflow_probability.substrates.jax.bijectors as tfb
import numpyro.distributions.transforms as npt

bij = tfb.Exp()
bij = npt.ExpTransform()

x = np.linspace(0.05, 3.0, 6)
y = np.asarray(bij.inverse(x))
y = np.asarray(bij.inv(x))
lval = 0.5
rval = 0.52

Expand Down
2 changes: 1 addition & 1 deletion docs/sharp_bits.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ this value that we apply gradient updates to. When we wish to recover the constr
value, we apply the inverse of the bijector, which is the exponential function in this
case. This gives us back the blue cross.

In GPJax, we supply bijective functions using [Tensorflow Probability](https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors).
In GPJax, we supply bijective functions using [Numpyro](https://num.pyro.ai/en/stable/distributions.html#transforms).


## Positive-definiteness
Expand Down
20 changes: 10 additions & 10 deletions examples/barycentres.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.6
# jupytext_version: 1.16.7
# kernelspec:
# display_name: gpjax
# language: python
Expand Down Expand Up @@ -41,7 +41,7 @@
import jax.scipy.linalg as jsl
from jaxtyping import install_import_hook
import matplotlib.pyplot as plt
import tensorflow_probability.substrates.jax.distributions as tfd
import numpyro.distributions as npd

from examples.utils import use_mpl_style

Expand Down Expand Up @@ -161,7 +161,7 @@


# %%
def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance:
def fit_gp(x: jax.Array, y: jax.Array) -> npd.MultivariateNormal:
if y.ndim == 1:
y = y.reshape(-1, 1)
D = gpx.Dataset(X=x, y=y)
Expand Down Expand Up @@ -204,9 +204,9 @@ def sqrtm(A: jax.Array):


def wasserstein_barycentres(
distributions: tp.List[tfd.MultivariateNormalFullCovariance], weights: jax.Array
distributions: tp.List[npd.MultivariateNormal], weights: jax.Array
):
covariances = [d.covariance() for d in distributions]
covariances = [d.covariance_matrix for d in distributions]
cov_stack = jnp.stack(covariances)
stack_sqrt = jax.vmap(sqrtm)(cov_stack)

Expand All @@ -231,7 +231,7 @@ def step(covariance_candidate: jax.Array, idx: None):
# %%
weights = jnp.ones((n_datasets,)) / n_datasets

means = jnp.stack([d.mean() for d in posterior_preds])
means = jnp.stack([d.mean for d in posterior_preds])
barycentre_mean = jnp.tensordot(weights, means, axes=1)

step_fn = jax.jit(wasserstein_barycentres(posterior_preds, weights))
Expand All @@ -242,7 +242,7 @@ def step(covariance_candidate: jax.Array, idx: None):
)
L = jnp.linalg.cholesky(barycentre_covariance)

barycentre_process = tfd.MultivariateNormalTriL(barycentre_mean, L)
barycentre_process = npd.MultivariateNormal(barycentre_mean, scale_tril=L)

# %% [markdown]
# ## Plotting the result
Expand All @@ -254,16 +254,16 @@ def step(covariance_candidate: jax.Array, idx: None):

# %%
def plot(
dist: tfd.MultivariateNormalTriL,
dist: npd.MultivariateNormal,
ax,
color: str,
label: str = None,
ci_alpha: float = 0.2,
linewidth: float = 1.0,
zorder: int = 0,
):
mu = dist.mean()
sigma = dist.stddev()
mu = dist.mean
sigma = jnp.sqrt(dist.variance)
ax.plot(xtest, mu, linewidth=linewidth, color=color, label=label, zorder=zorder)
ax.fill_between(
xtest.squeeze(),
Expand Down
26 changes: 11 additions & 15 deletions examples/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.6
# jupytext_version: 1.16.7
# kernelspec:
# display_name: gpjax
# language: python
Expand Down Expand Up @@ -37,8 +37,8 @@
install_import_hook,
)
import matplotlib.pyplot as plt
import numpyro.distributions as npd
import optax as ox
import tensorflow_probability.substrates.jax as tfp

from examples.utils import use_mpl_style
from gpjax.lower_cholesky import lower_cholesky
Expand All @@ -50,7 +50,6 @@
import gpjax as gpx


tfd = tfp.distributions
identity_matrix = jnp.eye

# set the default style for plotting
Expand Down Expand Up @@ -120,7 +119,6 @@
# Optax's optimisers.

# %%

optimiser = ox.adam(learning_rate=0.01)

opt_posterior, history = gpx.fit(
Expand All @@ -140,8 +138,8 @@
map_latent_dist = opt_posterior.predict(xtest, train_data=D)
predictive_dist = opt_posterior.likelihood(map_latent_dist)

predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()
predictive_mean = predictive_dist.mean
predictive_std = jnp.sqrt(predictive_dist.variance)

fig, ax = plt.subplots()
ax.scatter(x, y, label="Observations", color=cols[0])
Expand Down Expand Up @@ -215,8 +213,6 @@
# datapoints below.

# %%


gram, cross_covariance = (kernel.gram, kernel.cross_covariance)
jitter = 1e-6

Expand Down Expand Up @@ -246,7 +242,7 @@ def loss(params, D):
L_inv = jsp.linalg.solve_triangular(L, identity_matrix(D.n), lower=True)
H_inv = jsp.linalg.solve_triangular(L.T, L_inv, lower=False)
LH = jnp.linalg.cholesky(H_inv)
laplace_approximation = tfd.MultivariateNormalTriL(f_hat.squeeze(), LH)
laplace_approximation = npd.MultivariateNormal(f_hat.squeeze(), scale_tril=LH)


# %% [markdown]
Expand All @@ -265,7 +261,7 @@ def loss(params, D):


# %%
def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNormalTriL:
def construct_laplace(test_inputs: Float[Array, "N D"]) -> npd.MultivariateNormal:
map_latent_dist = opt_posterior.predict(xtest, train_data=D)

Kxt = opt_posterior.prior.kernel.cross_covariance(x, test_inputs)
Expand All @@ -279,10 +275,10 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma
# Ktx Kxx⁻¹[ H⁻¹ ] Kxx⁻¹ Kxt
laplace_cov_term = jnp.matmul(jnp.matmul(Kxx_inv_Kxt.T, H_inv), Kxx_inv_Kxt)

mean = map_latent_dist.mean()
covariance = map_latent_dist.covariance() + laplace_cov_term
mean = map_latent_dist.mean
covariance = map_latent_dist.covariance_matrix + laplace_cov_term
L = jnp.linalg.cholesky(covariance)
return tfd.MultivariateNormalTriL(jnp.atleast_1d(mean.squeeze()), L)
return npd.MultivariateNormal(jnp.atleast_1d(mean.squeeze()), scale_tril=L)


# %% [markdown]
Expand All @@ -291,8 +287,8 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma
laplace_latent_dist = construct_laplace(xtest)
predictive_dist = opt_posterior.likelihood(laplace_latent_dist)

predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()
predictive_mean = predictive_dist.mean
predictive_std = jnp.sqrt(predictive_dist.variance)

fig, ax = plt.subplots()
ax.scatter(x, y, label="Observations", color=cols[0])
Expand Down
8 changes: 4 additions & 4 deletions examples/collapsed_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.6
# jupytext_version: 1.16.7
# kernelspec:
# display_name: gpjax_beartype
# language: python
Expand Down Expand Up @@ -161,10 +161,10 @@

inducing_points = opt_posterior.inducing_inputs.value

samples = latent_dist.sample(seed=key, sample_shape=(20,))
samples = latent_dist.sample(key=key, sample_shape=(20,))

predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()
predictive_mean = predictive_dist.mean
predictive_std = jnp.sqrt(predictive_dist.variance)

fig, ax = plt.subplots()

Expand Down
31 changes: 26 additions & 5 deletions examples/constructing_new_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.6
# jupytext_version: 1.16.7
# kernelspec:
# display_name: gpjax
# language: python
Expand All @@ -24,6 +24,7 @@
# %%
# Enable Float64 for more stable matrix inversions.
from jax import config
from jax.nn import softplus
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import (
Expand All @@ -32,7 +33,9 @@
install_import_hook,
)
import matplotlib.pyplot as plt
import tensorflow_probability.substrates.jax as tfp
import numpyro.distributions as npd
from numpyro.distributions import constraints
import numpyro.distributions.transforms as npt

from examples.utils import use_mpl_style
from gpjax.kernels.computations import DenseKernelComputation
Expand Down Expand Up @@ -225,9 +228,27 @@ def angular_distance(x, y, c):
return jnp.abs((x - y + c) % (c * 2) - c)


bij = tfb.SoftClip(low=jnp.array(4.0, dtype=jnp.float64))
class ShiftedSoftplusTransform(npt.ParameterFreeTransform):
r"""
Transform from unconstrained space to the domain [4, infinity) via
:math:`y = 4 + \log(1 + \exp(x))`. The inverse is computed as
:math:`x = \log(\exp(y - 4) - 1)`.
"""

DEFAULT_BIJECTION["polar"] = bij
domain = constraints.real
codomain = constraints.interval(4.0, jnp.inf) # updated codomain

def __call__(self, x):
return 4.0 + softplus(x) # shift the softplus output by 4

def _inverse(self, y):
return npt._softplus_inv(y - 4.0) # subtract the shift in the inverse

def log_abs_det_jacobian(self, x, y, intermediates=None):
return -softplus(-x)


DEFAULT_BIJECTION["polar"] = ShiftedSoftplusTransform()


class Polar(gpx.kernels.AbstractKernel):
Expand Down Expand Up @@ -307,7 +328,7 @@ def __call__(

# %%
posterior_rv = opt_posterior.likelihood(opt_posterior.predict(angles, train_data=D))
mu = posterior_rv.mean()
mu = posterior_rv.mean
one_sigma = posterior_rv.stddev()

# %%
Expand Down
6 changes: 3 additions & 3 deletions examples/deep_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.6
# jupytext_version: 1.16.7
# kernelspec:
# display_name: gpjax
# language: python
Expand Down Expand Up @@ -238,8 +238,8 @@ def __call__(self, x: jax.Array) -> jax.Array:
latent_dist = opt_posterior(xtest, train_data=D)
predictive_dist = opt_posterior.likelihood(latent_dist)

predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()
predictive_mean = predictive_dist.mean
predictive_std = jnp.sqrt(predictive_dist.variance)

fig, ax = plt.subplots()
ax.plot(x, y, "o", label="Observations", color=cols[0])
Expand Down
Loading