Skip to content

Commit 3551865

Browse files
committed
run linting on test scripts
1 parent ba58339 commit 3551865

File tree

7 files changed

+57
-133
lines changed

7 files changed

+57
-133
lines changed

s2fft/precompute_transforms/construct.py

Lines changed: 19 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,7 @@ def spin_spherical_kernel(
7474
if recursion.lower() == "auto":
7575
# This mode automatically determines which recursion is best suited for the
7676
# current parameter configuration.
77-
recursion = (
78-
"risbo" if abs(spin) >= PM_MAX_STABLE_SPIN else "price-mcewen"
79-
)
77+
recursion = "risbo" if abs(spin) >= PM_MAX_STABLE_SPIN else "price-mcewen"
8078

8179
dl = []
8280
m_start_ind = L - 1 if reality else 0
@@ -111,13 +109,9 @@ def spin_spherical_kernel(
111109
# - The complexity of this approach is O(L^4).
112110
# - This approach is stable for arbitrary abs(spins) <= L.
113111
if sampling.lower() in ["healpix", "gl"]:
114-
delta = np.zeros(
115-
(len(thetas), 2 * L - 1, 2 * L - 1), dtype=np.float64
116-
)
112+
delta = np.zeros((len(thetas), 2 * L - 1, 2 * L - 1), dtype=np.float64)
117113
for el in range(L):
118-
delta = recursions.risbo.compute_full_vectorised(
119-
delta, thetas, L, el
120-
)
114+
delta = recursions.risbo.compute_full_vectorised(delta, thetas, L, el)
121115
dl[:, el] = delta[:, m_start_ind:, L - 1 - spin]
122116

123117
# MW, MWSS, and DH sampling ARE uniform in theta therefore CAN be calculated
@@ -144,19 +138,13 @@ def spin_spherical_kernel(
144138
delta[:, L - 1 - spin],
145139
1j ** (-spin - m_value[m_start_ind:]),
146140
)
147-
temp = np.einsum(
148-
"am,a->am", temp, np.exp(1j * m_value * thetas[0])
149-
)
150-
temp = np.fft.irfft(
151-
temp[L - 1 :], n=nsamps, axis=0, norm="forward"
152-
)
141+
temp = np.einsum("am,a->am", temp, np.exp(1j * m_value * thetas[0]))
142+
temp = np.fft.irfft(temp[L - 1 :], n=nsamps, axis=0, norm="forward")
153143

154144
dl[:, el] = temp[: len(thetas)]
155145

156146
# Fold in normalisation to avoid recomputation at run-time.
157-
dl = np.einsum(
158-
"tlm,l->tlm", dl, np.sqrt((2 * np.arange(L) + 1) / (4 * np.pi))
159-
)
147+
dl = np.einsum("tlm,l->tlm", dl, np.sqrt((2 * np.arange(L) + 1) / (4 * np.pi)))
160148

161149
else:
162150
raise ValueError(f"Recursion method {recursion} not recognised.")
@@ -234,9 +222,7 @@ def spin_spherical_kernel_jax(
234222
if recursion.lower() == "auto":
235223
# This mode automatically determines which recursion is best suited for the
236224
# current parameter configuration.
237-
recursion = (
238-
"risbo" if abs(spin) >= PM_MAX_STABLE_SPIN else "price-mcewen"
239-
)
225+
recursion = "risbo" if abs(spin) >= PM_MAX_STABLE_SPIN else "price-mcewen"
240226

241227
dl = []
242228
m_start_ind = L - 1 if reality else 0
@@ -283,9 +269,7 @@ def spin_spherical_kernel_jax(
283269
# - The complexity of this approach is O(L^4).
284270
# - This approach is stable for arbitrary abs(spins) <= L.
285271
if sampling.lower() in ["healpix", "gl"]:
286-
delta = jnp.zeros(
287-
(len(thetas), 2 * L - 1, 2 * L - 1), dtype=jnp.float64
288-
)
272+
delta = jnp.zeros((len(thetas), 2 * L - 1, 2 * L - 1), dtype=jnp.float64)
289273
vfunc = jax.vmap(
290274
recursions.risbo_jax.compute_full, in_axes=(0, 0, None, None)
291275
)
@@ -309,32 +293,24 @@ def spin_spherical_kernel_jax(
309293

310294
# Calculate the Fourier coefficients of the Wigner d-functions, delta(pi/2).
311295
for el in range(L):
312-
delta = recursions.risbo_jax.compute_full(
313-
delta, jnp.pi / 2, L, el
314-
)
296+
delta = recursions.risbo_jax.compute_full(delta, jnp.pi / 2, L, el)
315297
m_value = jnp.arange(-L + 1, L)
316298
temp = jnp.einsum(
317299
"am,a,m->am",
318300
delta[:, m_start_ind:],
319301
delta[:, L - 1 - spin],
320302
1j ** (-spin - m_value[m_start_ind:]),
321303
)
322-
temp = jnp.einsum(
323-
"am,a->am", temp, jnp.exp(1j * m_value * thetas[0])
324-
)
325-
temp = jnp.fft.irfft(
326-
temp[L - 1 :], n=nsamps, axis=0, norm="forward"
327-
)
304+
temp = jnp.einsum("am,a->am", temp, jnp.exp(1j * m_value * thetas[0]))
305+
temp = jnp.fft.irfft(temp[L - 1 :], n=nsamps, axis=0, norm="forward")
328306

329307
dl = dl.at[:, el].set(temp[: len(thetas)])
330308

331309
else:
332310
raise ValueError(f"Recursion method {recursion} not recognised.")
333311

334312
# Fold in normalisation to avoid recomputation at run-time.
335-
dl = jnp.einsum(
336-
"tlm,l->tlm", dl, jnp.sqrt((2 * jnp.arange(L) + 1) / (4 * jnp.pi))
337-
)
313+
dl = jnp.einsum("tlm,l->tlm", dl, jnp.sqrt((2 * jnp.arange(L) + 1) / (4 * jnp.pi)))
338314

339315
# Fold in quadrature to avoid recomputation at run-time.
340316
if forward:
@@ -433,9 +409,7 @@ def wigner_kernel(
433409
if mode.lower() == "direct":
434410
delta = np.zeros((len(thetas), 2 * L - 1, 2 * L - 1), dtype=np.float64)
435411
for el in range(L):
436-
delta = recursions.risbo.compute_full_vectorised(
437-
delta, thetas, L, el
438-
)
412+
delta = recursions.risbo.compute_full_vectorised(delta, thetas, L, el)
439413
dl[:, :, el] = np.moveaxis(delta, -1, 0)[L - 1 + n]
440414

441415
# MW, MWSS, and DH sampling ARE uniform in theta therefore CAN be calculated
@@ -464,9 +438,7 @@ def wigner_kernel(
464438
1j ** (-m_value),
465439
1j ** (n),
466440
)
467-
temp = np.einsum(
468-
"amn,a->amn", temp, np.exp(1j * m_value * thetas[0])
469-
)
441+
temp = np.einsum("amn,a->amn", temp, np.exp(1j * m_value * thetas[0]))
470442
temp = np.fft.irfft(temp[L - 1 :], n=nsamps, axis=0, norm="forward")
471443
dl[:, :, el] = np.moveaxis(temp[: len(thetas)], -1, 0)
472444

@@ -574,12 +546,8 @@ def wigner_kernel_jax(
574546
# - The complexity of this approach is ALWAYS O(L^4).
575547
# - This approach is stable for arbitrary abs(spins) <= L.
576548
if mode.lower() == "direct":
577-
delta = jnp.zeros(
578-
(len(thetas), 2 * L - 1, 2 * L - 1), dtype=jnp.float64
579-
)
580-
vfunc = jax.vmap(
581-
recursions.risbo_jax.compute_full, in_axes=(0, 0, None, None)
582-
)
549+
delta = jnp.zeros((len(thetas), 2 * L - 1, 2 * L - 1), dtype=jnp.float64)
550+
vfunc = jax.vmap(recursions.risbo_jax.compute_full, in_axes=(0, 0, None, None))
583551
for el in range(L):
584552
delta = vfunc(delta, thetas, L, el)
585553
dl = dl.at[:, :, el].set(jnp.moveaxis(delta, -1, 0)[L - 1 + n])
@@ -610,12 +578,8 @@ def wigner_kernel_jax(
610578
1j ** (-m_value),
611579
1j ** (n),
612580
)
613-
temp = jnp.einsum(
614-
"amn,a->amn", temp, jnp.exp(1j * m_value * thetas[0])
615-
)
616-
temp = jnp.fft.irfft(
617-
temp[L - 1 :], n=nsamps, axis=0, norm="forward"
618-
)
581+
temp = jnp.einsum("amn,a->amn", temp, jnp.exp(1j * m_value * thetas[0]))
582+
temp = jnp.fft.irfft(temp[L - 1 :], n=nsamps, axis=0, norm="forward")
619583
dl = dl.at[:, :, el].set(jnp.moveaxis(temp[: len(thetas)], -1, 0))
620584

621585
else:
@@ -646,9 +610,7 @@ def wigner_kernel_jax(
646610
return dl
647611

648612

649-
def healpix_phase_shifts(
650-
L: int, nside: int, forward: bool = False
651-
) -> np.ndarray:
613+
def healpix_phase_shifts(L: int, nside: int, forward: bool = False) -> np.ndarray:
652614
r"""
653615
Generates a phase shift vector for HEALPix for all :math:`\theta` rings.
654616

s2fft/precompute_transforms/spherical.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,9 @@ def inverse(
6565
if method == "numpy":
6666
return inverse_transform(flm, kernel, L, sampling, reality, spin, nside)
6767
elif method == "jax":
68-
return inverse_transform_jax(
69-
flm, kernel, L, sampling, reality, spin, nside
70-
)
68+
return inverse_transform_jax(flm, kernel, L, sampling, reality, spin, nside)
7169
elif method == "torch":
72-
return inverse_transform_torch(
73-
flm, kernel, L, sampling, reality, spin, nside
74-
)
70+
return inverse_transform_torch(flm, kernel, L, sampling, reality, spin, nside)
7571
else:
7672
raise ValueError(f"Method {method} not recognised.")
7773

@@ -193,9 +189,7 @@ def inverse_transform_jax(
193189
if sampling.lower() == "healpix":
194190
if reality:
195191
ftm = ftm.at[:, m_offset : m_start_ind + m_offset].set(
196-
jnp.flip(
197-
jnp.conj(ftm[:, m_start_ind + m_offset + 1 :]), axis=-1
198-
)
192+
jnp.flip(jnp.conj(ftm[:, m_start_ind + m_offset + 1 :]), axis=-1)
199193
)
200194
f = hp.healpix_ifft(ftm, L, nside, "jax", reality)
201195

@@ -252,9 +246,7 @@ def inverse_transform_torch(
252246
m_offset = 1 if sampling in ["mwss", "healpix"] else 0
253247
m_start_ind = L - 1 if reality else 0
254248

255-
ftm = torch.zeros(
256-
samples.ftm_shape(L, sampling, nside), dtype=torch.complex128
257-
)
249+
ftm = torch.zeros(samples.ftm_shape(L, sampling, nside), dtype=torch.complex128)
258250
if sampling.lower() == "healpix":
259251
ftm[:, m_start_ind + m_offset :] += torch.einsum(
260252
"...tlm, ...lm -> ...tm", kernel, flm[:, m_start_ind:]
@@ -348,13 +340,9 @@ def forward(
348340
if method == "numpy":
349341
return forward_transform(f, kernel, L, sampling, reality, spin, nside)
350342
elif method == "jax":
351-
return forward_transform_jax(
352-
f, kernel, L, sampling, reality, spin, nside
353-
)
343+
return forward_transform_jax(f, kernel, L, sampling, reality, spin, nside)
354344
elif method == "torch":
355-
return forward_transform_torch(
356-
f, kernel, L, sampling, reality, spin, nside
357-
)
345+
return forward_transform_torch(f, kernel, L, sampling, reality, spin, nside)
358346
else:
359347
raise ValueError(f"Method {method} not recognised.")
360348

@@ -495,8 +483,7 @@ def forward_transform_jax(
495483
if reality:
496484
flm = flm.at[:, :m_start_ind].set(
497485
jnp.flip(
498-
(-1) ** (jnp.arange(1, L) % 2)
499-
* jnp.conj(flm[:, m_start_ind + 1 :]),
486+
(-1) ** (jnp.arange(1, L) % 2) * jnp.conj(flm[:, m_start_ind + 1 :]),
500487
axis=-1,
501488
)
502489
)
@@ -564,9 +551,7 @@ def forward_transform_torch(
564551

565552
flm = torch.zeros(samples.flm_shape(L), dtype=torch.complex128)
566553
if sampling.lower() == "healpix":
567-
flm[:, m_start_ind:] = torch.einsum(
568-
"...tlm, ...tm -> ...lm", kernel, ftm
569-
)
554+
flm[:, m_start_ind:] = torch.einsum("...tlm, ...tm -> ...lm", kernel, ftm)
570555
else:
571556
flm[:, m_start_ind:].real = torch.einsum(
572557
"...tlm, ...tm -> ...lm", kernel, ftm.real
@@ -577,8 +562,7 @@ def forward_transform_torch(
577562

578563
if reality:
579564
flm[:, :m_start_ind] = torch.flip(
580-
(-1) ** (torch.arange(1, L) % 2)
581-
* torch.conj(flm[:, m_start_ind + 1 :]),
565+
(-1) ** (torch.arange(1, L) % 2) * torch.conj(flm[:, m_start_ind + 1 :]),
582566
dims=[-1],
583567
)
584568

s2fft/recursions/risbo.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,7 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray:
8989

9090
ddj = dd[i, k] / j
9191

92-
dl[k - el + L - 1, i - el + L - 1] += (
93-
sqrt_jmi * sqrt_jmk * ddj * coshb
94-
)
92+
dl[k - el + L - 1, i - el + L - 1] += sqrt_jmi * sqrt_jmk * ddj * coshb
9593
dl[k - el + L - 1, i + 1 - el + L - 1] -= (
9694
sqrt_ip1 * sqrt_jmk * ddj * sinhb
9795
)

s2fft/recursions/risbo_jax.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,29 +68,21 @@ def compute_full(dl: jnp.ndarray, beta: float, L: int, el: int) -> jnp.ndarray:
6868
dlj = dl[k - (el - 1) + L - 1][:, i - (el - 1) + L - 1]
6969

7070
dd = dd.at[:j, :j].add(
71-
jnp.einsum("i,k->ki", sqrt_jmi, sqrt_jmk, optimize=True)
72-
* dlj
73-
* coshb
71+
jnp.einsum("i,k->ki", sqrt_jmi, sqrt_jmk, optimize=True) * dlj * coshb
7472
)
7573
dd = dd.at[:j, 1 : j + 1].add(
76-
jnp.einsum("i,k->ki", -sqrt_ip1, sqrt_jmk, optimize=True)
77-
* dlj
78-
* sinhb
74+
jnp.einsum("i,k->ki", -sqrt_ip1, sqrt_jmk, optimize=True) * dlj * sinhb
7975
)
8076
dd = dd.at[1 : j + 1, :j].add(
81-
jnp.einsum("i,k->ki", sqrt_jmi, sqrt_kp1, optimize=True)
82-
* dlj
83-
* sinhb
77+
jnp.einsum("i,k->ki", sqrt_jmi, sqrt_kp1, optimize=True) * dlj * sinhb
8478
)
8579
dd = dd.at[1 : j + 1, 1 : j + 1].add(
86-
jnp.einsum("i,k->ki", sqrt_ip1, sqrt_kp1, optimize=True)
87-
* dlj
88-
* coshb
80+
jnp.einsum("i,k->ki", sqrt_ip1, sqrt_kp1, optimize=True) * dlj * coshb
8981
)
9082

91-
dl = dl.at[
92-
-el + L - 1 : el + 1 + L - 1, -el + L - 1 : el + 1 + L - 1
93-
].multiply(0.0)
83+
dl = dl.at[-el + L - 1 : el + 1 + L - 1, -el + L - 1 : el + 1 + L - 1].multiply(
84+
0.0
85+
)
9486

9587
j = 2 * el
9688
i = jnp.arange(j)

tests/test_spherical_precompute.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import numpy as np
2+
import pyssht as ssht
23
import pytest
34
import torch
4-
from s2fft.precompute_transforms.spherical import inverse, forward
5-
from s2fft.precompute_transforms import construct as c
5+
66
from s2fft.base_transforms import spherical as base
7-
import pyssht as ssht
7+
from s2fft.precompute_transforms import construct as c
8+
from s2fft.precompute_transforms.spherical import forward, inverse
89
from s2fft.sampling import s2_samples as samples
910

1011
L_to_test = [12]
@@ -47,9 +48,7 @@ def test_transform_inverse(
4748
if method.lower() == "jax"
4849
else c.spin_spherical_kernel
4950
)
50-
kernel = kfunc(
51-
L, spin, reality, sampling, forward=False, recursion=recursion
52-
)
51+
kernel = kfunc(L, spin, reality, sampling, forward=False, recursion=recursion)
5352

5453
tol = 1e-8 if sampling.lower() in ["dh", "gl"] else 1e-12
5554
if method.lower() == "torch":
@@ -178,9 +177,7 @@ def test_transform_forward(
178177
if method.lower() == "jax"
179178
else c.spin_spherical_kernel
180179
)
181-
kernel = kfunc(
182-
L, spin, reality, sampling, forward=True, recursion=recursion
183-
)
180+
kernel = kfunc(L, spin, reality, sampling, forward=True, recursion=recursion)
184181

185182
tol = 1e-8 if sampling.lower() in ["dh", "gl"] else 1e-12
186183
if method.lower() == "torch":

tests/test_wigner_precompute.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import numpy as np
22
import pytest
3+
import so3
34
import torch
4-
from s2fft.precompute_transforms.wigner import inverse, forward
5-
from s2fft.precompute_transforms import construct as c
5+
66
from s2fft.base_transforms import wigner as base
7+
from s2fft.precompute_transforms import construct as c
8+
from s2fft.precompute_transforms.wigner import forward, inverse
79
from s2fft.sampling import so3_samples as samples
8-
import so3
910

1011
L_to_test = [6]
1112
N_to_test = [2, 6]
@@ -172,9 +173,7 @@ def test_inverse_wigner_transform_healpix(
172173
method,
173174
nside,
174175
)
175-
np.testing.assert_allclose(
176-
np.real(f), np.real(f_check), atol=1e-5, rtol=1e-5
177-
)
176+
np.testing.assert_allclose(np.real(f), np.real(f_check), atol=1e-5, rtol=1e-5)
178177

179178
# Test Gradients
180179
flmn_grad_test = torch.from_numpy(flmn)

0 commit comments

Comments
 (0)