Skip to content

Commit 58e8101

Browse files
committed
connect cuda implementation to top level API
1 parent a65b009 commit 58e8101

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

s2fft/transforms/spherical.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -398,8 +398,13 @@ def forward(
398398
return forward_numpy(f, L, spin, nside, sampling, reality, precomps, L_lower)
399399
elif method == "jax":
400400
return forward_jax(
401-
f, L, spin, nside, sampling, reality, precomps, spmd, L_lower
401+
f, L, spin, nside, sampling, reality, precomps, spmd, L_lower , use_cuda=False
402402
)
403+
elif method == "cuda":
404+
return forward_jax(
405+
f, L, spin, nside, sampling, reality, precomps, spmd, L_lower, use_cuda=True
406+
)
407+
403408
elif method == "jax_ssht":
404409
if sampling.lower() == "healpix":
405410
raise ValueError("SSHT does not support healpix sampling.")
@@ -537,7 +542,7 @@ def forward_numpy(
537542
return flm * (-1) ** spin
538543

539544

540-
@partial(jit, static_argnums=(1, 3, 4, 5, 7, 8))
545+
@partial(jit, static_argnums=(1, 3, 4, 5, 7, 8,9))
541546
def forward_jax(
542547
f: jnp.ndarray,
543548
L: int,
@@ -548,6 +553,7 @@ def forward_jax(
548553
precomps: List = None,
549554
spmd: bool = False,
550555
L_lower: int = 0,
556+
use_cuda: bool = False,
551557
) -> jnp.ndarray:
552558
r"""Compute the forward spin-spherical harmonic transform (JAX).
553559
@@ -582,6 +588,8 @@ def forward_jax(
582588
L_lower (int, optional): Harmonic lower-bound. Transform will only be computed
583589
for :math:`\texttt{L_lower} \leq \ell < \texttt{L}`. Defaults to 0.
584590
591+
use_cuda (bool, optional): Whether to use the CUDA backend. Defaults to False.
592+
585593
Returns:
586594
jnp.ndarray: Spherical harmonic coefficients
587595
@@ -609,7 +617,10 @@ def forward_jax(
609617

610618
# Perform longitundal Fast Fourier Transforms
611619
if sampling.lower() == "healpix":
612-
ftm = hp.healpix_fft(f, L, nside, "jax", reality)
620+
if use_cuda:
621+
ftm = hp.healpix_fft(f, L, nside, "cuda", reality)
622+
else:
623+
ftm = hp.healpix_fft(f, L, nside, "jax", reality)
613624
else:
614625
if reality:
615626
t = jnp.fft.rfft(jnp.real(f), axis=1, norm="backward")

0 commit comments

Comments
 (0)