Skip to content

Commit 4665eec

Browse files
committed
address JDM review, switch from FFT to DFT
1 parent 0a6828d commit 4665eec

File tree

2 files changed

+55
-44
lines changed

2 files changed

+55
-44
lines changed

s2fft/precompute_transforms/custom_ops.py

Lines changed: 54 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@ def wigner_subset_to_s2(
1717
Transforms an arbitrary subset of Wigner coefficients onto a subset of spin signals
1818
on the sphere.
1919
20-
This function takes a collection of spin spherical harmonic coefficients each with
21-
a different (though not necessarily unique) spin and maps them back to their
22-
corresponding pixel-space representations. Following this operation one may
23-
liftn this collection of spin signals to a signal on SO(3) by exploiting the
24-
correct Mackey functions.
20+
This function takes a collection of spin spherical harmonic coefficients each with
21+
a different (though not necessarily unique) spin and maps them back to their
22+
corresponding pixel-space representations.
2523
2624
Args:
2725
flmn (np.ndarray): Collection of spin spherical harmonic coefficients
@@ -33,15 +31,25 @@ def wigner_subset_to_s2(
3331
sampling (str, optional): Sampling scheme. Supported sampling schemes include
3432
{"mw", "mwss"}. Defaults to "mw".
3533
34+
Raises:
35+
ValueError: If sampling scheme is not recognised.
36+
ValueError: If the number of spins does not match the number of Wigner coefficients.
37+
3638
Returns:
37-
np.ndarray: A collection of spin signals with shape :math:`[batch, n_s, n_\theta, n_\phi, channels]`.
39+
np.ndarray: A collection of spin signals with
40+
shape :math:`[batch, n_s, n_\theta, n_\phi, channels]`.
3841
3942
"""
4043
if sampling.lower() not in ["mw", "mwss"]:
4144
raise ValueError(
4245
f"Fourier-Wigner algorithm does not support {sampling} sampling."
4346
)
4447

48+
if flmn.shape[1] != spins.shape[0]:
49+
raise ValueError(
50+
f"Number of spins specified {spins.shape[0]} does not match the number of Wigner coefficients {flmn.shape[1]}"
51+
)
52+
4553
# EXTRACT VARIOUS PRECOMPUTES
4654
Delta, _ = DW
4755

@@ -93,9 +101,7 @@ def wigner_subset_to_s2_jax(
93101
94102
This function takes a collection of spin spherical harmonic coefficients each with
95103
a different (though not necessarily unique) spin and maps them back to their
96-
corresponding pixel-space representations. Following this operation one may
97-
liftn this collection of spin signals to a signal on SO(3) by exploiting the
98-
correct Mackey functions.
104+
corresponding pixel-space representations.
99105
100106
Args:
101107
flmn (jnp.ndarray): Collection of spin spherical harmonic coefficients
@@ -107,6 +113,10 @@ def wigner_subset_to_s2_jax(
107113
sampling (str, optional): Sampling scheme. Supported sampling schemes include
108114
{"mw", "mwss"}. Defaults to "mw".
109115
116+
Raises:
117+
ValueError: If sampling scheme is not recognised.
118+
ValueError: If the number of spins does not match the number of Wigner coefficients.
119+
110120
Returns:
111121
jnp.ndarray: A collection of spin signals with shape :math:`[batch, n_s, n_\theta, n_\phi, channels]`.
112122
@@ -116,6 +126,11 @@ def wigner_subset_to_s2_jax(
116126
f"Fourier-Wigner algorithm does not support {sampling} sampling."
117127
)
118128

129+
if flmn.shape[1] != spins.shape[0]:
130+
raise ValueError(
131+
f"Number of spins specified {spins.shape[0]} does not match the number of Wigner coefficients {flmn.shape[1]}"
132+
)
133+
119134
# EXTRACT VARIOUS PRECOMPUTES
120135
Delta, _ = DW
121136

@@ -167,10 +182,10 @@ def so3_to_wigner_subset(
167182
Transforms a signal on the rotation group to an arbitrary subset of its Wigner
168183
coefficients.
169184
170-
This function takes a signal on the rotation group SO(3) and computes a subset of
171-
spin spherical harmonic coefficients corresponding to slices across the requested
172-
spin numbers. These spin numbers can be arbitrarily chosen such that their absolute
173-
value is less than or equal to the azimuthal band-limit :math:`N\leq L`.
185+
This function takes a signal on the rotation group SO(3) and computes a subset of
186+
spin spherical harmonic coefficients corresponding to slices across the requested
187+
spin numbers. These spin numbers can be arbitrarily chosen such that their absolute
188+
value is less than or equal to the azimuthal band-limit :math:`N\leq L`.
174189
175190
Args:
176191
f (np.ndarray): Signal on the rotation group with shape :math:`[batch, n_\gamma, n_\theta,n_\phi, channels]`.
@@ -186,12 +201,11 @@ def so3_to_wigner_subset(
186201
np.ndarray: Collection of spin spherical harmonic coefficients with shape :math:`[batch, n_s, L, 2L-1, channels]`.
187202
188203
"""
189-
# COMPUTE FFT OVER GAMMA
190-
x = np.fft.fft(f, axis=-4, norm="forward")
191-
x = np.fft.fftshift(x, axes=-4)
192-
193-
# EXTRACT REQUESTED SPIN COMPONENTS
194-
x = x[:, N - 1 - spins]
204+
# COMPUTE DFT OVER GAMMA SUBSET
205+
e = np.exp(
206+
-2j * np.pi * np.einsum("g,n->gn", np.arange(f.shape[1]) / f.shape[1], -spins)
207+
)
208+
x = np.einsum("bgtpc,gn->bntpc", f, e) / f.shape[1]
195209

196210
return s2_to_wigner_subset(x, spins, DW, L, sampling)
197211

@@ -209,10 +223,10 @@ def so3_to_wigner_subset_jax(
209223
Transforms a signal on the rotation group to an arbitrary subset of its Wigner
210224
coefficients (JAX).
211225
212-
This function takes a signal on the rotation group SO(3) and computes a subset of
213-
spin spherical harmonic coefficients corresponding to slices across the requested
214-
spin numbers. These spin numbers can be arbitrarily chosen such that their absolute
215-
value is less than or equal to the azimuthal band-limit :math:`N\leq L`.
226+
This function takes a signal on the rotation group SO(3) and computes a subset of
227+
spin spherical harmonic coefficients corresponding to slices across the requested
228+
spin numbers. These spin numbers can be arbitrarily chosen such that their absolute
229+
value is less than or equal to the azimuthal band-limit :math:`N\leq L`.
216230
217231
Args:
218232
f (jnp.ndarray): Signal on the rotation group with shape :math:`[batch, n_\gamma, n_\theta,n_\phi, channels]`.
@@ -229,12 +243,13 @@ def so3_to_wigner_subset_jax(
229243
with shape :math:`[batch, n_s, L, 2L-1, channels]`.
230244
231245
"""
232-
# COMPUTE FFT OVER GAMMA
233-
x = jnp.fft.fft(f, axis=-4, norm="forward")
234-
x = jnp.fft.fftshift(x, axes=-4)
235-
236-
# EXTRACT REQUESTED SPIN COMPONENTS
237-
x = x[:, N - 1 - spins]
246+
# COMPUTE DFT OVER GAMMA SUBSET
247+
e = jnp.exp(
248+
-2j
249+
* jnp.pi
250+
* jnp.einsum("g,n->gn", jnp.arange(f.shape[1]) / f.shape[1], -spins)
251+
)
252+
x = jnp.einsum("bgtpc,gn->bntpc", f, e) / f.shape[1]
238253

239254
return s2_to_wigner_subset_jax(x, spins, DW, L, sampling)
240255

@@ -250,11 +265,11 @@ def s2_to_wigner_subset(
250265
Transforms from a collection of arbitrary spin signals on the sphere to the
251266
corresponding collection of their harmonic coefficients.
252267
253-
This function takes a multimodal collection of spin spherical harmonic signals
254-
on the sphere and transforms them into their spin spherical harmonic coefficients.
255-
These cofficients may then be combined into a subset of Wigner coefficients for
256-
downstream analysis. In this way one may combine input features across a variety
257-
of spins into a unified representation.
268+
This function takes a multimodal collection of spin spherical harmonic signals
269+
on the sphere and transforms them into their spin spherical harmonic coefficients.
270+
These coefficients may then be combined into a subset of Wigner coefficients for
271+
downstream analysis. In this way one may combine input features across a variety
272+
of spins into a unified representation.
258273
259274
Args:
260275
fs (np.ndarray): Collection of spin signal maps on the sphere with shape :math:`[batch, n_s, n_\theta,n_\phi, channels]`.
@@ -336,11 +351,11 @@ def s2_to_wigner_subset_jax(
336351
Transforms from a collection of arbitrary spin signals on the sphere to the
337352
corresponding collection of their harmonic coefficients (JAX).
338353
339-
This function takes a multimodal collection of spin spherical harmonic signals
340-
on the sphere and transforms them into their spin spherical harmonic coefficients.
341-
These cofficients may then be combined into a subset of Wigner coefficients for
342-
downstream analysis. In this way one may combine input features across a variety
343-
of spins into a unified representation.
354+
This function takes a multimodal collection of spin spherical harmonic signals
355+
on the sphere and transforms them into their spin spherical harmonic coefficients.
356+
These coefficients may then be combined into a subset of Wigner coefficients for
357+
downstream analysis. In this way one may combine input features across a variety
358+
of spins into a unified representation.
344359
345360
Args:
346361
fs (jnp.ndarray): Collection of spin signal maps on the sphere with shape :math:`[batch, n_s, n_\theta,n_\phi, channels]`.

tests/test_lifting_transforms.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,7 @@ def test_custom_forward_from_so3(
8181
spins = -np.arange(-N + 1, N)
8282

8383
# FUNCTION SWITCH
84-
func = (
85-
ops.so3_to_wigner_subset_jax
86-
if method == "jax"
87-
else ops.so3_to_wigner_subset_jax
88-
)
84+
func = ops.so3_to_wigner_subset_jax if method == "jax" else ops.so3_to_wigner_subset
8985

9086
# CREATE CORRECT SHAPE (BATCH: 1, CHANNELS: 1)
9187
f = f.reshape((1,) + f.shape + (1,))

0 commit comments

Comments
 (0)