Skip to content

Commit 7ae0adf

Browse files
authored
Merge pull request #429 from stephen-huan/kernel-diagonal
feat(gpjax/kernels/base.py): add diagonal
2 parents b69be96 + e697365 commit 7ae0adf

File tree

5 files changed

+101
-4
lines changed

5 files changed

+101
-4
lines changed

gpjax/kernels/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ def cross_covariance(self, x: Num[Array, "N D"], y: Num[Array, "M D"]):
6363
def gram(self, x: Num[Array, "N D"]):
6464
return self.compute_engine.gram(self, x)
6565

66+
def diagonal(self, x: Num[Array, "N D"]):
67+
return self.compute_engine.diagonal(self, x)
68+
6669
def slice_input(self, x: Float[Array, "... D"]) -> Float[Array, "... Q"]:
6770
r"""Slice out the relevant columns of the input matrix.
6871

gpjax/kernels/computations/basis_functions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from cola import PSD
1313
from cola.ops import (
1414
Dense,
15+
Diagonal,
1516
LinearOperator,
1617
)
1718

@@ -58,6 +59,20 @@ def gram(self, kernel: Kernel, inputs: Float[Array, "N D"]) -> LinearOperator:
5859
z1 = self.compute_features(kernel, inputs)
5960
return PSD(Dense(self.scaling(kernel) * jnp.matmul(z1, z1.T)))
6061

62+
def diagonal(self, kernel: Kernel, inputs: Float[Array, "N D"]) -> Diagonal:
63+
r"""For a given kernel, compute the elementwise diagonal of the
64+
NxN gram matrix on an input matrix of shape NxD.
65+
66+
Args:
67+
kernel (AbstractKernel): the kernel function.
68+
inputs (Float[Array, "N D"]): The input matrix.
69+
70+
Returns
71+
-------
72+
Diagonal: The computed diagonal variance entries.
73+
"""
74+
return super().diagonal(kernel.base_kernel, inputs)
75+
6176
def compute_features(
6277
self, kernel: Kernel, x: Float[Array, "N D"]
6378
) -> Float[Array, "N L"]:

tests/test_kernels/test_approximations.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from typing import Tuple
22

3-
from cola.ops import Dense
3+
from cola.ops import (
4+
Dense,
5+
Diagonal,
6+
)
47
import jax
58
from jax import config
69
import jax.numpy as jnp
@@ -63,6 +66,32 @@ def test_gram(kernel: AbstractKernel, num_basis_fns: int, n_dims: int, n_data: i
6366
assert jnp.all(evals > 0)
6467

6568

69+
@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52])
70+
@pytest.mark.parametrize("num_basis_fns", [2, 10, 20])
71+
@pytest.mark.parametrize("n_dims", [1, 2, 5])
72+
@pytest.mark.parametrize("n_data", [50, 100])
73+
def test_diagonal(kernel: AbstractKernel, num_basis_fns: int, n_dims: int, n_data: int):
74+
key = jr.key(123)
75+
x = jr.uniform(key, shape=(n_data, 1), minval=-3.0, maxval=3.0).reshape(-1, 1)
76+
if n_dims > 1:
77+
x = jnp.hstack([x] * n_dims)
78+
base_kernel = kernel(active_dims=list(range(n_dims)))
79+
approximate = RFF(base_kernel=base_kernel, num_basis_fns=num_basis_fns)
80+
81+
linop = approximate.diagonal(x)
82+
83+
# Check the return type
84+
assert isinstance(linop, Diagonal)
85+
86+
Kxx = linop.diag + _jitter
87+
88+
# Check that the shape is correct
89+
assert Kxx.shape == (n_data,)
90+
91+
# Check that the diagonal is positive
92+
assert jnp.all(Kxx > 0)
93+
94+
6695
@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52])
6796
@pytest.mark.parametrize("num_basis_fns", [2, 10, 20])
6897
@pytest.mark.parametrize("n_dims", [1, 2, 5])

tests/test_kernels/test_nonstationary.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
from itertools import product
1717
from typing import List
1818

19-
from cola.ops import LinearOperator
19+
from cola.ops import (
20+
Diagonal,
21+
LinearOperator,
22+
)
2023
import jax
2124
from jax import config
2225
import jax.numpy as jnp
@@ -125,9 +128,28 @@ def test_gram(self, dim: int, n: int) -> None:
125128

126129
# Test gram matrix
127130
Kxx = kernel.gram(x)
131+
Kxx_cross = kernel.cross_covariance(x, x)
128132
assert isinstance(Kxx, LinearOperator)
129133
assert Kxx.shape == (n, n)
130134
assert jnp.all(jnp.linalg.eigvalsh(Kxx.to_dense() + jnp.eye(n) * 1e-6) > 0.0)
135+
assert jnp.allclose(Kxx_cross, Kxx.to_dense())
136+
137+
@pytest.mark.parametrize("n", [1, 2, 5], ids=lambda x: f"n={x}")
138+
@pytest.mark.parametrize("dim", [1, 3], ids=lambda x: f"dim={x}")
139+
def test_diagonal(self, dim: int, n: int) -> None:
140+
# Initialise kernel
141+
kernel: AbstractKernel = self.kernel()
142+
143+
# Inputs
144+
x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim)
145+
146+
# Test diagonal
147+
Kxx = kernel.diagonal(x)
148+
Kxx_gram = jnp.diagonal(kernel.gram(x).to_dense())
149+
assert isinstance(Kxx, Diagonal)
150+
assert Kxx.shape == (n, n)
151+
assert jnp.all(Kxx.diag + 1e-6 > 0.0)
152+
assert jnp.allclose(Kxx_gram, Kxx.diag)
131153

132154
@pytest.mark.parametrize("n_a", [1, 2, 5], ids=lambda x: f"n_a={x}")
133155
@pytest.mark.parametrize("n_b", [1, 2, 5], ids=lambda x: f"n_b={x}")
@@ -139,11 +161,14 @@ def test_cross_covariance(self, n_a: int, n_b: int, dim: int) -> None:
139161
# Inputs
140162
a = jnp.linspace(-1.0, 1.0, n_a * dim).reshape(n_a, dim)
141163
b = jnp.linspace(3.0, 4.0, n_b * dim).reshape(n_b, dim)
164+
c = jnp.vstack((a, b))
142165

143166
# Test cross-covariance
144167
Kab = kernel.cross_covariance(a, b)
168+
Kab_gram = kernel.gram(c).to_dense()[:n_a, n_a:]
145169
assert isinstance(Kab, jnp.ndarray)
146170
assert Kab.shape == (n_a, n_b)
171+
assert jnp.allclose(Kab, Kab_gram)
147172

148173

149174
def prod(inp):
@@ -216,4 +241,4 @@ def test_values_by_monte_carlo_in_special_case(self, order: int) -> None:
216241
integrands = H_a * H_b * (weights_a**order) * (weights_b**order)
217242
Kab_approx = 2.0 * jnp.mean(integrands)
218243

219-
assert jnp.max(Kab_approx - Kab_exact) < 1e-4
244+
assert jnp.max(jnp.abs(Kab_approx - Kab_exact)) < 1e-4

tests/test_kernels/test_stationary.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
from dataclasses import is_dataclass
1818
from itertools import product
1919

20-
from cola.ops import LinearOperator
20+
from cola.ops import (
21+
Diagonal,
22+
LinearOperator,
23+
)
2124
import jax
2225
from jax import config
2326
import jax.numpy as jnp
@@ -129,9 +132,28 @@ def test_gram(self, dim: int, n: int) -> None:
129132

130133
# Test gram matrix
131134
Kxx = kernel.gram(x)
135+
Kxx_cross = kernel.cross_covariance(x, x)
132136
assert isinstance(Kxx, LinearOperator)
133137
assert Kxx.shape == (n, n)
134138
assert jnp.all(jnp.linalg.eigvalsh(Kxx.to_dense() + jnp.eye(n) * 1e-6) > 0.0)
139+
assert jnp.allclose(Kxx_cross, Kxx.to_dense())
140+
141+
@pytest.mark.parametrize("n", [1, 2, 5], ids=lambda x: f"n={x}")
142+
@pytest.mark.parametrize("dim", [1, 3], ids=lambda x: f"dim={x}")
143+
def test_diagonal(self, dim: int, n: int) -> None:
144+
# Initialise kernel
145+
kernel: AbstractKernel = self.kernel()
146+
147+
# Inputs
148+
x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim)
149+
150+
# Test diagonal
151+
Kxx = kernel.diagonal(x)
152+
Kxx_gram = jnp.diagonal(kernel.gram(x).to_dense())
153+
assert isinstance(Kxx, Diagonal)
154+
assert Kxx.shape == (n, n)
155+
assert jnp.all(Kxx.diag + 1e-6 > 0.0)
156+
assert jnp.allclose(Kxx_gram, Kxx.diag)
135157

136158
@pytest.mark.parametrize("n_a", [1, 2, 5], ids=lambda x: f"n_a={x}")
137159
@pytest.mark.parametrize("n_b", [1, 2, 5], ids=lambda x: f"n_b={x}")
@@ -143,11 +165,14 @@ def test_cross_covariance(self, n_a: int, n_b: int, dim: int) -> None:
143165
# Inputs
144166
a = jnp.linspace(-1.0, 1.0, n_a * dim).reshape(n_a, dim)
145167
b = jnp.linspace(3.0, 4.0, n_b * dim).reshape(n_b, dim)
168+
c = jnp.vstack((a, b))
146169

147170
# Test cross-covariance
148171
Kab = kernel.cross_covariance(a, b)
172+
Kab_gram = kernel.gram(c).to_dense()[:n_a, n_a:]
149173
assert isinstance(Kab, jnp.ndarray)
150174
assert Kab.shape == (n_a, n_b)
175+
assert jnp.allclose(Kab, Kab_gram)
151176

152177
def test_spectral_density(self):
153178
# Initialise kernel

0 commit comments

Comments
 (0)