@@ -74,9 +74,7 @@ def spin_spherical_kernel(
74
74
if recursion .lower () == "auto" :
75
75
# This mode automatically determines which recursion is best suited for the
76
76
# current parameter configuration.
77
- recursion = (
78
- "risbo" if abs (spin ) >= PM_MAX_STABLE_SPIN else "price-mcewen"
79
- )
77
+ recursion = "risbo" if abs (spin ) >= PM_MAX_STABLE_SPIN else "price-mcewen"
80
78
81
79
dl = []
82
80
m_start_ind = L - 1 if reality else 0
@@ -111,13 +109,9 @@ def spin_spherical_kernel(
111
109
# - The complexity of this approach is O(L^4).
112
110
# - This approach is stable for arbitrary abs(spins) <= L.
113
111
if sampling .lower () in ["healpix" , "gl" ]:
114
- delta = np .zeros (
115
- (len (thetas ), 2 * L - 1 , 2 * L - 1 ), dtype = np .float64
116
- )
112
+ delta = np .zeros ((len (thetas ), 2 * L - 1 , 2 * L - 1 ), dtype = np .float64 )
117
113
for el in range (L ):
118
- delta = recursions .risbo .compute_full_vectorised (
119
- delta , thetas , L , el
120
- )
114
+ delta = recursions .risbo .compute_full_vectorised (delta , thetas , L , el )
121
115
dl [:, el ] = delta [:, m_start_ind :, L - 1 - spin ]
122
116
123
117
# MW, MWSS, and DH sampling ARE uniform in theta therefore CAN be calculated
@@ -144,19 +138,13 @@ def spin_spherical_kernel(
144
138
delta [:, L - 1 - spin ],
145
139
1j ** (- spin - m_value [m_start_ind :]),
146
140
)
147
- temp = np .einsum (
148
- "am,a->am" , temp , np .exp (1j * m_value * thetas [0 ])
149
- )
150
- temp = np .fft .irfft (
151
- temp [L - 1 :], n = nsamps , axis = 0 , norm = "forward"
152
- )
141
+ temp = np .einsum ("am,a->am" , temp , np .exp (1j * m_value * thetas [0 ]))
142
+ temp = np .fft .irfft (temp [L - 1 :], n = nsamps , axis = 0 , norm = "forward" )
153
143
154
144
dl [:, el ] = temp [: len (thetas )]
155
145
156
146
# Fold in normalisation to avoid recomputation at run-time.
157
- dl = np .einsum (
158
- "tlm,l->tlm" , dl , np .sqrt ((2 * np .arange (L ) + 1 ) / (4 * np .pi ))
159
- )
147
+ dl = np .einsum ("tlm,l->tlm" , dl , np .sqrt ((2 * np .arange (L ) + 1 ) / (4 * np .pi )))
160
148
161
149
else :
162
150
raise ValueError (f"Recursion method { recursion } not recognised." )
@@ -234,9 +222,7 @@ def spin_spherical_kernel_jax(
234
222
if recursion .lower () == "auto" :
235
223
# This mode automatically determines which recursion is best suited for the
236
224
# current parameter configuration.
237
- recursion = (
238
- "risbo" if abs (spin ) >= PM_MAX_STABLE_SPIN else "price-mcewen"
239
- )
225
+ recursion = "risbo" if abs (spin ) >= PM_MAX_STABLE_SPIN else "price-mcewen"
240
226
241
227
dl = []
242
228
m_start_ind = L - 1 if reality else 0
@@ -283,9 +269,7 @@ def spin_spherical_kernel_jax(
283
269
# - The complexity of this approach is O(L^4).
284
270
# - This approach is stable for arbitrary abs(spins) <= L.
285
271
if sampling .lower () in ["healpix" , "gl" ]:
286
- delta = jnp .zeros (
287
- (len (thetas ), 2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64
288
- )
272
+ delta = jnp .zeros ((len (thetas ), 2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
289
273
vfunc = jax .vmap (
290
274
recursions .risbo_jax .compute_full , in_axes = (0 , 0 , None , None )
291
275
)
@@ -309,32 +293,24 @@ def spin_spherical_kernel_jax(
309
293
310
294
# Calculate the Fourier coefficients of the Wigner d-functions, delta(pi/2).
311
295
for el in range (L ):
312
- delta = recursions .risbo_jax .compute_full (
313
- delta , jnp .pi / 2 , L , el
314
- )
296
+ delta = recursions .risbo_jax .compute_full (delta , jnp .pi / 2 , L , el )
315
297
m_value = jnp .arange (- L + 1 , L )
316
298
temp = jnp .einsum (
317
299
"am,a,m->am" ,
318
300
delta [:, m_start_ind :],
319
301
delta [:, L - 1 - spin ],
320
302
1j ** (- spin - m_value [m_start_ind :]),
321
303
)
322
- temp = jnp .einsum (
323
- "am,a->am" , temp , jnp .exp (1j * m_value * thetas [0 ])
324
- )
325
- temp = jnp .fft .irfft (
326
- temp [L - 1 :], n = nsamps , axis = 0 , norm = "forward"
327
- )
304
+ temp = jnp .einsum ("am,a->am" , temp , jnp .exp (1j * m_value * thetas [0 ]))
305
+ temp = jnp .fft .irfft (temp [L - 1 :], n = nsamps , axis = 0 , norm = "forward" )
328
306
329
307
dl = dl .at [:, el ].set (temp [: len (thetas )])
330
308
331
309
else :
332
310
raise ValueError (f"Recursion method { recursion } not recognised." )
333
311
334
312
# Fold in normalisation to avoid recomputation at run-time.
335
- dl = jnp .einsum (
336
- "tlm,l->tlm" , dl , jnp .sqrt ((2 * jnp .arange (L ) + 1 ) / (4 * jnp .pi ))
337
- )
313
+ dl = jnp .einsum ("tlm,l->tlm" , dl , jnp .sqrt ((2 * jnp .arange (L ) + 1 ) / (4 * jnp .pi )))
338
314
339
315
# Fold in quadrature to avoid recomputation at run-time.
340
316
if forward :
@@ -433,9 +409,7 @@ def wigner_kernel(
433
409
if mode .lower () == "direct" :
434
410
delta = np .zeros ((len (thetas ), 2 * L - 1 , 2 * L - 1 ), dtype = np .float64 )
435
411
for el in range (L ):
436
- delta = recursions .risbo .compute_full_vectorised (
437
- delta , thetas , L , el
438
- )
412
+ delta = recursions .risbo .compute_full_vectorised (delta , thetas , L , el )
439
413
dl [:, :, el ] = np .moveaxis (delta , - 1 , 0 )[L - 1 + n ]
440
414
441
415
# MW, MWSS, and DH sampling ARE uniform in theta therefore CAN be calculated
@@ -464,9 +438,7 @@ def wigner_kernel(
464
438
1j ** (- m_value ),
465
439
1j ** (n ),
466
440
)
467
- temp = np .einsum (
468
- "amn,a->amn" , temp , np .exp (1j * m_value * thetas [0 ])
469
- )
441
+ temp = np .einsum ("amn,a->amn" , temp , np .exp (1j * m_value * thetas [0 ]))
470
442
temp = np .fft .irfft (temp [L - 1 :], n = nsamps , axis = 0 , norm = "forward" )
471
443
dl [:, :, el ] = np .moveaxis (temp [: len (thetas )], - 1 , 0 )
472
444
@@ -574,12 +546,8 @@ def wigner_kernel_jax(
574
546
# - The complexity of this approach is ALWAYS O(L^4).
575
547
# - This approach is stable for arbitrary abs(spins) <= L.
576
548
if mode .lower () == "direct" :
577
- delta = jnp .zeros (
578
- (len (thetas ), 2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64
579
- )
580
- vfunc = jax .vmap (
581
- recursions .risbo_jax .compute_full , in_axes = (0 , 0 , None , None )
582
- )
549
+ delta = jnp .zeros ((len (thetas ), 2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
550
+ vfunc = jax .vmap (recursions .risbo_jax .compute_full , in_axes = (0 , 0 , None , None ))
583
551
for el in range (L ):
584
552
delta = vfunc (delta , thetas , L , el )
585
553
dl = dl .at [:, :, el ].set (jnp .moveaxis (delta , - 1 , 0 )[L - 1 + n ])
@@ -610,12 +578,8 @@ def wigner_kernel_jax(
610
578
1j ** (- m_value ),
611
579
1j ** (n ),
612
580
)
613
- temp = jnp .einsum (
614
- "amn,a->amn" , temp , jnp .exp (1j * m_value * thetas [0 ])
615
- )
616
- temp = jnp .fft .irfft (
617
- temp [L - 1 :], n = nsamps , axis = 0 , norm = "forward"
618
- )
581
+ temp = jnp .einsum ("amn,a->amn" , temp , jnp .exp (1j * m_value * thetas [0 ]))
582
+ temp = jnp .fft .irfft (temp [L - 1 :], n = nsamps , axis = 0 , norm = "forward" )
619
583
dl = dl .at [:, :, el ].set (jnp .moveaxis (temp [: len (thetas )], - 1 , 0 ))
620
584
621
585
else :
@@ -646,9 +610,7 @@ def wigner_kernel_jax(
646
610
return dl
647
611
648
612
649
- def healpix_phase_shifts (
650
- L : int , nside : int , forward : bool = False
651
- ) -> np .ndarray :
613
+ def healpix_phase_shifts (L : int , nside : int , forward : bool = False ) -> np .ndarray :
652
614
r"""
653
615
Generates a phase shift vector for HEALPix for all :math:`\theta` rings.
654
616
0 commit comments