Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/constructing_new_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@

meanf = gpx.mean_functions.Zero()

for k, ax, c in zip(kernels, axes.ravel(), cols):
for k, ax, c in zip(kernels, axes.ravel(), cols, strict=False):
prior = gpx.gps.Prior(mean_function=meanf, kernel=k)
rv = prior(x)
y = rv.sample(seed=key, sample_shape=(10,))
Expand Down
2 changes: 1 addition & 1 deletion examples/intro_to_gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@

cmap = mpl.colors.LinearSegmentedColormap.from_list("custom", ["white", cols[1]], N=256)

for a, t, d in zip([ax0, ax1, ax2], titles, dists):
for a, t, d in zip([ax0, ax1, ax2], titles, dists, strict=False):
d_prob = d.prob(jnp.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)])).reshape(
xx.shape
)
Expand Down
7 changes: 2 additions & 5 deletions examples/intro_to_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,7 @@
from sklearn.preprocessing import StandardScaler

from examples.utils import use_mpl_style
from gpjax.parameters import (
PositiveReal,
Static,
)
from gpjax.parameters import Static
from gpjax.typing import Array

config.update("jax_enable_x64", True)
Expand Down Expand Up @@ -204,7 +201,7 @@

meanf = gpx.mean_functions.Zero()

for k, ax in zip(kernels, axes.ravel()):
for k, ax in zip(kernels, axes.ravel(), strict=False):
prior = gpx.gps.Prior(mean_function=meanf, kernel=k)
rv = prior(x)
y = rv.sample(seed=key, sample_shape=(10,))
Expand Down
2 changes: 1 addition & 1 deletion examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def confidence_ellipse(x, y, ax, n_std=3.0, facecolor="none", **kwargs):

def clean_legend(ax):
handles, labels = ax.get_legend_handles_labels()
by_label = dict(zip(labels, handles))
by_label = dict(zip(labels, handles, strict=False))
ax.legend(by_label.values(), by_label.keys())
return ax

Expand Down
4 changes: 3 additions & 1 deletion gpjax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ def affine_transformation(x):

return vmap(affine_transformation)(Z)

def sample(self, seed: KeyArray, sample_shape: Tuple[int, ...]): # pylint: disable=useless-super-delegation
def sample(
self, seed: KeyArray, sample_shape: Tuple[int, ...]
): # pylint: disable=useless-super-delegation
r"""See `Distribution.sample`."""
return self._sample_n(
seed, sample_shape[0]
Expand Down
5 changes: 3 additions & 2 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from dataclasses import is_dataclass

try:
import beartype
from beartype.roar import BeartypeCallHintParamViolation
from jaxtyping import TypeCheckError

ValidationErrors = (ValueError, beartype.roar.BeartypeCallHintParamViolation)
ValidationErrors = (TypeError, BeartypeCallHintParamViolation, TypeCheckError)
except ImportError:
ValidationErrors = ValueError

Expand Down
3 changes: 2 additions & 1 deletion tests/test_decision_making/test_search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from jaxtyping import (
Array,
Float,
TypeCheckError,
)
import pytest

Expand Down Expand Up @@ -64,7 +65,7 @@ def test_continuous_search_space_dtype_consistency(
def test_continous_search_space_bounds_shape_consistency(
lower_bounds: Float[Array, " D1"], upper_bounds: Float[Array, " D2"]
):
with pytest.raises((BeartypeCallHintParamViolation, ValueError)):
with pytest.raises((BeartypeCallHintParamViolation, TypeCheckError, ValueError)):
ContinuousSearchSpace(lower_bounds=lower_bounds, upper_bounds=upper_bounds)


Expand Down
3 changes: 2 additions & 1 deletion tests/test_gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@

try:
from beartype.roar import BeartypeCallHintParamViolation
from jaxtyping import TypeCheckError

ValidationErrors = (TypeError, BeartypeCallHintParamViolation)
ValidationErrors = (TypeError, BeartypeCallHintParamViolation, TypeCheckError)
except ImportError:
ValidationErrors = TypeError

Expand Down
2 changes: 1 addition & 1 deletion tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_transform(param, value):

# Test inverse transformation
it_params = transform(t_params, DEFAULT_BIJECTION, inverse=True)
assert it_params == params
assert repr(it_params) == repr(params)


@pytest.mark.parametrize(
Expand Down