Skip to content

Commit c552ce5

Browse files
committed
use the extend module only for jax 0.4.32
1 parent 75906c2 commit c552ce5

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tests/test_healpix_ffts.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
import healpy as hp
44
import pytest
55
import jax
6-
from jax.extend.backend import get_backend
6+
if jax.__version__ < "0.4.32":
7+
from jax.lib.xla_bridge import get_backend
8+
else:
9+
from jax.extend.backend import get_backend
710
gpu_available = get_backend().platform == "gpu"
811

912
jax.config.update("jax_enable_x64", True)

0 commit comments

Comments
 (0)