Skip to content

Commit 7048633

Browse files
authored
Update custom_ops.py
Small compatibility change which disables jitting on the s2fft side, in turn enables higher level jitting in s2ai.
1 parent d77e9cb commit 7048633

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

s2fft/precompute_transforms/custom_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def wigner_subset_to_s2(
8686
return np.fft.ifft(x, axis=-2, norm="forward")
8787

8888

89-
@partial(jit, static_argnums=(3, 4))
89+
# @partial(jit, static_argnums=(3, 4))
9090
def wigner_subset_to_s2_jax(
9191
flmn: jnp.ndarray,
9292
spins: jnp.ndarray,
@@ -209,7 +209,7 @@ def so3_to_wigner_subset(
209209
return s2_to_wigner_subset(x, spins, DW, L, sampling)
210210

211211

212-
@partial(jit, static_argnums=(3, 4, 5))
212+
# @partial(jit, static_argnums=(3, 4, 5))
213213
def so3_to_wigner_subset_jax(
214214
f: jnp.ndarray,
215215
spins: jnp.ndarray,
@@ -338,7 +338,7 @@ def s2_to_wigner_subset(
338338
return x * (2.0 * np.pi) ** 2
339339

340340

341-
@partial(jit, static_argnums=(3, 4))
341+
# @partial(jit, static_argnums=(3, 4))
342342
def s2_to_wigner_subset_jax(
343343
fs: jnp.ndarray,
344344
spins: jnp.ndarray,

0 commit comments

Comments
 (0)