@@ -398,8 +398,13 @@ def forward(
398
398
return forward_numpy (f , L , spin , nside , sampling , reality , precomps , L_lower )
399
399
elif method == "jax" :
400
400
return forward_jax (
401
- f , L , spin , nside , sampling , reality , precomps , spmd , L_lower
401
+ f , L , spin , nside , sampling , reality , precomps , spmd , L_lower , use_cuda = False
402
402
)
403
+ elif method == "cuda" :
404
+ return forward_jax (
405
+ f , L , spin , nside , sampling , reality , precomps , spmd , L_lower , use_cuda = True
406
+ )
407
+
403
408
elif method == "jax_ssht" :
404
409
if sampling .lower () == "healpix" :
405
410
raise ValueError ("SSHT does not support healpix sampling." )
@@ -537,7 +542,7 @@ def forward_numpy(
537
542
return flm * (- 1 ) ** spin
538
543
539
544
540
- @partial (jit , static_argnums = (1 , 3 , 4 , 5 , 7 , 8 ))
545
+ @partial (jit , static_argnums = (1 , 3 , 4 , 5 , 7 , 8 , 9 ))
541
546
def forward_jax (
542
547
f : jnp .ndarray ,
543
548
L : int ,
@@ -548,6 +553,7 @@ def forward_jax(
548
553
precomps : List = None ,
549
554
spmd : bool = False ,
550
555
L_lower : int = 0 ,
556
+ use_cuda : bool = False ,
551
557
) -> jnp .ndarray :
552
558
r"""Compute the forward spin-spherical harmonic transform (JAX).
553
559
@@ -582,6 +588,8 @@ def forward_jax(
582
588
L_lower (int, optional): Harmonic lower-bound. Transform will only be computed
583
589
for :math:`\texttt{L_lower} \leq \ell < \texttt{L}`. Defaults to 0.
584
590
591
+ use_cuda (bool, optional): Whether to use the CUDA backend. Defaults to False.
592
+
585
593
Returns:
586
594
jnp.ndarray: Spherical harmonic coefficients
587
595
@@ -609,7 +617,10 @@ def forward_jax(
609
617
610
618
# Perform longitundal Fast Fourier Transforms
611
619
if sampling .lower () == "healpix" :
612
- ftm = hp .healpix_fft (f , L , nside , "jax" , reality )
620
+ if use_cuda :
621
+ ftm = hp .healpix_fft (f , L , nside , "cuda" , reality )
622
+ else :
623
+ ftm = hp .healpix_fft (f , L , nside , "jax" , reality )
613
624
else :
614
625
if reality :
615
626
t = jnp .fft .rfft (jnp .real (f ), axis = 1 , norm = "backward" )
0 commit comments