From 41efa9063be366b90ab8059df54af85010596ae1 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Fri, 23 May 2025 19:32:53 +0100 Subject: [PATCH 1/5] Factoring out spherical precompute torch autograd tests and marking slow --- tests/test_spherical_precompute.py | 391 +++++++++++++++-------------- 1 file changed, 208 insertions(+), 183 deletions(-) diff --git a/tests/test_spherical_precompute.py b/tests/test_spherical_precompute.py index 115a4b24..d6a8f526 100644 --- a/tests/test_spherical_precompute.py +++ b/tests/test_spherical_precompute.py @@ -6,7 +6,7 @@ from s2fft.base_transforms import spherical as base from s2fft.precompute_transforms import construct as c -from s2fft.precompute_transforms.spherical import forward, inverse +from s2fft.precompute_transforms.spherical import _kernel_functions, forward, inverse from s2fft.sampling import s2_samples as samples jax.config.update("jax_enable_x64", True) @@ -26,6 +26,68 @@ iter_to_test = [0, 1] +def get_flm_and_kernel( + flm_generator, + L, + spin, + sampling, + reality, + method, + recursion, + forward, + nside=None, +): + flm = flm_generator(L=L, spin=spin, reality=reality) + kfunc = _kernel_functions[method] + kernel = kfunc(L, spin, reality, sampling, nside, forward, recursion=recursion) + return flm, kernel + + +def get_tol(sampling): + return 1e-8 if sampling.lower() in ["dh", "gl"] else 1e-12 + + +def check_spin(recursion, spin): + if recursion.lower() == "price-mcewen" and abs(spin) >= PM_MAX_STABLE_SPIN: + pytest.skip( + f"price-mcewen recursion not accurate above |spin| = {PM_MAX_STABLE_SPIN}" + ) + + +def check_inverse_transform( + flm, kernel, L, spin, sampling, reality, method, nside=None +): + f = inverse( + torch.from_numpy(flm) if method == "torch" else flm, + L, + spin, + kernel, + sampling, + reality, + method, + nside, + ) + if method == "torch": + f = f.resolve_conj().numpy() + f_check = base.inverse(flm, L, spin, sampling, nside, reality) + tol = get_tol(sampling) + np.testing.assert_allclose(f, f_check, atol=tol, rtol=tol) + + +def check_forward_transform( + flm, kernel, L, spin, sampling, reality, method, nside=None +): + f = base.inverse(flm, L, spin, sampling, nside, reality) + flm_check = base.forward(f, L, spin, sampling, nside, reality) + if method == "torch": + f = torch.from_numpy(f) + flm_recov = forward(f, L, spin, kernel, sampling, reality, method, nside) + if method == "torch": + flm_recov = flm_recov.resolve_conj().numpy() + tol = get_tol(sampling) + np.testing.assert_allclose(flm_check, flm_recov, atol=tol, rtol=tol) + + @pytest.mark.parametrize("L", L_to_test) @pytest.mark.parametrize("spin", spin_to_test) @pytest.mark.parametrize("sampling", sampling_to_test) @@ -41,55 +103,34 @@ def test_transform_inverse( method: str, recursion: str, ): - if recursion.lower() == "price-mcewen" and abs(spin) >= PM_MAX_STABLE_SPIN: - pytest.skip( - f"price-mcewen recursion not accurate above |spin| = {PM_MAX_STABLE_SPIN}" - ) + check_spin(recursion, spin) + flm, kernel = get_flm_and_kernel( + flm_generator, L, spin, sampling, reality, method, recursion, forward=False + ) + check_inverse_transform(flm, kernel, L, spin, sampling, reality, method) - flm = flm_generator(L=L, spin=spin, reality=reality) - f_check = base.inverse(flm, L, spin, sampling, reality=reality) - kfunc = ( - c.spin_spherical_kernel_jax - if method.lower() == "jax" - else c.spin_spherical_kernel +@pytest.mark.slow +@pytest.mark.parametrize("L", L_to_test) +@pytest.mark.parametrize("spin", spin_to_test) +@pytest.mark.parametrize("sampling", sampling_to_test) +@pytest.mark.parametrize("reality", reality_to_test) +@pytest.mark.parametrize("recursion", recursions_to_test) +def test_transform_inverse_torch_gradcheck( + flm_generator, + L: int, + spin: int, + sampling: str, + reality: bool, + recursion: str, +): + method = "torch" + flm, kernel = get_flm_and_kernel( + flm_generator, L, spin, sampling, reality, method, recursion, forward=False ) - kernel = kfunc(L, spin, reality, sampling, forward=False, recursion=recursion) - - tol = 1e-8 if sampling.lower() in ["dh", "gl"] else 1e-12 - if method.lower() == "torch": - f = inverse( - torch.from_numpy(flm), - L, - spin, - torch.from_numpy(kernel), - sampling, - reality, - method, - ) - # Test Transform - np.testing.assert_allclose( - f.resolve_conj().numpy(), f_check, atol=tol, rtol=tol - ) - # Test Gradients - flm_grad_test = torch.from_numpy(flm) - flm_grad_test.requires_grad = True - assert torch.autograd.gradcheck( - inverse, - ( - flm_grad_test, - L, - spin, - torch.from_numpy(kernel), - sampling, - reality, - method, - ), - ) - - else: - f = inverse(flm, L, spin, kernel, sampling, reality, method) - np.testing.assert_allclose(f, f_check, atol=tol, rtol=tol) + flm = torch.from_numpy(flm) + flm.requires_grad = True + torch.autograd.gradcheck(inverse, (flm, L, spin, kernel, sampling, reality, method)) @pytest.mark.parametrize("nside", nside_to_test) @@ -106,51 +147,54 @@ def test_transform_inverse_healpix( recursion: str, ): sampling = "healpix" + spin = 0 L = ratio * nside - flm = flm_generator(L=L, reality=reality) - f_check = base.inverse(flm, L, 0, sampling, nside, reality) + flm, kernel = get_flm_and_kernel( + flm_generator, + L, + spin, + sampling, + reality, + method, + recursion, + forward=False, + nside=nside, + ) + check_inverse_transform(flm, kernel, L, spin, sampling, reality, method, nside) + - kfunc = ( - c.spin_spherical_kernel_jax - if method.lower() == "jax" - else c.spin_spherical_kernel +@pytest.mark.slow +@pytest.mark.parametrize("nside", nside_to_test) +@pytest.mark.parametrize("ratio", L_to_nside_ratio) +@pytest.mark.parametrize("reality", reality_to_test) +@pytest.mark.parametrize("recursion", recursions_to_test) +def test_transform_inverse_healpix_torch_gradcheck( + flm_generator, + nside: int, + ratio: int, + reality: bool, + recursion: str, +): + method = "torch" + sampling = "healpix" + spin = 0 + L = ratio * nside + flm, kernel = get_flm_and_kernel( + flm_generator, + L, + spin, + sampling, + reality, + method, + recursion, + forward=False, + nside=nside, + ) + flm = torch.from_numpy(flm) + flm.requires_grad = True + torch.autograd.gradcheck( + inverse, (flm, L, spin, kernel, sampling, reality, method, nside) ) - kernel = kfunc(L, 0, reality, sampling, nside, False, recursion=recursion) - - tol = 1e-8 if sampling.lower() in ["dh", "gl"] else 1e-12 - if method.lower() == "torch": - # Test Transform - f = inverse( - torch.from_numpy(flm), - L, - 0, - torch.from_numpy(kernel), - sampling, - reality, - method, - nside, - ) - np.testing.assert_allclose(f, f_check, atol=tol, rtol=tol) - - # Test Gradients - flm_grad_test = torch.from_numpy(flm) - flm_grad_test.requires_grad = True - assert torch.autograd.gradcheck( - inverse, - ( - flm_grad_test, - L, - 0, - torch.from_numpy(kernel), - sampling, - reality, - method, - nside, - ), - ) - else: - f = inverse(flm, L, 0, kernel, sampling, reality, method, nside) - np.testing.assert_allclose(f, f_check, atol=tol, rtol=tol) @pytest.mark.parametrize("L", L_to_test) @@ -168,56 +212,34 @@ def test_transform_forward( method: str, recursion: str, ): - if recursion.lower() == "price-mcewen" and abs(spin) >= PM_MAX_STABLE_SPIN: - pytest.skip( - f"price-mcewen recursion not accurate above |spin| = {PM_MAX_STABLE_SPIN}" - ) - - flm = flm_generator(L=L, spin=spin, reality=reality) + check_spin(recursion, spin) + flm, kernel = get_flm_and_kernel( + flm_generator, L, spin, sampling, reality, method, recursion, forward=True + ) + check_forward_transform(flm, kernel, L, spin, sampling, reality, method) - f = base.inverse(flm, L, spin, sampling, reality=reality) - flm_check = base.forward(f, L, spin, sampling, reality=reality) - kfunc = ( - c.spin_spherical_kernel_jax - if method.lower() == "jax" - else c.spin_spherical_kernel +@pytest.mark.slow +@pytest.mark.parametrize("L", L_to_test) +@pytest.mark.parametrize("spin", spin_to_test) +@pytest.mark.parametrize("sampling", sampling_to_test) +@pytest.mark.parametrize("reality", reality_to_test) +@pytest.mark.parametrize("recursion", recursions_to_test) +def test_transform_forward_torch_gradcheck( + flm_generator, + L: int, + spin: int, + sampling: str, + reality: bool, + recursion: str, +): + method = "torch" + flm, kernel = get_flm_and_kernel( + flm_generator, L, spin, sampling, reality, method, recursion, forward=True ) - kernel = kfunc(L, spin, reality, sampling, forward=True, recursion=recursion) - - tol = 1e-8 if sampling.lower() in ["dh", "gl"] else 1e-12 - if method.lower() == "torch": - # Test Transform - flm_recov = forward( - torch.from_numpy(f), - L, - spin, - torch.from_numpy(kernel), - sampling, - reality, - method, - ) - - np.testing.assert_allclose(flm_check, flm_recov, atol=tol, rtol=tol) - - # Test Gradients - f_grad_test = torch.from_numpy(f) - f_grad_test.requires_grad = True - assert torch.autograd.gradcheck( - forward, - ( - f_grad_test, - L, - spin, - torch.from_numpy(kernel), - sampling, - reality, - method, - ), - ) - else: - flm_recov = forward(f, L, spin, kernel, sampling, reality, method) - np.testing.assert_allclose(flm_check, flm_recov, atol=tol, rtol=tol) + f = torch.from_numpy(base.inverse(flm, L, spin, sampling, reality=reality)) + f.requires_grad = True + torch.autograd.gradcheck(forward, (f, L, spin, kernel, sampling, reality, method)) @pytest.mark.parametrize("nside", nside_to_test) @@ -236,54 +258,57 @@ def test_transform_forward_healpix( iter: int, ): sampling = "healpix" + spin = 0 L = ratio * nside - flm = flm_generator(L=L, reality=True) - f = base.inverse(flm, L, 0, sampling, nside, reality) - flm_check = base.forward(f, L, 0, sampling, nside, reality, iter=iter) - - kfunc = ( - c.spin_spherical_kernel_jax - if method.lower() == "jax" - else c.spin_spherical_kernel + check_spin(recursion, spin) + flm, kernel = get_flm_and_kernel( + flm_generator, + L, + spin, + sampling, + reality, + method, + recursion, + forward=True, + nside=nside, + ) + check_forward_transform(flm, kernel, L, spin, sampling, reality, method, nside) + + +@pytest.mark.slow +@pytest.mark.parametrize("nside", nside_to_test) +@pytest.mark.parametrize("ratio", L_to_nside_ratio) +@pytest.mark.parametrize("reality", reality_to_test) +@pytest.mark.parametrize("recursion", recursions_to_test) +@pytest.mark.parametrize("iter", iter_to_test) +def test_transform_forward_healpix_torch_gradcheck( + flm_generator, + nside: int, + ratio: int, + reality: bool, + recursion: str, + iter: int, +): + method = "torch" + sampling = "healpix" + spin = 0 + L = ratio * nside + flm, kernel = get_flm_and_kernel( + flm_generator, + L, + spin, + sampling, + reality, + method, + recursion, + forward=True, + nside=nside, + ) + f = torch.from_numpy(base.inverse(flm, L, spin, sampling, nside, reality)) + f.requires_grad = True + torch.autograd.gradcheck( + forward, (f, L, spin, kernel, sampling, reality, method, nside) ) - kernel = kfunc(L, 0, reality, sampling, nside, True, recursion=recursion) - - tol = 1e-8 if sampling.lower() in ["dh", "gl"] else 1e-12 - if method.lower() == "torch": - # Test Transform - flm_recov = forward( - torch.from_numpy(f), - L, - 0, - torch.from_numpy(kernel), - sampling, - reality, - method, - nside, - iter, - ) - np.testing.assert_allclose(flm_recov, flm_check, atol=tol, rtol=tol) - - # Test Gradients - f_grad_test = torch.from_numpy(f) - f_grad_test.requires_grad = True - assert torch.autograd.gradcheck( - forward, - ( - f_grad_test, - L, - 0, - torch.from_numpy(kernel), - sampling, - reality, - method, - nside, - iter, - ), - ) - else: - flm_recov = forward(f, L, 0, kernel, sampling, reality, method, nside, iter) - np.testing.assert_allclose(flm_recov, flm_check, atol=tol, rtol=tol) @pytest.mark.parametrize("spin", [0, 20, 30, -20, -30]) @@ -330,7 +355,7 @@ def test_transform_forward_high_spin( kernel = c.spin_spherical_kernel(L, spin, reality, sampling, forward=True) flm_recov = forward(f, L, spin, kernel, sampling, reality, "numpy") - tol = 1e-8 if sampling.lower() in ["dh", "gl"] else 1e-12 + tol = get_tol(sampling) np.testing.assert_allclose(flm_recov, flm, atol=tol, rtol=tol) From 44b5847027d10d991d4026e6656efc9c8e38353c Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Fri, 23 May 2025 19:33:17 +0100 Subject: [PATCH 2/5] Factoring out Wigner precompute torch autograd tests and marking slow --- tests/test_wigner_precompute.py | 349 +++++++++++++++++--------------- 1 file changed, 186 insertions(+), 163 deletions(-) diff --git a/tests/test_wigner_precompute.py b/tests/test_wigner_precompute.py index 28ec66ba..eb6c910c 100644 --- a/tests/test_wigner_precompute.py +++ b/tests/test_wigner_precompute.py @@ -6,7 +6,7 @@ from s2fft.base_transforms import wigner as base from s2fft.precompute_transforms import construct as c -from s2fft.precompute_transforms.wigner import forward, inverse +from s2fft.precompute_transforms.wigner import _kernel_functions, forward, inverse from s2fft.sampling import so3_samples as samples jax.config.update("jax_enable_x64", True) @@ -21,6 +21,57 @@ modes_to_test = ["auto", "fft", "direct"] +def check_mode_and_sampling(mode, sampling): + if mode.lower() == "fft" and sampling.lower() not in ["mw", "mwss", "dh"]: + pytest.skip( + f"Fourier based Wigner computation not valid for sampling={sampling}" + ) + + +def get_flmn_and_kernel( + flmn_generator, L, N, sampling, reality, method, mode, forward, nside=None +): + flmn = flmn_generator(L=L, N=N, reality=reality) + kfunc = _kernel_functions[method] + kernel = kfunc(L, N, reality, sampling, nside, forward=forward, mode=mode) + return flmn, kernel + + +def check_inverse_transform(flmn, kernel, L, N, sampling, reality, method, nside=None): + f = inverse( + torch.from_numpy(flmn) if method == "torch" else flmn, + L, + N, + kernel, + sampling, + reality, + method, + nside, + ) + if method == "torch": + f = f.resolve_conj().numpy() + f_check = base.inverse(flmn, L, N, 0, sampling, reality, nside) + np.testing.assert_allclose(f, f_check, atol=1e-5, rtol=1e-5) + + +def check_forward_tranform(flmn, kernel, L, N, sampling, reality, method, nside=None): + f = base.inverse(flmn, L, N, sampling=sampling, reality=reality, nside=nside) + flmn_check = base.forward(f, L, N, sampling=sampling, reality=reality, nside=nside) + flmn = forward( + torch.from_numpy(f) if method == "torch" else f, + L, + N, + kernel, + sampling, + reality, + method, + nside, + ) + if method == "torch": + flmn = flmn.resolve_conj().numpy() + np.testing.assert_allclose(flmn, flmn_check, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("L", L_to_test) @pytest.mark.parametrize("N", N_to_test) @pytest.mark.parametrize("sampling", sampling_schemes) @@ -36,50 +87,37 @@ def test_inverse_wigner_transform( method: str, mode: str, ): - if mode.lower() == "fft" and sampling.lower() not in ["mw", "mwss", "dh"]: - pytest.skip( - f"Fourier based Wigner computation not valid for sampling={sampling}" - ) - - flmn = flmn_generator(L=L, N=N, reality=reality) + check_mode_and_sampling(mode, sampling) + flmn, kernel = get_flmn_and_kernel( + flmn_generator, L, N, sampling, reality, method, mode, forward=False + ) + check_inverse_transform(flmn, kernel, L, N, sampling, reality, method) - f = base.inverse(flmn, L, N, 0, sampling, reality) - - kfunc = c.wigner_kernel_jax if method == "jax" else c.wigner_kernel - kernel = kfunc(L, N, reality, sampling, forward=False, mode=mode) - - if method.lower() == "torch": - # Test Transform - f_check = inverse( - torch.from_numpy(flmn), - L, - N, - torch.from_numpy(kernel), - sampling, - reality, - method, - ) - np.testing.assert_allclose(f, f_check, atol=1e-5, rtol=1e-5) - - # Test Gradients - flmn_grad_test = torch.from_numpy(flmn) - flmn_grad_test.requires_grad = True - assert torch.autograd.gradcheck( - inverse, - ( - flmn_grad_test, - L, - N, - torch.from_numpy(kernel), - sampling, - reality, - method, - ), - ) - else: - f_check = inverse(flmn, L, N, kernel, sampling, reality, method) - np.testing.assert_allclose(f, f_check, atol=1e-5, rtol=1e-5) +@pytest.mark.slow +@pytest.mark.parametrize("L", L_to_test) +@pytest.mark.parametrize("N", N_to_test) +@pytest.mark.parametrize("sampling", sampling_schemes) +@pytest.mark.parametrize("reality", reality_to_test) +@pytest.mark.parametrize("mode", modes_to_test) +def test_inverse_wigner_transform_torch_gradcheck( + flmn_generator, + L: int, + N: int, + sampling: str, + reality: bool, + mode: str, +): + method = "torch" + check_mode_and_sampling(mode, sampling) + flmn, kernel = get_flmn_and_kernel( + flmn_generator, L, N, sampling, reality, method, mode, forward=False + ) + flmn = torch.from_numpy(flmn) + flmn.requires_grad = True + assert torch.autograd.gradcheck( + inverse, (flmn, L, N, kernel, sampling, reality, method) + ) @pytest.mark.parametrize("L", L_to_test) @@ -97,49 +135,38 @@ def test_forward_wigner_transform( method: str, mode: str, ): - if mode.lower() == "fft" and sampling.lower() not in ["mw", "mwss", "dh"]: - pytest.skip( - f"Fourier based Wigner computation not valid for sampling={sampling}" - ) - flmn = flmn_generator(L=L, N=N, reality=reality) + check_mode_and_sampling(mode, sampling) + flmn, kernel = get_flmn_and_kernel( + flmn_generator, L, N, sampling, reality, method, mode, forward=True + ) + check_forward_tranform(flmn, kernel, L, N, sampling, reality, method) + +@pytest.mark.slow +@pytest.mark.parametrize("L", L_to_test) +@pytest.mark.parametrize("N", N_to_test) +@pytest.mark.parametrize("sampling", sampling_schemes) +@pytest.mark.parametrize("reality", reality_to_test) +@pytest.mark.parametrize("mode", modes_to_test) +def test_forward_wigner_transform_torch_gradcheck( + flmn_generator, + L: int, + N: int, + sampling: str, + reality: bool, + mode: str, +): + method = "torch" + check_mode_and_sampling(mode, sampling) + flmn, kernel = get_flmn_and_kernel( + flmn_generator, L, N, sampling, reality, method, mode, forward=True + ) f = base.inverse(flmn, L, N, sampling=sampling, reality=reality) - flmn = base.forward(f, L, N, sampling=sampling, reality=reality) - - kfunc = c.wigner_kernel_jax if method == "jax" else c.wigner_kernel - kernel = kfunc(L, N, reality, sampling, forward=True, mode=mode) - - if method.lower() == "torch": - # Test Transform - flmn_check = forward( - torch.from_numpy(f), - L, - N, - torch.from_numpy(kernel), - sampling, - reality, - method, - ) - np.testing.assert_allclose(flmn, flmn_check, atol=1e-5, rtol=1e-5) - - # Test Gradients - f_grad_test = torch.from_numpy(f) - f_grad_test.requires_grad = True - assert torch.autograd.gradcheck( - forward, - ( - f_grad_test, - L, - N, - torch.from_numpy(kernel), - sampling, - reality, - method, - ), - ) - else: - flmn_check = forward(f, L, N, kernel, sampling, reality, method) - np.testing.assert_allclose(flmn, flmn_check, atol=1e-5, rtol=1e-5) + f = torch.from_numpy(f) + f.requires_grad = True + assert torch.autograd.gradcheck( + forward, (f, L, N, kernel, sampling, reality, method) + ) @pytest.mark.parametrize("nside", nside_to_test) @@ -156,48 +183,54 @@ def test_inverse_wigner_transform_healpix( method: str, ): sampling = "healpix" + mode = "auto" L = ratio * nside - flmn = flmn_generator(L=L, N=N, reality=reality) + flmn, kernel = get_flmn_and_kernel( + flmn_generator, + L, + N, + sampling, + reality, + method, + mode, + forward=False, + nside=nside, + ) + check_inverse_transform(flmn, kernel, L, N, sampling, reality, method, nside) - f = base.inverse(flmn, L, N, 0, sampling, reality, nside) - - kfunc = c.wigner_kernel_jax if method == "jax" else c.wigner_kernel - kernel = kfunc(L, N, reality, sampling, nside, forward=False) - - if method.lower() == "torch": - # Test Transform - f_check = inverse( - torch.from_numpy(flmn), - L, - N, - torch.from_numpy(kernel), - sampling, - reality, - method, - nside, - ) - np.testing.assert_allclose(np.real(f), np.real(f_check), atol=1e-5, rtol=1e-5) - - # Test Gradients - flmn_grad_test = torch.from_numpy(flmn) - flmn_grad_test.requires_grad = True - assert torch.autograd.gradcheck( - inverse, - ( - flmn_grad_test, - L, - N, - torch.from_numpy(kernel), - sampling, - reality, - method, - nside, - ), - ) - else: - f_check = inverse(flmn, L, N, kernel, sampling, reality, method, nside) - np.testing.assert_allclose(f, f_check, atol=1e-5, rtol=1e-5) +@pytest.mark.slow +@pytest.mark.parametrize("nside", nside_to_test) +@pytest.mark.parametrize("ratio", L_to_nside_ratio) +@pytest.mark.parametrize("N", N_to_test) +@pytest.mark.parametrize("reality", reality_to_test) +def test_inverse_wigner_transform_healpix_torch_gradcheck( + flmn_generator, + nside: int, + ratio: int, + N: int, + reality: bool, +): + method = "torch" + sampling = "healpix" + mode = "auto" + L = ratio * nside + flmn, kernel = get_flmn_and_kernel( + flmn_generator, + L, + N, + sampling, + reality, + method, + mode, + forward=False, + nside=nside, + ) + flmn = torch.from_numpy(flmn) + flmn.requires_grad = True + assert torch.autograd.gradcheck( + inverse, (flmn, L, N, kernel, sampling, reality, method, nside) + ) @pytest.mark.parametrize("nside", nside_to_test) @@ -214,49 +247,39 @@ def test_forward_wigner_transform_healpix( method: str, ): sampling = "healpix" + mode = "auto" L = ratio * nside - flmn = flmn_generator(L=L, N=N, reality=reality) + flmn, kernel = get_flmn_and_kernel( + flmn_generator, L, N, sampling, reality, method, mode, forward=True, nside=nside + ) + check_forward_tranform(flmn, kernel, L, N, sampling, reality, method, nside) - f = base.inverse(flmn, L, N, 0, sampling, reality, nside) - flmn_check = base.forward(f, L, N, 0, sampling, reality, nside) - - kfunc = c.wigner_kernel_jax if method == "jax" else c.wigner_kernel - kernel = kfunc(L, N, reality, sampling, nside, forward=True) - - if method.lower() == "torch": - # Test Transform - flmn = forward( - torch.from_numpy(f), - L, - N, - torch.from_numpy(kernel), - sampling, - reality, - method, - nside, - ) - np.testing.assert_allclose(flmn, flmn_check, atol=1e-5, rtol=1e-5) - - # Test Gradients - f_grad_test = torch.from_numpy(f) - f_grad_test.requires_grad = True - assert torch.autograd.gradcheck( - forward, - ( - f_grad_test, - L, - N, - torch.from_numpy(kernel), - sampling, - reality, - method, - nside, - ), - ) - else: - flmn = forward(f, L, N, kernel, sampling, reality, method, nside) - np.testing.assert_allclose(flmn, flmn_check, atol=1e-5, rtol=1e-5) +@pytest.mark.slow +@pytest.mark.parametrize("nside", nside_to_test) +@pytest.mark.parametrize("ratio", L_to_nside_ratio) +@pytest.mark.parametrize("N", N_to_test) +@pytest.mark.parametrize("reality", reality_to_test) +def test_forward_wigner_transform_healpix_torch_gradcheck( + flmn_generator, + nside: int, + ratio: int, + N: int, + reality: bool, +): + method = "torch" + sampling = "healpix" + mode = "auto" + L = ratio * nside + flmn, kernel = get_flmn_and_kernel( + flmn_generator, L, N, sampling, reality, method, mode, forward=True, nside=nside + ) + f = base.inverse(flmn, L, N, sampling=sampling, reality=reality, nside=nside) + f = torch.from_numpy(f) + f.requires_grad = True + assert torch.autograd.gradcheck( + forward, (f, L, N, kernel, sampling, reality, method, nside) + ) @pytest.mark.parametrize("L", [8, 16, 32]) From 8d2aae8cace080447882dd53de9f9bdde738d8a2 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Fri, 23 May 2025 19:33:33 +0100 Subject: [PATCH 3/5] Add slow test mark metadata --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 6e05c936..c797c2b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,7 @@ filterwarnings = [ "ignore::DeprecationWarning", "ignore:FutureWarning", ] +markers = ["slow: mark test as slow"] [tool.ruff] fix = true From 9554a7f3d278315f4cf66b95b386964e6da8e799 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Fri, 23 May 2025 19:34:09 +0100 Subject: [PATCH 4/5] Skip slow tests for pull-request triggered Actions jobs --- .github/workflows/tests.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index da86f094..b4ddd417 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -54,7 +54,13 @@ jobs: python -m pip install --upgrade pip pip install .[tests] + - name: Run tests (skipping slow) + if: github.event_name == 'pull_request' + run: | + pytest -v --cov-report=xml --cov=s2fft --cov-config=.coveragerc -m "no slow" + - name: Run tests + if: github.event_name != 'pull_request' run: | pytest -v --cov-report=xml --cov=s2fft --cov-config=.coveragerc From a3f7cf5adf60ce3f0c3dbc4b07513015395900f2 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Fri, 23 May 2025 19:40:56 +0100 Subject: [PATCH 5/5] Correct skip slow specification in workflow --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b4ddd417..d7d99af1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -57,7 +57,7 @@ jobs: - name: Run tests (skipping slow) if: github.event_name == 'pull_request' run: | - pytest -v --cov-report=xml --cov=s2fft --cov-config=.coveragerc -m "no slow" + pytest -v --cov-report=xml --cov=s2fft --cov-config=.coveragerc -m "not slow" - name: Run tests if: github.event_name != 'pull_request'