Skip to content

Commit 60a8114

Browse files
committed
Restore indentation in tests to 4 spaces
1 parent 3abb71b commit 60a8114

File tree

1 file changed

+51
-47
lines changed

1 file changed

+51
-47
lines changed

tests/test_healpix_ffts.py

Lines changed: 51 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import pytest
55
import jax
66
from packaging.version import Version as _Version
7+
78
if _Version(jax.__version__) < _Version("0.4.32"):
8-
from jax.lib.xla_bridge import get_backend
9+
from jax.lib.xla_bridge import get_backend
910
else:
10-
from jax.extend.backend import get_backend
11+
from jax.extend.backend import get_backend
1112
gpu_available = get_backend().platform == "gpu"
1213

1314
jax.config.update("jax_enable_x64", True)
@@ -28,61 +29,64 @@
2829
@pytest.mark.parametrize("nside", nside_to_test)
2930
@pytest.mark.parametrize("reality", reality_to_test)
3031
def test_healpix_fft_jax_numpy_consistency(flm_generator, nside, reality):
31-
L = 2 * nside
32-
# Generate a random bandlimited signal
33-
flm = flm_generator(L=L, reality=reality)
34-
flm_hp = samples.flm_2d_to_hp(flm, L)
35-
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
36-
# Test consistency
37-
assert np.allclose(
38-
healpix_fft_numpy(f, L, nside, reality),
39-
healpix_fft_jax(f, L, nside, reality))
32+
L = 2 * nside
33+
# Generate a random bandlimited signal
34+
flm = flm_generator(L=L, reality=reality)
35+
flm_hp = samples.flm_2d_to_hp(flm, L)
36+
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
37+
# Test consistency
38+
assert np.allclose(
39+
healpix_fft_numpy(f, L, nside, reality), healpix_fft_jax(f, L, nside, reality)
40+
)
4041

4142

4243
@pytest.mark.parametrize("nside", nside_to_test)
4344
@pytest.mark.parametrize("reality", reality_to_test)
4445
def test_healpix_ifft_jax_numpy_consistency(flm_generator, nside, reality):
45-
L = 2 * nside
46-
# Generate a random bandlimited signal
47-
flm = flm_generator(L=L, reality=reality)
48-
flm_hp = samples.flm_2d_to_hp(flm, L)
49-
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
50-
ftm = healpix_fft_numpy(f, L, nside, reality)
51-
ftm_copy = np.copy(ftm)
52-
# Test consistency
53-
assert np.allclose(
54-
healpix_ifft_numpy(ftm, L, nside, reality),
55-
healpix_ifft_jax(ftm_copy, L, nside, reality),
56-
)
46+
L = 2 * nside
47+
# Generate a random bandlimited signal
48+
flm = flm_generator(L=L, reality=reality)
49+
flm_hp = samples.flm_2d_to_hp(flm, L)
50+
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
51+
ftm = healpix_fft_numpy(f, L, nside, reality)
52+
ftm_copy = np.copy(ftm)
53+
# Test consistency
54+
assert np.allclose(
55+
healpix_ifft_numpy(ftm, L, nside, reality),
56+
healpix_ifft_jax(ftm_copy, L, nside, reality),
57+
)
58+
5759

5860
@pytest.mark.skipif(not gpu_available, reason="GPU not available")
5961
@pytest.mark.parametrize("nside", nside_to_test)
6062
def test_healpix_fft_cuda(flm_generator, nside):
61-
L = 2 * nside
62-
# Generate a random bandlimited signal
63-
flm = flm_generator(L=L, reality=False)
64-
flm_hp = samples.flm_2d_to_hp(flm, L)
65-
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
66-
# Test consistency
67-
assert_allclose(
68-
healpix_fft_jax(f, L, nside, False),
69-
healpix_fft_cuda(f, L, nside, False),
70-
atol=1e-7 ,
71-
rtol=1e-7)
63+
L = 2 * nside
64+
# Generate a random bandlimited signal
65+
flm = flm_generator(L=L, reality=False)
66+
flm_hp = samples.flm_2d_to_hp(flm, L)
67+
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
68+
# Test consistency
69+
assert_allclose(
70+
healpix_fft_jax(f, L, nside, False),
71+
healpix_fft_cuda(f, L, nside, False),
72+
atol=1e-7,
73+
rtol=1e-7,
74+
)
75+
7276

7377
@pytest.mark.skipif(not gpu_available, reason="GPU not available")
7478
@pytest.mark.parametrize("nside", nside_to_test)
7579
def test_healpix_ifft_cuda(flm_generator, nside):
76-
L = 2 * nside
77-
# Generate a random bandlimited signal
78-
flm = flm_generator(L=L, reality=False)
79-
flm_hp = samples.flm_2d_to_hp(flm, L)
80-
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
81-
ftm = healpix_fft_jax(f, L, nside, False)
82-
# Test consistency
83-
assert_allclose(
84-
healpix_ifft_jax(ftm, L, nside, False).flatten(),
85-
healpix_ifft_cuda(ftm, L, nside, False).flatten(),
86-
atol=1e-7,
87-
rtol=1e-7)
88-
80+
L = 2 * nside
81+
# Generate a random bandlimited signal
82+
flm = flm_generator(L=L, reality=False)
83+
flm_hp = samples.flm_2d_to_hp(flm, L)
84+
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
85+
ftm = healpix_fft_jax(f, L, nside, False)
86+
# Test consistency
87+
assert_allclose(
88+
healpix_ifft_jax(ftm, L, nside, False).flatten(),
89+
healpix_ifft_cuda(ftm, L, nside, False).flatten(),
90+
atol=1e-7,
91+
rtol=1e-7,
92+
)

0 commit comments

Comments
 (0)