Skip to content

Commit 5d1d13f

Browse files
authored
Merge pull request #238 from astro-informatics/map/risbo-precompute-transform-memeff
add stable forward/inverse memory efficient Wigner transforms
2 parents 11f76bf + 6f64ebb commit 5d1d13f

File tree

7 files changed

+575
-3
lines changed

7 files changed

+575
-3
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
:html_theme.sidebar_secondary.remove:
2+
3+
**************************
4+
Fourier-Wigner Transform
5+
**************************
6+
.. automodule:: s2fft.precompute_transforms.fourier_wigner
7+
:members:

docs/api/precompute_transforms/index.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,21 @@ Precompute Functions
5050
* - :func:`~s2fft.precompute_transforms.wigner.forward_transform_torch`
5151
- Forward Wigner transform (Torch)
5252

53+
.. list-table:: Fourier-Wigner transforms.
54+
:widths: 25 25
55+
:header-rows: 1
56+
57+
* - Function Name
58+
- Description
59+
* - :func:`~s2fft.precompute_transforms.fourier_wigner.inverse_transform`
60+
- Inverse Wigner transform with Fourier method (NumPy)
61+
* - :func:`~s2fft.precompute_transforms.fourier_wigner.inverse_transform_jax`
62+
- Inverse Wigner transform with Fourier method (JAX)
63+
* - :func:`~s2fft.precompute_transforms.fourier_wigner.forward_transform`
64+
- Forward Wigner transform with Fourier method (NumPy)
65+
* - :func:`~s2fft.precompute_transforms.fourier_wigner.forward_transform_jax`
66+
- Forward Wigner transform with Fourier method (JAX)
67+
5368
.. list-table:: Constructing Kernels for precompute transforms.
5469
:widths: 25 25
5570
:header-rows: 1
@@ -64,6 +79,10 @@ Precompute Functions
6479
- Builds a kernel including quadrature weights and Wigner-D coefficients for spherical harmonic transform (JAX).
6580
* - :func:`~s2fft.precompute_transforms.construct.wigner_kernel_jax`
6681
- Builds a kernel including quadrature weights and Wigner-D coefficients for Wigner transform (JAX).
82+
* - :func:`~s2fft.precompute_transforms.construct.fourier_wigner_kernel`
83+
- Builds a kernel including quadrature weights and Fourier coefficienfs of Wigner d-functions
84+
* - :func:`~s2fft.precompute_transforms.construct.fourier_wigner_kernel_jax`
85+
- Builds a kernel including quadrature weights and Fourier coefficienfs of Wigner d-functions (JAX).
6786
* - :func:`~s2fft.precompute_transforms.construct.healpix_phase_shifts`
6887
- Builds a vector of corresponding phase shifts for each HEALPix latitudinal ring.
6988

@@ -76,4 +95,5 @@ Precompute Functions
7695
alt_construct
7796
spin_spherical
7897
wigner
98+
fourier_wigner
7999

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from . import construct, spherical, wigner
1+
from . import construct, fourier_wigner, spherical, wigner

s2fft/precompute_transforms/construct.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Tuple
12
from warnings import warn
23

34
import jax
@@ -610,6 +611,62 @@ def wigner_kernel_jax(
610611
return dl
611612

612613

614+
def fourier_wigner_kernel(L: int) -> Tuple[np.ndarray, np.ndarray]:
615+
"""
616+
Computes Fourier coefficients of the reduced Wigner d-functions and quadrature
617+
weights upsampled for the forward Fourier-Wigner transform.
618+
619+
Args:
620+
L (int): Harmonic band-limit.
621+
622+
Returns:
623+
Tuple[np.ndarray, np.ndarray]: Tuple of delta Fourier coefficients and weights.
624+
625+
"""
626+
# Calculate deltas (np.pi/2 Fourier coefficients of Wigner matrices)
627+
deltas = np.zeros((L, 2 * L - 1, 2 * L - 1), dtype=np.float64)
628+
d = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
629+
for el in range(L):
630+
d = recursions.risbo.compute_full(d, np.pi / 2, L, el)
631+
deltas[el] = d
632+
633+
# Calculate upsampled quadrature weights
634+
w = np.zeros(4 * L - 3, dtype=np.complex128)
635+
for mm in range(-2 * (L - 1), 2 * (L - 1) + 1):
636+
w[mm + 2 * (L - 1)] = quadrature.mw_weights(-mm)
637+
w = np.fft.ifft(np.fft.ifftshift(w), norm="forward")
638+
639+
return deltas, w
640+
641+
642+
def fourier_wigner_kernel_jax(L: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
643+
"""
644+
Computes Fourier coefficients of the reduced Wigner d-functions and quadrature
645+
weights upsampled for the forward Fourier-Wigner transform (JAX implementation).
646+
647+
Args:
648+
L (int): Harmonic band-limit.
649+
650+
Returns:
651+
Tuple[jnp.ndarray, jnp.ndarray]: Tuple of delta Fourier coefficients and weights.
652+
653+
"""
654+
# Calculate deltas (np.pi/2 Fourier coefficients of Wigner matrices)
655+
deltas = jnp.zeros((L, 2 * L - 1, 2 * L - 1), dtype=jnp.float64)
656+
d = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)
657+
for el in range(L):
658+
d = recursions.risbo_jax.compute_full(d, jnp.pi / 2, L, el)
659+
deltas = deltas.at[el].set(d)
660+
661+
# Calculate upsampled quadrature weights
662+
w = jnp.zeros(4 * L - 3, dtype=jnp.complex128)
663+
for mm in range(-2 * (L - 1), 2 * (L - 1) + 1):
664+
w = w.at[mm + 2 * (L - 1)].set(quadrature_jax.mw_weights(-mm))
665+
w = jnp.fft.ifft(jnp.fft.ifftshift(w), norm="forward")
666+
667+
return deltas, w
668+
669+
613670
def healpix_phase_shifts(L: int, nside: int, forward: bool = False) -> np.ndarray:
614671
r"""
615672
Generates a phase shift vector for HEALPix for all :math:`\theta` rings.

0 commit comments

Comments
 (0)