Skip to content

Commit ad072de

Browse files
committed
add switch for different wigner precompute modes depending on N/L
1 parent 5c3226f commit ad072de

File tree

5 files changed

+447
-321
lines changed

5 files changed

+447
-321
lines changed

s2fft/precompute_transforms/alt_construct.py

Lines changed: 47 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,7 @@ def spin_spherical_kernel(
8484
delta[:, L - 1 - spin],
8585
1j ** (-spin - m_value[m_start_ind:]),
8686
)
87-
if sampling.lower() in ["mw", "dh"]:
88-
temp = np.einsum("am,a->am", temp, np.exp(1j * m_value * thetas[0]))
89-
else:
90-
temp_new = np.zeros((L + 1, m_dim), dtype=temp.dtype)
91-
temp_new[:-1] = temp[L - 1 :]
87+
temp = np.einsum("am,a->am", temp, np.exp(1j * m_value * thetas[0]))
9288
temp = np.fft.irfft(temp[L - 1 :], n=nsamps, axis=0, norm="forward")
9389

9490
dl[:, el] = temp[: len(thetas)]
@@ -173,14 +169,7 @@ def spin_spherical_kernel_jax(
173169
delta[:, L - 1 - spin],
174170
1j ** (-spin - m_value[m_start_ind:]),
175171
)
176-
if sampling.lower() in ["mw", "dh"]:
177-
temp = jnp.einsum(
178-
"am,a->am", temp, jnp.exp(1j * m_value * thetas[0])
179-
)
180-
else:
181-
temp_new = jnp.zeros((L + 1, m_dim), dtype=temp.dtype)
182-
temp_new = temp_new.at[:-1].set(temp[L - 1 :])
183-
172+
temp = jnp.einsum("am,a->am", temp, jnp.exp(1j * m_value * thetas[0]))
184173
temp = jnp.fft.irfft(temp[L - 1 :], n=nsamps, axis=0, norm="forward")
185174

186175
dl = dl.at[:, el].set(temp[: len(thetas)])
@@ -253,32 +242,35 @@ def wigner_kernel(
253242
raise ValueError("Sampling in supported list [mw, mwss, dh]")
254243

255244
# Compute Wigner d-functions from their Fourier decomposition.
256-
delta = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
245+
if N <= int(L / np.log(L)):
246+
delta = np.zeros((len(thetas), 2 * L - 1, 2 * L - 1), dtype=np.float64)
247+
else:
248+
delta = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
257249
dl = np.zeros((n_dim, len(thetas), L, 2 * L - 1), dtype=np.float64)
258250

259251
# Range values which need only be defined once.
260252
m_value = np.arange(-L + 1, L)
261253
n = np.arange(n_start_ind - N + 1, N)
262254

255+
# If N <= L/LogL more efficient to manually compute over FFT
263256
for el in range(L):
264-
delta = recursions.risbo.compute_full(delta, np.pi / 2, L, el)
265-
temp = np.einsum(
266-
"am,an,m,n->amn",
267-
delta,
268-
delta[:, L - 1 + n],
269-
1j ** (-m_value),
270-
1j ** (n),
271-
)
272-
if sampling.lower() in ["mw", "dh"]:
257+
if N <= int(L / np.log(L)):
258+
delta = recursions.risbo.compute_full_vect(delta, thetas, L, el)
259+
dl[:, :, el] = np.moveaxis(delta, -1, 0)[L - 1 + n]
260+
else:
261+
delta = recursions.risbo.compute_full(delta, np.pi / 2, L, el)
262+
temp = np.einsum(
263+
"am,an,m,n->amn",
264+
delta,
265+
delta[:, L - 1 + n],
266+
1j ** (-m_value),
267+
1j ** (n),
268+
)
273269
temp = np.einsum(
274270
"amn,a->amn", temp, np.exp(1j * m_value * thetas[0])
275271
)
276-
277-
else:
278-
temp_new = np.zeros((L + 1, 2 * L - 1, len(n)), dtype=temp.dtype)
279-
temp_new[:-1] = temp[L - 1 :]
280-
temp = np.fft.irfft(temp[L - 1 :], n=nsamps, axis=0, norm="forward")
281-
dl[:, :, el] = np.moveaxis(temp[: len(thetas)], -1, 0)
272+
temp = np.fft.irfft(temp[L - 1 :], n=nsamps, axis=0, norm="forward")
273+
dl[:, :, el] = np.moveaxis(temp[: len(thetas)], -1, 0)
282274

283275
if forward:
284276
weights = quadrature.quad_weights_transform(L, sampling)
@@ -351,32 +343,42 @@ def wigner_kernel_jax(
351343
raise ValueError("Sampling in supported list [mw, mwss, dh]")
352344

353345
# Compute Wigner d-functions from their Fourier decomposition.
354-
delta = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)
346+
if N <= int(L / np.log(L)):
347+
delta = jnp.zeros(
348+
(len(thetas), 2 * L - 1, 2 * L - 1), dtype=jnp.float64
349+
)
350+
vfunc = jax.vmap(
351+
recursions.risbo_jax.compute_full, in_axes=(0, 0, None, None)
352+
)
353+
else:
354+
delta = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)
355355
dl = jnp.zeros((n_dim, len(thetas), L, 2 * L - 1), dtype=jnp.float64)
356356

357357
# Range values which need only be defined once.
358358
m_value = jnp.arange(-L + 1, L)
359359
n = jnp.arange(n_start_ind - N + 1, N)
360360

361+
# If N <= L/LogL more efficient to manually compute over FFT
361362
for el in range(L):
362-
delta = recursions.risbo_jax.compute_full(delta, jnp.pi / 2, L, el)
363-
temp = jnp.einsum(
364-
"am,an,m,n->amn",
365-
delta,
366-
delta[:, L - 1 + n],
367-
1j ** (-m_value),
368-
1j ** (n),
369-
)
370-
if sampling.lower() in ["mw", "dh"]:
363+
if N <= int(L / np.log(L)):
364+
delta = vfunc(delta, thetas, L, el)
365+
dl = dl.at[:, :, el].set(jnp.moveaxis(delta, -1, 0)[L - 1 + n])
366+
else:
367+
delta = recursions.risbo_jax.compute_full(delta, jnp.pi / 2, L, el)
368+
temp = jnp.einsum(
369+
"am,an,m,n->amn",
370+
delta,
371+
delta[:, L - 1 + n],
372+
1j ** (-m_value),
373+
1j ** (n),
374+
)
371375
temp = jnp.einsum(
372376
"amn,a->amn", temp, jnp.exp(1j * m_value * thetas[0])
373377
)
374-
375-
else:
376-
temp_new = jnp.zeros((L + 1, 2 * L - 1, len(n)), dtype=temp.dtype)
377-
temp_new = temp_new.at[:-1].set(temp[L - 1 :])
378-
temp = jnp.fft.irfft(temp[L - 1 :], n=nsamps, axis=0, norm="forward")
379-
dl = dl.at[:, :, el].set(jnp.moveaxis(temp[: len(thetas)], -1, 0))
378+
temp = jnp.fft.irfft(
379+
temp[L - 1 :], n=nsamps, axis=0, norm="forward"
380+
)
381+
dl = dl.at[:, :, el].set(jnp.moveaxis(temp[: len(thetas)], -1, 0))
380382

381383
if forward:
382384
weights = quadrature_jax.quad_weights_transform(L, sampling)

s2fft/recursions/risbo.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray:
8888

8989
ddj = dd[i, k] / j
9090

91-
dl[k - el + L - 1, i - el + L - 1] += sqrt_jmi * sqrt_jmk * ddj * coshb
91+
dl[k - el + L - 1, i - el + L - 1] += (
92+
sqrt_jmi * sqrt_jmk * ddj * coshb
93+
)
9294
dl[k - el + L - 1, i + 1 - el + L - 1] -= (
9395
sqrt_ip1 * sqrt_jmk * ddj * sinhb
9496
)
@@ -116,3 +118,105 @@ def _arg_checks(dl: np.ndarray, beta: float, L: int, el: int):
116118
assert 0 <= el < L # Should be < not <= once have init routine.
117119
assert dl.shape[0] == dl.shape[1] == 2 * L - 1
118120
assert 0 <= beta <= np.pi
121+
122+
123+
def compute_full_vect(
124+
dl: np.ndarray, beta: np.ndarray, L: int, el: int
125+
) -> np.ndarray:
126+
r"""Compute Wigner-d at argument :math:`\beta` for full plane using
127+
Risbo recursion.
128+
129+
The Wigner-d plane is computed by recursion over :math:`\ell`.
130+
Thus, for :math:`\ell > 0` the plane must be computed already for
131+
:math:`\ell - 1`. At present, for :math:`\ell = 0` the recusion is initialised.
132+
133+
Args:
134+
dl (np.ndarray): Wigner-d plane for :math:`\ell - 1` at :math:`\beta`.
135+
136+
beta (np.ndarray): Arguments :math:`\beta` at which to compute Wigner-d plane.
137+
138+
L (int): Harmonic band-limit.
139+
140+
el (int): Spherical harmonic degree :math:`\ell`.
141+
142+
Returns:
143+
np.ndarray: Plane of Wigner-d for :math:`\ell` and :math:`\beta`, with full plane computed.
144+
"""
145+
if el == 0:
146+
el = 0
147+
dl[:, el + L - 1, el + L - 1] = 1.0
148+
149+
elif el == 1:
150+
cosb = np.cos(beta)
151+
sinb = np.sin(beta)
152+
153+
coshb = np.cos(beta / 2.0)
154+
sinhb = np.sin(beta / 2.0)
155+
sqrt2 = np.sqrt(2.0)
156+
157+
dl[:, -1 + L - 1, -1 + L - 1] = coshb**2
158+
dl[:, -1 + L - 1, 0 + L - 1] = sinb / sqrt2
159+
dl[:, -1 + L - 1, 1 + L - 1] = sinhb**2
160+
161+
dl[:, 0 + L - 1, -1 + L - 1] = -sinb / sqrt2
162+
dl[:, 0 + L - 1, 0 + L - 1] = cosb
163+
dl[:, 0 + L - 1, 1 + L - 1] = sinb / sqrt2
164+
165+
dl[:, 1 + L - 1, -1 + L - 1] = sinhb**2
166+
dl[:, 1 + L - 1, 0 + L - 1] = -sinb / sqrt2
167+
dl[:, 1 + L - 1, 1 + L - 1] = coshb**2
168+
169+
else:
170+
coshb = -np.cos(beta / 2.0)
171+
sinhb = np.sin(beta / 2.0)
172+
173+
# Initialise the plane of the dl-matrix to 0.0 for the recursion
174+
# from l - 1 to l - 1/2.
175+
dd = np.zeros((dl.shape[0], 2 * el + 2, 2 * el + 2))
176+
j = 2 * el - 1
177+
178+
for k in range(0, j):
179+
sqrt_jmk = np.sqrt(j - k)
180+
sqrt_kp1 = np.sqrt(k + 1)
181+
182+
for i in range(0, j):
183+
sqrt_jmi = np.sqrt(j - i)
184+
sqrt_ip1 = np.sqrt(i + 1)
185+
186+
dlj = dl[:, k - (el - 1) + L - 1, i - (el - 1) + L - 1] / j
187+
188+
dd[:, i, k] += sqrt_jmi * sqrt_jmk * dlj * coshb
189+
dd[:, i + 1, k] -= sqrt_ip1 * sqrt_jmk * dlj * sinhb
190+
dd[:, i, k + 1] += sqrt_jmi * sqrt_kp1 * dlj * sinhb
191+
dd[:, i + 1, k + 1] += sqrt_ip1 * sqrt_kp1 * dlj * coshb
192+
193+
# Having constructed the d^(l+1/2) matrix in dd, do the second
194+
# half-step recursion from dd to dl. Start by initilalising
195+
# the plane of the dl-matrix to 0.0.
196+
dl[:, -el + L - 1 : el + 1 + L - 1, -el + L - 1 : el + 1 + L - 1] = 0.0
197+
j = 2 * el
198+
199+
for k in range(0, j):
200+
sqrt_jmk = np.sqrt(j - k)
201+
sqrt_kp1 = np.sqrt(k + 1)
202+
203+
for i in range(0, j):
204+
sqrt_jmi = np.sqrt(j - i)
205+
sqrt_ip1 = np.sqrt(i + 1)
206+
207+
ddj = dd[:, i, k] / j
208+
209+
dl[:, k - el + L - 1, i - el + L - 1] += (
210+
sqrt_jmi * sqrt_jmk * ddj * coshb
211+
)
212+
dl[:, k - el + L - 1, i + 1 - el + L - 1] -= (
213+
sqrt_ip1 * sqrt_jmk * ddj * sinhb
214+
)
215+
dl[:, k + 1 - el + L - 1, i - el + L - 1] += (
216+
sqrt_jmi * sqrt_kp1 * ddj * sinhb
217+
)
218+
dl[:, k + 1 - el + L - 1, i + 1 - el + L - 1] += (
219+
sqrt_ip1 * sqrt_kp1 * ddj * coshb
220+
)
221+
222+
return dl

s2fft/recursions/risbo_jax.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from functools import partial
44

55

6-
@partial(jit, static_argnums=(1, 2, 3))
6+
@partial(jit, static_argnums=(2, 3))
77
def compute_full(dl: jnp.ndarray, beta: float, L: int, el: int) -> jnp.ndarray:
88
r"""Compute Wigner-d at argument :math:`\beta` for full plane using
99
Risbo recursion (JAX implementation)
@@ -13,7 +13,7 @@ def compute_full(dl: jnp.ndarray, beta: float, L: int, el: int) -> jnp.ndarray:
1313
:math:`\ell - 1`. At present, for :math:`\ell = 0` the recusion is initialised.
1414
1515
Args:
16-
dl (np.ndarray): Wigner-d plane for :math:`\ell - 1` at :math:`\beta`.
16+
dl (jnp.ndarray): Wigner-d plane for :math:`\ell - 1` at :math:`\beta`.
1717
1818
beta (float): Argument :math:`\beta` at which to compute Wigner-d plane.
1919
@@ -22,7 +22,7 @@ def compute_full(dl: jnp.ndarray, beta: float, L: int, el: int) -> jnp.ndarray:
2222
el (int): Spherical harmonic degree :math:`\ell`.
2323
2424
Returns:
25-
np.ndarray: Plane of Wigner-d for `el` and `beta`, with full plane computed.
25+
jnp.ndarray: Plane of Wigner-d for `el` and `beta`, with full plane computed.
2626
"""
2727

2828
if el == 0:
@@ -66,21 +66,29 @@ def compute_full(dl: jnp.ndarray, beta: float, L: int, el: int) -> jnp.ndarray:
6666
dlj = dl[k - (el - 1) + L - 1][:, i - (el - 1) + L - 1]
6767

6868
dd = dd.at[:j, :j].add(
69-
jnp.einsum("i,k->ki", sqrt_jmi, sqrt_jmk, optimize=True) * dlj * coshb
69+
jnp.einsum("i,k->ki", sqrt_jmi, sqrt_jmk, optimize=True)
70+
* dlj
71+
* coshb
7072
)
7173
dd = dd.at[:j, 1 : j + 1].add(
72-
jnp.einsum("i,k->ki", -sqrt_ip1, sqrt_jmk, optimize=True) * dlj * sinhb
74+
jnp.einsum("i,k->ki", -sqrt_ip1, sqrt_jmk, optimize=True)
75+
* dlj
76+
* sinhb
7377
)
7478
dd = dd.at[1 : j + 1, :j].add(
75-
jnp.einsum("i,k->ki", sqrt_jmi, sqrt_kp1, optimize=True) * dlj * sinhb
79+
jnp.einsum("i,k->ki", sqrt_jmi, sqrt_kp1, optimize=True)
80+
* dlj
81+
* sinhb
7682
)
7783
dd = dd.at[1 : j + 1, 1 : j + 1].add(
78-
jnp.einsum("i,k->ki", sqrt_ip1, sqrt_kp1, optimize=True) * dlj * coshb
84+
jnp.einsum("i,k->ki", sqrt_ip1, sqrt_kp1, optimize=True)
85+
* dlj
86+
* coshb
7987
)
8088

81-
dl = dl.at[-el + L - 1 : el + 1 + L - 1, -el + L - 1 : el + 1 + L - 1].multiply(
82-
0.0
83-
)
89+
dl = dl.at[
90+
-el + L - 1 : el + 1 + L - 1, -el + L - 1 : el + 1 + L - 1
91+
].multiply(0.0)
8492

8593
j = 2 * el
8694
i = jnp.arange(j)

0 commit comments

Comments
 (0)