Skip to content

Commit 75906c2

Browse files
committed
skip cuda test if gpu not available
1 parent 2e4fc32 commit 75906c2

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tests/test_healpix_ffts.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import healpy as hp
44
import pytest
55
import jax
6+
from jax.extend.backend import get_backend
7+
gpu_available = get_backend().platform == "gpu"
68

79
jax.config.update("jax_enable_x64", True)
810
from s2fft.sampling import s2_samples as samples
@@ -49,7 +51,7 @@ def test_healpix_ifft_jax_numpy_consistency(flm_generator, nside, reality):
4951
healpix_ifft_jax(ftm_copy, L, nside, reality),
5052
)
5153

52-
54+
@pytest.mark.skipif(not gpu_available, reason="GPU not available")
5355
@pytest.mark.parametrize("nside", nside_to_test)
5456
def test_healpix_fft_cuda(flm_generator, nside):
5557
L = 2 * nside
@@ -64,7 +66,7 @@ def test_healpix_fft_cuda(flm_generator, nside):
6466
atol=1,
6567
rtol=1)
6668

67-
69+
@pytest.mark.skipif(not gpu_available, reason="GPU not available")
6870
@pytest.mark.parametrize("nside", nside_to_test)
6971
def test_healpix_ifft_cuda(flm_generator, nside):
7072
L = 2 * nside

0 commit comments

Comments
 (0)