Skip to content

Commit 21d2d0c

Browse files
Update custom_ops.py (#315)
* Update custom_ops.py Small compatibility change which disables jitting on the s2fft side, in turn enables higher level jitting in s2ai. * Update custom_ops.py Removed commented lines for linting purposes * Removing now unused imports --------- Co-authored-by: Matt Graham <matthew.m.graham@gmail.com>
1 parent b5d3eb1 commit 21d2d0c

File tree

1 file changed

+0
-6
lines changed

1 file changed

+0
-6
lines changed

s2fft/precompute_transforms/custom_ops.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
from functools import partial
2-
31
import jax.numpy as jnp
42
import numpy as np
5-
from jax import jit
63

74

85
def wigner_subset_to_s2(
@@ -86,7 +83,6 @@ def wigner_subset_to_s2(
8683
return np.fft.ifft(x, axis=-2, norm="forward")
8784

8885

89-
@partial(jit, static_argnums=(3, 4))
9086
def wigner_subset_to_s2_jax(
9187
flmn: jnp.ndarray,
9288
spins: jnp.ndarray,
@@ -209,7 +205,6 @@ def so3_to_wigner_subset(
209205
return s2_to_wigner_subset(x, spins, DW, L, sampling)
210206

211207

212-
@partial(jit, static_argnums=(3, 4, 5))
213208
def so3_to_wigner_subset_jax(
214209
f: jnp.ndarray,
215210
spins: jnp.ndarray,
@@ -338,7 +333,6 @@ def s2_to_wigner_subset(
338333
return x * (2.0 * np.pi) ** 2
339334

340335

341-
@partial(jit, static_argnums=(3, 4))
342336
def s2_to_wigner_subset_jax(
343337
fs: jnp.ndarray,
344338
spins: jnp.ndarray,

0 commit comments

Comments
 (0)