7
7
8
8
def inverse_transform (
9
9
flmn : np .ndarray ,
10
- delta : np .ndarray ,
10
+ DW : np .ndarray ,
11
11
L : int ,
12
12
N : int ,
13
13
reality : bool = False ,
@@ -18,7 +18,7 @@ def inverse_transform(
18
18
19
19
Args:
20
20
flmn (np.ndarray): Wigner coefficients.
21
- delta (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced
21
+ DW (Tuple[np.ndarray, np.ndarray], optional ): Fourier coefficients of the reduced
22
22
Wigner d-functions and the corresponding upsampled quadrature weights.
23
23
L (int): Harmonic band-limit.
24
24
N (int): Azimuthal band-limit.
@@ -32,6 +32,14 @@ def inverse_transform(
32
32
np.ndarray: Pixel-space function sampled on the rotation group.
33
33
34
34
"""
35
+ if sampling .lower () not in ["mw" , "mwss" ]:
36
+ raise ValueError (
37
+ f"Fourier-Wigner algorithm does not support { sampling } sampling."
38
+ )
39
+
40
+ # EXTRACT VARIOUS PRECOMPUTES
41
+ Delta , _ = DW
42
+
35
43
# INDEX VALUES
36
44
n_start_ind = N - 1 if reality else 0
37
45
n_dim = N if reality else 2 * N - 1
@@ -44,13 +52,13 @@ def inverse_transform(
44
52
m = np .arange (- L + 1 - m_offset , L )
45
53
n = np .arange (n_start_ind - N + 1 , N )
46
54
47
- # Calculate fmna = i^(n-m)\sum_L delta ^l_am delta ^l_an f^l_mn(2l+1)/(8pi^2)
55
+ # Calculate fmna = i^(n-m)\sum_L Delta ^l_am Delta ^l_an f^l_mn(2l+1)/(8pi^2)
48
56
x = np .zeros ((xnlm_size , xnlm_size , n_dim ), dtype = flmn .dtype )
49
57
x [m_offset :, m_offset :] = np .einsum (
50
58
"nlm,lam,lan,l->amn" ,
51
59
flmn [n_start_ind :],
52
- delta [ 0 ] ,
53
- delta [ 0 ] [:, :, L - 1 + n ],
60
+ Delta ,
61
+ Delta [:, :, L - 1 + n ],
54
62
(2 * np .arange (L ) + 1 ) / (8 * np .pi ** 2 ),
55
63
)
56
64
@@ -72,7 +80,7 @@ def inverse_transform(
72
80
@partial (jit , static_argnums = (2 , 3 , 4 , 5 ))
73
81
def inverse_transform_jax (
74
82
flmn : jnp .ndarray ,
75
- delta : jnp .ndarray ,
83
+ DW : jnp .ndarray ,
76
84
L : int ,
77
85
N : int ,
78
86
reality : bool = False ,
@@ -83,7 +91,7 @@ def inverse_transform_jax(
83
91
84
92
Args:
85
93
flmn (jnp.ndarray): Wigner coefficients.
86
- delta (Tuple[jnp .ndarray, jnp .ndarray]): Fourier coefficients of the reduced
94
+ DW (Tuple[np .ndarray, np .ndarray]): Fourier coefficients of the reduced
87
95
Wigner d-functions and the corresponding upsampled quadrature weights.
88
96
L (int): Harmonic band-limit.
89
97
N (int): Azimuthal band-limit.
@@ -97,6 +105,14 @@ def inverse_transform_jax(
97
105
jnp.ndarray: Pixel-space function sampled on the rotation group.
98
106
99
107
"""
108
+ if sampling .lower () not in ["mw" , "mwss" ]:
109
+ raise ValueError (
110
+ f"Fourier-Wigner algorithm does not support { sampling } sampling."
111
+ )
112
+
113
+ # EXTRACT VARIOUS PRECOMPUTES
114
+ Delta , _ = DW
115
+
100
116
# INDEX VALUES
101
117
n_start_ind = N - 1 if reality else 0
102
118
n_dim = N if reality else 2 * N - 1
@@ -109,17 +125,15 @@ def inverse_transform_jax(
109
125
m = jnp .arange (- L + 1 - m_offset , L )
110
126
n = jnp .arange (n_start_ind - N + 1 , N )
111
127
112
- # Calculate fmna = i^(n-m)\sum_L delta^l_am delta^l_an f^l_mn(2l+1)/(8pi^2)
113
- x = jnp .zeros ((xnlm_size , xnlm_size , n_dim ), dtype = flmn .dtype )
128
+ # Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2)
129
+ x = jnp .zeros ((xnlm_size , xnlm_size , n_dim ), dtype = jnp .complex128 )
130
+ flmn = jnp .einsum ("nlm,l->nlm" , flmn , (2 * jnp .arange (L ) + 1 ) / (8 * jnp .pi ** 2 ))
114
131
x = x .at [m_offset :, m_offset :].set (
115
132
jnp .einsum (
116
- "nlm,lam,lan,l->amn" ,
117
- flmn [n_start_ind :],
118
- delta [0 ],
119
- delta [0 ][:, :, L - 1 + n ],
120
- (2 * jnp .arange (L ) + 1 ) / (8 * jnp .pi ** 2 ),
133
+ "nlm,lam,lan->amn" , flmn [n_start_ind :], Delta , Delta [:, :, L - 1 + n ]
121
134
)
122
135
)
136
+
123
137
# APPLY SIGN FUNCTION AND PHASE SHIFT
124
138
x = jnp .einsum ("amn,m,n,a->nam" , x , 1j ** (- m ), 1j ** (n ), jnp .exp (1j * m * theta0 ))
125
139
@@ -136,14 +150,19 @@ def inverse_transform_jax(
136
150
137
151
138
152
def forward_transform (
139
- f : np .ndarray , delta : np .ndarray , L : int , N : int , reality : bool , sampling : str
153
+ f : np .ndarray ,
154
+ DW : np .ndarray ,
155
+ L : int ,
156
+ N : int ,
157
+ reality : bool = False ,
158
+ sampling : str = "mw" ,
140
159
) -> np .ndarray :
141
160
"""
142
161
Computes the forward Wigner transform using the Fourier decomposition algorithm.
143
162
144
163
Args:
145
164
f (np.ndarray): Function sampled on the rotation group.
146
- delta (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced
165
+ DW (Tuple[np.ndarray, np.ndarray], optional ): Fourier coefficients of the reduced
147
166
Wigner d-functions and the corresponding upsampled quadrature weights.
148
167
L (int): Harmonic band-limit.
149
168
N (int): Azimuthal band-limit.
@@ -157,6 +176,14 @@ def forward_transform(
157
176
np.ndarray: Wigner coefficients of function f.
158
177
159
178
"""
179
+ if sampling .lower () not in ["mw" , "mwss" ]:
180
+ raise ValueError (
181
+ f"Fourier-Wigner algorithm does not support { sampling } sampling."
182
+ )
183
+
184
+ # EXTRACT VARIOUS PRECOMPUTES
185
+ Delta , Quads = DW
186
+
160
187
# INDEX VALUES
161
188
n_start_ind = N - 1 if reality else 0
162
189
m_offset = 1 if sampling .lower () == "mwss" else 0
@@ -193,14 +220,17 @@ def forward_transform(
193
220
x = np .fft .ifft (x , axis = 1 , norm = "forward" )
194
221
195
222
# PERFORM QUADRATURE CONVOLUTION AS FFT REWEIGHTING IN REAL SPACE
196
- x = np .einsum ("nbm,b->nbm" , x , delta [1 ])
223
+ # NB: Our convention here is conjugate to that of SSHT, in which
224
+ # the weights are conjugate but applied flipped and therefore are
225
+ # equivalent. To avoid flipping here he simply conjugate the weights.
226
+ x = np .einsum ("nbm,b->nbm" , x , Quads )
197
227
198
228
# COMPUTE GMM BY FFT
199
229
x = np .fft .fft (x , axis = 1 , norm = "forward" )
200
230
x = np .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
201
231
202
- # Calculate flmn = i^(n-m)\sum_t delta ^l_tm delta ^l_tn G_mnt
203
- x = np .einsum ("nam,lam,lan->nlm" , x , delta [ 0 ], delta [ 0 ] [:, :, L - 1 + n ])
232
+ # Calculate flmn = i^(n-m)\sum_t Delta ^l_tm Delta ^l_tn G_mnt
233
+ x = np .einsum ("nam,lam,lan->nlm" , x , Delta , Delta [:, :, L - 1 + n ])
204
234
x = np .einsum ("nbm,m,n->nbm" , x , 1j ** (m ), 1j ** (- n ))
205
235
206
236
# SYMMETRY REFLECT FOR N < 0
@@ -218,14 +248,19 @@ def forward_transform(
218
248
219
249
@partial (jit , static_argnums = (2 , 3 , 4 , 5 ))
220
250
def forward_transform_jax (
221
- f : jnp .ndarray , delta : jnp .ndarray , L : int , N : int , reality : bool , sampling : str
251
+ f : jnp .ndarray ,
252
+ DW : jnp .ndarray ,
253
+ L : int ,
254
+ N : int ,
255
+ reality : bool = False ,
256
+ sampling : str = "mw" ,
222
257
) -> jnp .ndarray :
223
258
"""
224
259
Computes the forward Wigner transform using the Fourier decomposition algorithm (JAX).
225
260
226
261
Args:
227
262
f (jnp.ndarray): Function sampled on the rotation group.
228
- delta (Tuple[jnp.ndarray, jnp.ndarray]): Fourier coefficients of the reduced
263
+ DW (Tuple[jnp.ndarray, jnp.ndarray]): Fourier coefficients of the reduced
229
264
Wigner d-functions and the corresponding upsampled quadrature weights.
230
265
L (int): Harmonic band-limit.
231
266
N (int): Azimuthal band-limit.
@@ -239,6 +274,14 @@ def forward_transform_jax(
239
274
jnp.ndarray: Wigner coefficients of function f.
240
275
241
276
"""
277
+ if sampling .lower () not in ["mw" , "mwss" ]:
278
+ raise ValueError (
279
+ f"Fourier-Wigner algorithm does not support { sampling } sampling."
280
+ )
281
+
282
+ # EXTRACT VARIOUS PRECOMPUTES
283
+ Delta , Quads = DW
284
+
242
285
# INDEX VALUES
243
286
n_start_ind = N - 1 if reality else 0
244
287
m_offset = 1 if sampling .lower () == "mwss" else 0
@@ -275,14 +318,17 @@ def forward_transform_jax(
275
318
x = jnp .fft .ifft (x , axis = 1 , norm = "forward" )
276
319
277
320
# PERFORM QUADRATURE CONVOLUTION AS FFT REWEIGHTING IN REAL SPACE
278
- x = jnp .einsum ("nbm,b->nbm" , x , delta [1 ])
321
+ # NB: Our convention here is conjugate to that of SSHT, in which
322
+ # the weights are conjugate but applied flipped and therefore are
323
+ # equivalent. To avoid flipping here he simply conjugate the weights.
324
+ x = jnp .einsum ("nbm,b->nbm" , x , Quads )
279
325
280
326
# COMPUTE GMM BY FFT
281
327
x = jnp .fft .fft (x , axis = 1 , norm = "forward" )
282
328
x = jnp .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
283
329
284
- # Calculate flmn = i^(n-m)\sum_t delta ^l_tm delta ^l_tn G_mnt
285
- x = jnp .einsum ("nam,lam,lan->nlm" , x , delta [ 0 ], delta [ 0 ] [:, :, L - 1 + n ])
330
+ # Calculate flmn = i^(n-m)\sum_t Delta ^l_tm Delta ^l_tn G_mnt
331
+ x = jnp .einsum ("nam,lam,lan->nlm" , x , Delta , Delta [:, :, L - 1 + n ])
286
332
x = jnp .einsum ("nbm,m,n->nbm" , x , 1j ** (m ), 1j ** (- n ))
287
333
288
334
# SYMMETRY REFLECT FOR N < 0
0 commit comments