@@ -84,11 +84,7 @@ def spin_spherical_kernel(
84
84
delta [:, L - 1 - spin ],
85
85
1j ** (- spin - m_value [m_start_ind :]),
86
86
)
87
- if sampling .lower () in ["mw" , "dh" ]:
88
- temp = np .einsum ("am,a->am" , temp , np .exp (1j * m_value * thetas [0 ]))
89
- else :
90
- temp_new = np .zeros ((L + 1 , m_dim ), dtype = temp .dtype )
91
- temp_new [:- 1 ] = temp [L - 1 :]
87
+ temp = np .einsum ("am,a->am" , temp , np .exp (1j * m_value * thetas [0 ]))
92
88
temp = np .fft .irfft (temp [L - 1 :], n = nsamps , axis = 0 , norm = "forward" )
93
89
94
90
dl [:, el ] = temp [: len (thetas )]
@@ -173,14 +169,7 @@ def spin_spherical_kernel_jax(
173
169
delta [:, L - 1 - spin ],
174
170
1j ** (- spin - m_value [m_start_ind :]),
175
171
)
176
- if sampling .lower () in ["mw" , "dh" ]:
177
- temp = jnp .einsum (
178
- "am,a->am" , temp , jnp .exp (1j * m_value * thetas [0 ])
179
- )
180
- else :
181
- temp_new = jnp .zeros ((L + 1 , m_dim ), dtype = temp .dtype )
182
- temp_new = temp_new .at [:- 1 ].set (temp [L - 1 :])
183
-
172
+ temp = jnp .einsum ("am,a->am" , temp , jnp .exp (1j * m_value * thetas [0 ]))
184
173
temp = jnp .fft .irfft (temp [L - 1 :], n = nsamps , axis = 0 , norm = "forward" )
185
174
186
175
dl = dl .at [:, el ].set (temp [: len (thetas )])
@@ -253,32 +242,35 @@ def wigner_kernel(
253
242
raise ValueError ("Sampling in supported list [mw, mwss, dh]" )
254
243
255
244
# Compute Wigner d-functions from their Fourier decomposition.
256
- delta = np .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = np .float64 )
245
+ if N <= int (L / np .log (L )):
246
+ delta = np .zeros ((len (thetas ), 2 * L - 1 , 2 * L - 1 ), dtype = np .float64 )
247
+ else :
248
+ delta = np .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = np .float64 )
257
249
dl = np .zeros ((n_dim , len (thetas ), L , 2 * L - 1 ), dtype = np .float64 )
258
250
259
251
# Range values which need only be defined once.
260
252
m_value = np .arange (- L + 1 , L )
261
253
n = np .arange (n_start_ind - N + 1 , N )
262
254
255
+ # If N <= L/LogL more efficient to manually compute over FFT
263
256
for el in range (L ):
264
- delta = recursions .risbo .compute_full (delta , np .pi / 2 , L , el )
265
- temp = np .einsum (
266
- "am,an,m,n->amn" ,
267
- delta ,
268
- delta [:, L - 1 + n ],
269
- 1j ** (- m_value ),
270
- 1j ** (n ),
271
- )
272
- if sampling .lower () in ["mw" , "dh" ]:
257
+ if N <= int (L / np .log (L )):
258
+ delta = recursions .risbo .compute_full_vect (delta , thetas , L , el )
259
+ dl [:, :, el ] = np .moveaxis (delta , - 1 , 0 )[L - 1 + n ]
260
+ else :
261
+ delta = recursions .risbo .compute_full (delta , np .pi / 2 , L , el )
262
+ temp = np .einsum (
263
+ "am,an,m,n->amn" ,
264
+ delta ,
265
+ delta [:, L - 1 + n ],
266
+ 1j ** (- m_value ),
267
+ 1j ** (n ),
268
+ )
273
269
temp = np .einsum (
274
270
"amn,a->amn" , temp , np .exp (1j * m_value * thetas [0 ])
275
271
)
276
-
277
- else :
278
- temp_new = np .zeros ((L + 1 , 2 * L - 1 , len (n )), dtype = temp .dtype )
279
- temp_new [:- 1 ] = temp [L - 1 :]
280
- temp = np .fft .irfft (temp [L - 1 :], n = nsamps , axis = 0 , norm = "forward" )
281
- dl [:, :, el ] = np .moveaxis (temp [: len (thetas )], - 1 , 0 )
272
+ temp = np .fft .irfft (temp [L - 1 :], n = nsamps , axis = 0 , norm = "forward" )
273
+ dl [:, :, el ] = np .moveaxis (temp [: len (thetas )], - 1 , 0 )
282
274
283
275
if forward :
284
276
weights = quadrature .quad_weights_transform (L , sampling )
@@ -351,32 +343,42 @@ def wigner_kernel_jax(
351
343
raise ValueError ("Sampling in supported list [mw, mwss, dh]" )
352
344
353
345
# Compute Wigner d-functions from their Fourier decomposition.
354
- delta = jnp .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
346
+ if N <= int (L / np .log (L )):
347
+ delta = jnp .zeros (
348
+ (len (thetas ), 2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64
349
+ )
350
+ vfunc = jax .vmap (
351
+ recursions .risbo_jax .compute_full , in_axes = (0 , 0 , None , None )
352
+ )
353
+ else :
354
+ delta = jnp .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
355
355
dl = jnp .zeros ((n_dim , len (thetas ), L , 2 * L - 1 ), dtype = jnp .float64 )
356
356
357
357
# Range values which need only be defined once.
358
358
m_value = jnp .arange (- L + 1 , L )
359
359
n = jnp .arange (n_start_ind - N + 1 , N )
360
360
361
+ # If N <= L/LogL more efficient to manually compute over FFT
361
362
for el in range (L ):
362
- delta = recursions .risbo_jax .compute_full (delta , jnp .pi / 2 , L , el )
363
- temp = jnp .einsum (
364
- "am,an,m,n->amn" ,
365
- delta ,
366
- delta [:, L - 1 + n ],
367
- 1j ** (- m_value ),
368
- 1j ** (n ),
369
- )
370
- if sampling .lower () in ["mw" , "dh" ]:
363
+ if N <= int (L / np .log (L )):
364
+ delta = vfunc (delta , thetas , L , el )
365
+ dl = dl .at [:, :, el ].set (jnp .moveaxis (delta , - 1 , 0 )[L - 1 + n ])
366
+ else :
367
+ delta = recursions .risbo_jax .compute_full (delta , jnp .pi / 2 , L , el )
368
+ temp = jnp .einsum (
369
+ "am,an,m,n->amn" ,
370
+ delta ,
371
+ delta [:, L - 1 + n ],
372
+ 1j ** (- m_value ),
373
+ 1j ** (n ),
374
+ )
371
375
temp = jnp .einsum (
372
376
"amn,a->amn" , temp , jnp .exp (1j * m_value * thetas [0 ])
373
377
)
374
-
375
- else :
376
- temp_new = jnp .zeros ((L + 1 , 2 * L - 1 , len (n )), dtype = temp .dtype )
377
- temp_new = temp_new .at [:- 1 ].set (temp [L - 1 :])
378
- temp = jnp .fft .irfft (temp [L - 1 :], n = nsamps , axis = 0 , norm = "forward" )
379
- dl = dl .at [:, :, el ].set (jnp .moveaxis (temp [: len (thetas )], - 1 , 0 ))
378
+ temp = jnp .fft .irfft (
379
+ temp [L - 1 :], n = nsamps , axis = 0 , norm = "forward"
380
+ )
381
+ dl = dl .at [:, :, el ].set (jnp .moveaxis (temp [: len (thetas )], - 1 , 0 ))
380
382
381
383
if forward :
382
384
weights = quadrature_jax .quad_weights_transform (L , sampling )
0 commit comments