|
4 | 4 | import pytest
|
5 | 5 | import jax
|
6 | 6 | from packaging.version import Version as _Version
|
| 7 | + |
7 | 8 | if _Version(jax.__version__) < _Version("0.4.32"):
|
8 |
| - from jax.lib.xla_bridge import get_backend |
| 9 | + from jax.lib.xla_bridge import get_backend |
9 | 10 | else:
|
10 |
| - from jax.extend.backend import get_backend |
| 11 | + from jax.extend.backend import get_backend |
11 | 12 | gpu_available = get_backend().platform == "gpu"
|
12 | 13 |
|
13 | 14 | jax.config.update("jax_enable_x64", True)
|
|
28 | 29 | @pytest.mark.parametrize("nside", nside_to_test)
|
29 | 30 | @pytest.mark.parametrize("reality", reality_to_test)
|
30 | 31 | def test_healpix_fft_jax_numpy_consistency(flm_generator, nside, reality):
|
31 |
| - L = 2 * nside |
32 |
| - # Generate a random bandlimited signal |
33 |
| - flm = flm_generator(L=L, reality=reality) |
34 |
| - flm_hp = samples.flm_2d_to_hp(flm, L) |
35 |
| - f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1) |
36 |
| - # Test consistency |
37 |
| - assert np.allclose( |
38 |
| - healpix_fft_numpy(f, L, nside, reality), |
39 |
| - healpix_fft_jax(f, L, nside, reality)) |
| 32 | + L = 2 * nside |
| 33 | + # Generate a random bandlimited signal |
| 34 | + flm = flm_generator(L=L, reality=reality) |
| 35 | + flm_hp = samples.flm_2d_to_hp(flm, L) |
| 36 | + f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1) |
| 37 | + # Test consistency |
| 38 | + assert np.allclose( |
| 39 | + healpix_fft_numpy(f, L, nside, reality), healpix_fft_jax(f, L, nside, reality) |
| 40 | + ) |
40 | 41 |
|
41 | 42 |
|
42 | 43 | @pytest.mark.parametrize("nside", nside_to_test)
|
43 | 44 | @pytest.mark.parametrize("reality", reality_to_test)
|
44 | 45 | def test_healpix_ifft_jax_numpy_consistency(flm_generator, nside, reality):
|
45 |
| - L = 2 * nside |
46 |
| - # Generate a random bandlimited signal |
47 |
| - flm = flm_generator(L=L, reality=reality) |
48 |
| - flm_hp = samples.flm_2d_to_hp(flm, L) |
49 |
| - f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1) |
50 |
| - ftm = healpix_fft_numpy(f, L, nside, reality) |
51 |
| - ftm_copy = np.copy(ftm) |
52 |
| - # Test consistency |
53 |
| - assert np.allclose( |
54 |
| - healpix_ifft_numpy(ftm, L, nside, reality), |
55 |
| - healpix_ifft_jax(ftm_copy, L, nside, reality), |
56 |
| - ) |
| 46 | + L = 2 * nside |
| 47 | + # Generate a random bandlimited signal |
| 48 | + flm = flm_generator(L=L, reality=reality) |
| 49 | + flm_hp = samples.flm_2d_to_hp(flm, L) |
| 50 | + f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1) |
| 51 | + ftm = healpix_fft_numpy(f, L, nside, reality) |
| 52 | + ftm_copy = np.copy(ftm) |
| 53 | + # Test consistency |
| 54 | + assert np.allclose( |
| 55 | + healpix_ifft_numpy(ftm, L, nside, reality), |
| 56 | + healpix_ifft_jax(ftm_copy, L, nside, reality), |
| 57 | + ) |
| 58 | + |
57 | 59 |
|
58 | 60 | @pytest.mark.skipif(not gpu_available, reason="GPU not available")
|
59 | 61 | @pytest.mark.parametrize("nside", nside_to_test)
|
60 | 62 | def test_healpix_fft_cuda(flm_generator, nside):
|
61 |
| - L = 2 * nside |
62 |
| - # Generate a random bandlimited signal |
63 |
| - flm = flm_generator(L=L, reality=False) |
64 |
| - flm_hp = samples.flm_2d_to_hp(flm, L) |
65 |
| - f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1) |
66 |
| - # Test consistency |
67 |
| - assert_allclose( |
68 |
| - healpix_fft_jax(f, L, nside, False), |
69 |
| - healpix_fft_cuda(f, L, nside, False), |
70 |
| - atol=1e-7 , |
71 |
| - rtol=1e-7) |
| 63 | + L = 2 * nside |
| 64 | + # Generate a random bandlimited signal |
| 65 | + flm = flm_generator(L=L, reality=False) |
| 66 | + flm_hp = samples.flm_2d_to_hp(flm, L) |
| 67 | + f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1) |
| 68 | + # Test consistency |
| 69 | + assert_allclose( |
| 70 | + healpix_fft_jax(f, L, nside, False), |
| 71 | + healpix_fft_cuda(f, L, nside, False), |
| 72 | + atol=1e-7, |
| 73 | + rtol=1e-7, |
| 74 | + ) |
| 75 | + |
72 | 76 |
|
73 | 77 | @pytest.mark.skipif(not gpu_available, reason="GPU not available")
|
74 | 78 | @pytest.mark.parametrize("nside", nside_to_test)
|
75 | 79 | def test_healpix_ifft_cuda(flm_generator, nside):
|
76 |
| - L = 2 * nside |
77 |
| - # Generate a random bandlimited signal |
78 |
| - flm = flm_generator(L=L, reality=False) |
79 |
| - flm_hp = samples.flm_2d_to_hp(flm, L) |
80 |
| - f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1) |
81 |
| - ftm = healpix_fft_jax(f, L, nside, False) |
82 |
| - # Test consistency |
83 |
| - assert_allclose( |
84 |
| - healpix_ifft_jax(ftm, L, nside, False).flatten(), |
85 |
| - healpix_ifft_cuda(ftm, L, nside, False).flatten(), |
86 |
| - atol=1e-7, |
87 |
| - rtol=1e-7) |
88 |
| - |
| 80 | + L = 2 * nside |
| 81 | + # Generate a random bandlimited signal |
| 82 | + flm = flm_generator(L=L, reality=False) |
| 83 | + flm_hp = samples.flm_2d_to_hp(flm, L) |
| 84 | + f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1) |
| 85 | + ftm = healpix_fft_jax(f, L, nside, False) |
| 86 | + # Test consistency |
| 87 | + assert_allclose( |
| 88 | + healpix_ifft_jax(ftm, L, nside, False).flatten(), |
| 89 | + healpix_ifft_cuda(ftm, L, nside, False).flatten(), |
| 90 | + atol=1e-7, |
| 91 | + rtol=1e-7, |
| 92 | + ) |
0 commit comments