Skip to content

Commit 3818eda

Browse files
committed
fix docstrings and test coverage
1 parent ad072de commit 3818eda

File tree

5 files changed

+81
-40
lines changed

5 files changed

+81
-40
lines changed

s2fft/precompute_transforms/alt_construct.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@ def spin_spherical_kernel(
1717
sampling: str = "mw",
1818
forward: bool = True,
1919
):
20-
r"""Precompute the wigner-d kernel for spin-spherical transform. This can be
21-
drastically faster but comes at a :math:`\mathcal{O}(L^3)` memory overhead, making
22-
it infeasible for :math:`L\geq 512`.
20+
r"""Precompute the wigner-d kernel for spin-spherical transform.
21+
22+
This implementation is typically faster than computing these elements on-the-fly but
23+
comes at a :math:`\mathcal{O}(L^3)` memory overhead, making it infeasible for large
24+
bandlimits :math:`L\geq 512`.
2325
2426
Args:
2527
L (int): Harmonic band-limit.
@@ -109,9 +111,11 @@ def spin_spherical_kernel_jax(
109111
sampling: str = "mw",
110112
forward: bool = True,
111113
):
112-
r"""Precompute the wigner-d kernel for spin-spherical transform. This can be
113-
drastically faster but comes at a :math:`\mathcal{O}(L^3)` memory overhead, making
114-
it infeasible for :math:`L\geq 512`.
114+
r"""Precompute the wigner-d kernel for spin-spherical transform.
115+
116+
This implementation is typically faster than computing these elements on-the-fly but
117+
comes at a :math:`\mathcal{O}(L^3)` memory overhead, making it infeasible for large
118+
bandlimits :math:`L\geq 512`.
115119
116120
Args:
117121
L (int): Harmonic band-limit.
@@ -136,6 +140,13 @@ def spin_spherical_kernel_jax(
136140
Fourier decomposition of Wigner D-functions. This involves (minor) additional
137141
precomputations, but is stable to effectively arbitrarily large spin numbers.
138142
"""
143+
if reality and spin != 0:
144+
reality = False
145+
warn(
146+
"Reality acceleration only supports spin 0 fields. "
147+
+ "Defering to complex transform."
148+
)
149+
139150
m_start_ind = L - 1 if reality else 0
140151
m_dim = L if reality else 2 * L - 1
141152

@@ -194,9 +205,11 @@ def wigner_kernel(
194205
sampling: str = "mw",
195206
forward: bool = False,
196207
):
197-
r"""Precompute the wigner-d kernels required for a Wigner transform. This can be
198-
drastically faster but comes at a :math:`\mathcal{O}(NL^3)` memory overhead, making
199-
it infeasible for :math:`L \geq 512`.
208+
r"""Precompute the wigner-d kernel for Wigner transform.
209+
210+
This implementation is typically faster than computing these elements on-the-fly but
211+
comes at a :math:`\mathcal{O}(NL^3)` memory overhead, making it infeasible for large
212+
bandlimits :math:`L\geq 512`.
200213
201214
Args:
202215
L (int): Harmonic band-limit.
@@ -242,7 +255,7 @@ def wigner_kernel(
242255
raise ValueError("Sampling in supported list [mw, mwss, dh]")
243256

244257
# Compute Wigner d-functions from their Fourier decomposition.
245-
if N <= int(L / np.log(L)):
258+
if N >= int(L / np.log(L)):
246259
delta = np.zeros((len(thetas), 2 * L - 1, 2 * L - 1), dtype=np.float64)
247260
else:
248261
delta = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
@@ -254,7 +267,7 @@ def wigner_kernel(
254267

255268
# If N <= L/LogL more efficient to manually compute over FFT
256269
for el in range(L):
257-
if N <= int(L / np.log(L)):
270+
if N >= int(L / np.log(L)):
258271
delta = recursions.risbo.compute_full_vect(delta, thetas, L, el)
259272
dl[:, :, el] = np.moveaxis(delta, -1, 0)[L - 1 + n]
260273
else:
@@ -294,9 +307,11 @@ def wigner_kernel_jax(
294307
sampling: str = "mw",
295308
forward: bool = False,
296309
):
297-
r"""Precompute the wigner-d kernels required for a Wigner transform. This can be
298-
drastically faster but comes at a :math:`\mathcal{O}(NL^3)` memory overhead, making
299-
it infeasible for :math:`L \geq 512`.
310+
r"""Precompute the wigner-d kernel for Wigner transform.
311+
312+
This implementation is typically faster than computing these elements on-the-fly but
313+
comes at a :math:`\mathcal{O}(NL^3)` memory overhead, making it infeasible for large
314+
bandlimits :math:`L\geq 512`.
300315
301316
Args:
302317
L (int): Harmonic band-limit.
@@ -343,7 +358,7 @@ def wigner_kernel_jax(
343358
raise ValueError("Sampling in supported list [mw, mwss, dh]")
344359

345360
# Compute Wigner d-functions from their Fourier decomposition.
346-
if N <= int(L / np.log(L)):
361+
if N >= int(L / np.log(L)):
347362
delta = jnp.zeros(
348363
(len(thetas), 2 * L - 1, 2 * L - 1), dtype=jnp.float64
349364
)
@@ -360,7 +375,7 @@ def wigner_kernel_jax(
360375

361376
# If N <= L/LogL more efficient to manually compute over FFT
362377
for el in range(L):
363-
if N <= int(L / np.log(L)):
378+
if N >= int(L / np.log(L)):
364379
delta = vfunc(delta, thetas, L, el)
365380
dl = dl.at[:, :, el].set(jnp.moveaxis(delta, -1, 0)[L - 1 + n])
366381
else:

s2fft/precompute_transforms/construct.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def spin_spherical_kernel(
6565
dl = np.zeros((len(thetas), L, m_dim), dtype=np.float64)
6666
for t, theta in enumerate(thetas):
6767
for el in range(abs(spin), L):
68-
dl[t, el] = recursions.turok.compute_slice(theta, el, L, -spin, reality)[
69-
m_start_ind:
70-
]
68+
dl[t, el] = recursions.turok.compute_slice(
69+
theta, el, L, -spin, reality
70+
)[m_start_ind:]
7171
dl[t, el] *= np.sqrt((2 * el + 1) / (4 * np.pi))
7272

7373
if forward:
@@ -117,6 +117,13 @@ def spin_spherical_kernel_jax(
117117
Returns:
118118
jnp.ndarray: Transform kernel for spin-spherical harmonic transform.
119119
"""
120+
if reality and spin != 0:
121+
reality = False
122+
warn(
123+
"Reality acceleration only supports spin 0 fields. "
124+
+ "Defering to complex transform."
125+
)
126+
120127
m_start_ind = L - 1 if reality else 0
121128

122129
if forward and sampling.lower() in ["mw", "mwss"]:
@@ -135,7 +142,7 @@ def spin_spherical_kernel_jax(
135142
# North pole singularity
136143
if sampling.lower() == "mwss":
137144
dl = dl.at[0].set(0)
138-
dl = dl = dl.at[0, :, L - 1 - spin].set(1)
145+
dl = dl.at[0, :, L - 1 - spin].set(1)
139146

140147
# South pole singularity
141148
if sampling.lower() in ["mw", "mwss"]:
@@ -212,7 +219,9 @@ def wigner_kernel(
212219
for t, theta in enumerate(thetas):
213220
for el in range(abs(n), L):
214221
ind = n if reality else N - 1 + n
215-
dl[ind, t, el] = recursions.turok.compute_slice(theta, el, L, n, False)
222+
dl[ind, t, el] = recursions.turok.compute_slice(
223+
theta, el, L, n, False
224+
)
216225

217226
if forward:
218227
weights = quadrature.quad_weights_transform(L, sampling, 0, nside)
@@ -345,7 +354,7 @@ def healpix_phase_shifts(
345354
"""
346355
thetas = samples.thetas(L, "healpix", nside)
347356
phase_array = np.zeros((len(thetas), 2 * L - 1), dtype=np.complex128)
348-
for t, theta in enumerate(thetas):
357+
for t, _ in enumerate(thetas):
349358
phase_array[t] = samples.ring_phase_shift_hp(L, t, nside, forward)
350359

351360
return phase_array

s2fft/precompute_transforms/spherical.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,13 @@ def inverse(
6161
if method == "numpy":
6262
return inverse_transform(flm, kernel, L, sampling, reality, spin, nside)
6363
elif method == "jax":
64-
return inverse_transform_jax(flm, kernel, L, sampling, reality, spin, nside)
64+
return inverse_transform_jax(
65+
flm, kernel, L, sampling, reality, spin, nside
66+
)
6567
elif method == "torch":
66-
return inverse_transform_torch(flm, kernel, L, sampling, reality, spin, nside)
68+
return inverse_transform_torch(
69+
flm, kernel, L, sampling, reality, spin, nside
70+
)
6771
else:
6872
raise ValueError(f"Method {method} not recognised.")
6973

@@ -181,7 +185,9 @@ def inverse_transform_jax(
181185
if sampling.lower() == "healpix":
182186
if reality:
183187
ftm = ftm.at[:, m_offset : m_start_ind + m_offset].set(
184-
jnp.flip(jnp.conj(ftm[:, m_start_ind + m_offset + 1 :]), axis=-1)
188+
jnp.flip(
189+
jnp.conj(ftm[:, m_start_ind + m_offset + 1 :]), axis=-1
190+
)
185191
)
186192
f = hp.healpix_ifft(ftm, L, nside, "jax", reality)
187193

@@ -236,7 +242,9 @@ def inverse_transform_torch(
236242
m_offset = 1 if sampling in ["mwss", "healpix"] else 0
237243
m_start_ind = L - 1 if reality else 0
238244

239-
ftm = torch.zeros(samples.ftm_shape(L, sampling, nside), dtype=torch.complex128)
245+
ftm = torch.zeros(
246+
samples.ftm_shape(L, sampling, nside), dtype=torch.complex128
247+
)
240248
if sampling.lower() == "healpix":
241249
ftm[:, m_start_ind + m_offset :] += torch.einsum(
242250
"...tlm, ...lm -> ...tm", kernel, flm[:, m_start_ind:]
@@ -327,9 +335,13 @@ def forward(
327335
if method == "numpy":
328336
return forward_transform(f, kernel, L, sampling, reality, spin, nside)
329337
elif method == "jax":
330-
return forward_transform_jax(f, kernel, L, sampling, reality, spin, nside)
338+
return forward_transform_jax(
339+
f, kernel, L, sampling, reality, spin, nside
340+
)
331341
elif method == "torch":
332-
return forward_transform_torch(f, kernel, L, sampling, reality, spin, nside)
342+
return forward_transform_torch(
343+
f, kernel, L, sampling, reality, spin, nside
344+
)
333345
else:
334346
raise ValueError(f"Method {method} not recognised.")
335347

@@ -466,7 +478,8 @@ def forward_transform_jax(
466478
if reality:
467479
flm = flm.at[:, :m_start_ind].set(
468480
jnp.flip(
469-
(-1) ** (jnp.arange(1, L) % 2) * jnp.conj(flm[:, m_start_ind + 1 :]),
481+
(-1) ** (jnp.arange(1, L) % 2)
482+
* jnp.conj(flm[:, m_start_ind + 1 :]),
470483
axis=-1,
471484
)
472485
)
@@ -532,7 +545,9 @@ def forward_transform_torch(
532545

533546
flm = torch.zeros(samples.flm_shape(L), dtype=torch.complex128)
534547
if sampling.lower() == "healpix":
535-
flm[:, m_start_ind:] = torch.einsum("...tlm, ...tm -> ...lm", kernel, ftm)
548+
flm[:, m_start_ind:] = torch.einsum(
549+
"...tlm, ...tm -> ...lm", kernel, ftm
550+
)
536551
else:
537552
flm[:, m_start_ind:].real = torch.einsum(
538553
"...tlm, ...tm -> ...lm", kernel, ftm.real
@@ -543,7 +558,8 @@ def forward_transform_torch(
543558

544559
if reality:
545560
flm[:, :m_start_ind] = torch.flip(
546-
(-1) ** (torch.arange(1, L) % 2) * torch.conj(flm[:, m_start_ind + 1 :]),
561+
(-1) ** (torch.arange(1, L) % 2)
562+
* torch.conj(flm[:, m_start_ind + 1 :]),
547563
dims=[-1],
548564
)
549565

tests/test_spherical_precompute.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
sampling_to_test = ["mw", "mwss", "dh", "gl"]
1616
reality_to_test = [True, False]
1717
methods_to_test = ["numpy", "jax", "torch"]
18-
recursions_to_test = ["price_mcewen", "risbo"]
18+
# recursions_to_test = ["price_mcewen", "risbo"]
19+
recursions_to_test = ["risbo"]
1920

2021

2122
@pytest.mark.parametrize("L", L_to_test)
@@ -33,9 +34,9 @@ def test_transform_inverse(
3334
method: str,
3435
recursion: str,
3536
):
36-
if recursion.lower() == "risbo" and [
37+
if recursion.lower() == "risbo" and (
3738
method.lower() == "torch" or sampling.lower() == "gl"
38-
]:
39+
):
3940
pytest.skip("Fourier mode Risbo recursions have limited functionality.")
4041

4142
flm = flm_generator(L=L, spin=spin, reality=reality)
@@ -162,9 +163,9 @@ def test_transform_forward(
162163
method: str,
163164
recursion: str,
164165
):
165-
if recursion.lower() == "risbo" and [
166+
if recursion.lower() == "risbo" and (
166167
method.lower() == "torch" or sampling.lower() == "gl"
167-
]:
168+
):
168169
pytest.skip("Fourier mode Risbo recursions have limited functionality.")
169170

170171
flm = flm_generator(L=L, spin=spin, reality=reality)

tests/test_wigner_precompute.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def test_inverse_wigner_transform(
3434
method: str,
3535
recursion: str,
3636
):
37-
if recursion.lower() == "risbo" and [
37+
if recursion.lower() == "risbo" and (
3838
method.lower() == "torch" or sampling.lower() == "gl"
39-
]:
39+
):
4040
pytest.skip("Fourier mode Risbo recursions have limited functionality.")
4141

4242
flmn = flmn_generator(L=L, N=N, reality=reality)
@@ -99,9 +99,9 @@ def test_forward_wigner_transform(
9999
method: str,
100100
recursion: str,
101101
):
102-
if recursion.lower() == "risbo" and [
102+
if recursion.lower() == "risbo" and (
103103
method.lower() == "torch" or sampling.lower() == "gl"
104-
]:
104+
):
105105
pytest.skip("Fourier mode Risbo recursions have limited functionality.")
106106

107107
flmn = flmn_generator(L=L, N=N, reality=reality)

0 commit comments

Comments
 (0)