Skip to content

Commit 6f64ebb

Browse files
committed
address JDM minor comments
1 parent 8e33569 commit 6f64ebb

File tree

1 file changed

+70
-24
lines changed

1 file changed

+70
-24
lines changed

s2fft/precompute_transforms/fourier_wigner.py

Lines changed: 70 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
def inverse_transform(
99
flmn: np.ndarray,
10-
delta: np.ndarray,
10+
DW: np.ndarray,
1111
L: int,
1212
N: int,
1313
reality: bool = False,
@@ -18,7 +18,7 @@ def inverse_transform(
1818
1919
Args:
2020
flmn (np.ndarray): Wigner coefficients.
21-
delta (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced
21+
DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
2222
Wigner d-functions and the corresponding upsampled quadrature weights.
2323
L (int): Harmonic band-limit.
2424
N (int): Azimuthal band-limit.
@@ -32,6 +32,14 @@ def inverse_transform(
3232
np.ndarray: Pixel-space function sampled on the rotation group.
3333
3434
"""
35+
if sampling.lower() not in ["mw", "mwss"]:
36+
raise ValueError(
37+
f"Fourier-Wigner algorithm does not support {sampling} sampling."
38+
)
39+
40+
# EXTRACT VARIOUS PRECOMPUTES
41+
Delta, _ = DW
42+
3543
# INDEX VALUES
3644
n_start_ind = N - 1 if reality else 0
3745
n_dim = N if reality else 2 * N - 1
@@ -44,13 +52,13 @@ def inverse_transform(
4452
m = np.arange(-L + 1 - m_offset, L)
4553
n = np.arange(n_start_ind - N + 1, N)
4654

47-
# Calculate fmna = i^(n-m)\sum_L delta^l_am delta^l_an f^l_mn(2l+1)/(8pi^2)
55+
# Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2)
4856
x = np.zeros((xnlm_size, xnlm_size, n_dim), dtype=flmn.dtype)
4957
x[m_offset:, m_offset:] = np.einsum(
5058
"nlm,lam,lan,l->amn",
5159
flmn[n_start_ind:],
52-
delta[0],
53-
delta[0][:, :, L - 1 + n],
60+
Delta,
61+
Delta[:, :, L - 1 + n],
5462
(2 * np.arange(L) + 1) / (8 * np.pi**2),
5563
)
5664

@@ -72,7 +80,7 @@ def inverse_transform(
7280
@partial(jit, static_argnums=(2, 3, 4, 5))
7381
def inverse_transform_jax(
7482
flmn: jnp.ndarray,
75-
delta: jnp.ndarray,
83+
DW: jnp.ndarray,
7684
L: int,
7785
N: int,
7886
reality: bool = False,
@@ -83,7 +91,7 @@ def inverse_transform_jax(
8391
8492
Args:
8593
flmn (jnp.ndarray): Wigner coefficients.
86-
delta (Tuple[jnp.ndarray, jnp.ndarray]): Fourier coefficients of the reduced
94+
DW (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced
8795
Wigner d-functions and the corresponding upsampled quadrature weights.
8896
L (int): Harmonic band-limit.
8997
N (int): Azimuthal band-limit.
@@ -97,6 +105,14 @@ def inverse_transform_jax(
97105
jnp.ndarray: Pixel-space function sampled on the rotation group.
98106
99107
"""
108+
if sampling.lower() not in ["mw", "mwss"]:
109+
raise ValueError(
110+
f"Fourier-Wigner algorithm does not support {sampling} sampling."
111+
)
112+
113+
# EXTRACT VARIOUS PRECOMPUTES
114+
Delta, _ = DW
115+
100116
# INDEX VALUES
101117
n_start_ind = N - 1 if reality else 0
102118
n_dim = N if reality else 2 * N - 1
@@ -109,17 +125,15 @@ def inverse_transform_jax(
109125
m = jnp.arange(-L + 1 - m_offset, L)
110126
n = jnp.arange(n_start_ind - N + 1, N)
111127

112-
# Calculate fmna = i^(n-m)\sum_L delta^l_am delta^l_an f^l_mn(2l+1)/(8pi^2)
113-
x = jnp.zeros((xnlm_size, xnlm_size, n_dim), dtype=flmn.dtype)
128+
# Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2)
129+
x = jnp.zeros((xnlm_size, xnlm_size, n_dim), dtype=jnp.complex128)
130+
flmn = jnp.einsum("nlm,l->nlm", flmn, (2 * jnp.arange(L) + 1) / (8 * jnp.pi**2))
114131
x = x.at[m_offset:, m_offset:].set(
115132
jnp.einsum(
116-
"nlm,lam,lan,l->amn",
117-
flmn[n_start_ind:],
118-
delta[0],
119-
delta[0][:, :, L - 1 + n],
120-
(2 * jnp.arange(L) + 1) / (8 * jnp.pi**2),
133+
"nlm,lam,lan->amn", flmn[n_start_ind:], Delta, Delta[:, :, L - 1 + n]
121134
)
122135
)
136+
123137
# APPLY SIGN FUNCTION AND PHASE SHIFT
124138
x = jnp.einsum("amn,m,n,a->nam", x, 1j ** (-m), 1j ** (n), jnp.exp(1j * m * theta0))
125139

@@ -136,14 +150,19 @@ def inverse_transform_jax(
136150

137151

138152
def forward_transform(
139-
f: np.ndarray, delta: np.ndarray, L: int, N: int, reality: bool, sampling: str
153+
f: np.ndarray,
154+
DW: np.ndarray,
155+
L: int,
156+
N: int,
157+
reality: bool = False,
158+
sampling: str = "mw",
140159
) -> np.ndarray:
141160
"""
142161
Computes the forward Wigner transform using the Fourier decomposition algorithm.
143162
144163
Args:
145164
f (np.ndarray): Function sampled on the rotation group.
146-
delta (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced
165+
DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
147166
Wigner d-functions and the corresponding upsampled quadrature weights.
148167
L (int): Harmonic band-limit.
149168
N (int): Azimuthal band-limit.
@@ -157,6 +176,14 @@ def forward_transform(
157176
np.ndarray: Wigner coefficients of function f.
158177
159178
"""
179+
if sampling.lower() not in ["mw", "mwss"]:
180+
raise ValueError(
181+
f"Fourier-Wigner algorithm does not support {sampling} sampling."
182+
)
183+
184+
# EXTRACT VARIOUS PRECOMPUTES
185+
Delta, Quads = DW
186+
160187
# INDEX VALUES
161188
n_start_ind = N - 1 if reality else 0
162189
m_offset = 1 if sampling.lower() == "mwss" else 0
@@ -193,14 +220,17 @@ def forward_transform(
193220
x = np.fft.ifft(x, axis=1, norm="forward")
194221

195222
# PERFORM QUADRATURE CONVOLUTION AS FFT REWEIGHTING IN REAL SPACE
196-
x = np.einsum("nbm,b->nbm", x, delta[1])
223+
# NB: Our convention here is conjugate to that of SSHT, in which
224+
# the weights are conjugate but applied flipped and therefore are
225+
# equivalent. To avoid flipping here he simply conjugate the weights.
226+
x = np.einsum("nbm,b->nbm", x, Quads)
197227

198228
# COMPUTE GMM BY FFT
199229
x = np.fft.fft(x, axis=1, norm="forward")
200230
x = np.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]
201231

202-
# Calculate flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
203-
x = np.einsum("nam,lam,lan->nlm", x, delta[0], delta[0][:, :, L - 1 + n])
232+
# Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
233+
x = np.einsum("nam,lam,lan->nlm", x, Delta, Delta[:, :, L - 1 + n])
204234
x = np.einsum("nbm,m,n->nbm", x, 1j ** (m), 1j ** (-n))
205235

206236
# SYMMETRY REFLECT FOR N < 0
@@ -218,14 +248,19 @@ def forward_transform(
218248

219249
@partial(jit, static_argnums=(2, 3, 4, 5))
220250
def forward_transform_jax(
221-
f: jnp.ndarray, delta: jnp.ndarray, L: int, N: int, reality: bool, sampling: str
251+
f: jnp.ndarray,
252+
DW: jnp.ndarray,
253+
L: int,
254+
N: int,
255+
reality: bool = False,
256+
sampling: str = "mw",
222257
) -> jnp.ndarray:
223258
"""
224259
Computes the forward Wigner transform using the Fourier decomposition algorithm (JAX).
225260
226261
Args:
227262
f (jnp.ndarray): Function sampled on the rotation group.
228-
delta (Tuple[jnp.ndarray, jnp.ndarray]): Fourier coefficients of the reduced
263+
DW (Tuple[jnp.ndarray, jnp.ndarray]): Fourier coefficients of the reduced
229264
Wigner d-functions and the corresponding upsampled quadrature weights.
230265
L (int): Harmonic band-limit.
231266
N (int): Azimuthal band-limit.
@@ -239,6 +274,14 @@ def forward_transform_jax(
239274
jnp.ndarray: Wigner coefficients of function f.
240275
241276
"""
277+
if sampling.lower() not in ["mw", "mwss"]:
278+
raise ValueError(
279+
f"Fourier-Wigner algorithm does not support {sampling} sampling."
280+
)
281+
282+
# EXTRACT VARIOUS PRECOMPUTES
283+
Delta, Quads = DW
284+
242285
# INDEX VALUES
243286
n_start_ind = N - 1 if reality else 0
244287
m_offset = 1 if sampling.lower() == "mwss" else 0
@@ -275,14 +318,17 @@ def forward_transform_jax(
275318
x = jnp.fft.ifft(x, axis=1, norm="forward")
276319

277320
# PERFORM QUADRATURE CONVOLUTION AS FFT REWEIGHTING IN REAL SPACE
278-
x = jnp.einsum("nbm,b->nbm", x, delta[1])
321+
# NB: Our convention here is conjugate to that of SSHT, in which
322+
# the weights are conjugate but applied flipped and therefore are
323+
# equivalent. To avoid flipping here he simply conjugate the weights.
324+
x = jnp.einsum("nbm,b->nbm", x, Quads)
279325

280326
# COMPUTE GMM BY FFT
281327
x = jnp.fft.fft(x, axis=1, norm="forward")
282328
x = jnp.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]
283329

284-
# Calculate flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
285-
x = jnp.einsum("nam,lam,lan->nlm", x, delta[0], delta[0][:, :, L - 1 + n])
330+
# Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
331+
x = jnp.einsum("nam,lam,lan->nlm", x, Delta, Delta[:, :, L - 1 + n])
286332
x = jnp.einsum("nbm,m,n->nbm", x, 1j ** (m), 1j ** (-n))
287333

288334
# SYMMETRY REFLECT FOR N < 0

0 commit comments

Comments
 (0)