1
+ #![ feature( portable_simd) ]
2
+
1
3
use numpy:: PyReadonlyArray1 ;
2
4
use pyo3:: prelude:: * ;
3
- use wide:: * ;
5
+ use pyo3:: exceptions:: PyValueError ;
6
+ use pyo3:: Bound ;
4
7
// use pyo3::types::{PyBool, PyAny};
8
+ use wide:: * ;
9
+ // use std::simd::Simd;
10
+ // use std::simd::cmp::SimdPartialEq;
11
+
12
+ use numpy:: { IntoPyArray , PyArray2 , PyReadonlyArray2 } ;
13
+ use numpy:: ToPyArray ;
14
+ use numpy:: PyArrayMethods ;
15
+ use numpy:: PyUntypedArrayMethods ;
16
+
5
17
6
18
#[ pyfunction]
7
19
fn first_true_1d_a ( array : PyReadonlyArray1 < bool > ) -> isize {
@@ -190,51 +202,6 @@ fn first_true_1d_e(array: PyReadonlyArray1<bool>) -> isize {
190
202
}
191
203
}
192
204
193
- #[ pyfunction]
194
- fn first_true_1d_f ( py : Python , array : PyReadonlyArray1 < bool > ) -> isize {
195
- if let Ok ( slice) = array. as_slice ( ) {
196
- py. allow_threads ( || {
197
- let len = slice. len ( ) ;
198
- let ptr = slice. as_ptr ( ) as * const u8 ;
199
- let mut i = 0 ;
200
-
201
- let ones = u8x32:: splat ( 1 ) ;
202
- unsafe {
203
- // Process 32 bytes at a time with SIMD
204
- while i + 32 <= len {
205
- // Cast pointer to array reference
206
- let bytes = & * ( ptr. add ( i) as * const [ u8 ; 32 ] ) ;
207
-
208
- // Convert to SIMD vector
209
- let chunk = u8x32:: from ( * bytes) ;
210
- let equal_one = chunk. cmp_eq ( ones) ;
211
- if equal_one. any ( ) {
212
- break ;
213
- }
214
-
215
- i += 32 ;
216
- }
217
- // // Handle final remainder
218
- while i < len. min ( i + 32 ) {
219
- if * ptr. add ( i) != 0 {
220
- return i as isize ;
221
- }
222
- i += 1 ;
223
- }
224
- -1
225
- }
226
- } )
227
- } else {
228
- let array_view = array. as_array ( ) ;
229
- py. allow_threads ( || {
230
- array_view
231
- . iter ( )
232
- . position ( |& v| v)
233
- . map ( |i| i as isize )
234
- . unwrap_or ( -1 )
235
- } )
236
- }
237
- }
238
205
239
206
#[ pyfunction]
240
207
#[ pyo3( signature = ( array, forward=true ) ) ]
@@ -243,6 +210,8 @@ fn first_true_1d(py: Python,
243
210
forward : bool ,
244
211
) -> isize {
245
212
if let Ok ( slice) = array. as_slice ( ) {
213
+ const LANES : usize = 32 ;
214
+
246
215
py. allow_threads ( || {
247
216
let len = slice. len ( ) ;
248
217
let ptr = slice. as_ptr ( ) as * const u8 ;
@@ -252,17 +221,17 @@ fn first_true_1d(py: Python,
252
221
let mut i = 0 ;
253
222
unsafe {
254
223
// Process 32 bytes at a time with SIMD
255
- while i + 32 <= len {
256
- let bytes = & * ( ptr. add ( i) as * const [ u8 ; 32 ] ) ;
224
+ while i + LANES <= len {
225
+ let bytes = & * ( ptr. add ( i) as * const [ u8 ; LANES ] ) ;
257
226
let chunk = u8x32:: from ( * bytes) ;
258
227
let equal_one = chunk. cmp_eq ( ones) ;
259
228
if equal_one. any ( ) {
260
229
break ;
261
230
}
262
- i += 32 ;
231
+ i += LANES ;
263
232
}
264
233
// Handle final remainder
265
- while i < len. min ( i + 32 ) {
234
+ while i < len. min ( i + LANES ) {
266
235
if * ptr. add ( i) != 0 {
267
236
return i as isize ;
268
237
}
@@ -273,15 +242,15 @@ fn first_true_1d(py: Python,
273
242
// Backward search
274
243
let mut i = len;
275
244
unsafe {
276
- // Process 32 bytes at a time with SIMD (backwards)
277
- while i >= 32 {
278
- i -= 32 ;
279
- let bytes = & * ( ptr. add ( i) as * const [ u8 ; 32 ] ) ;
245
+ // Process LANES bytes at a time with SIMD (backwards)
246
+ while i >= LANES {
247
+ i -= LANES ;
248
+ let bytes = & * ( ptr. add ( i) as * const [ u8 ; LANES ] ) ;
280
249
let chunk = u8x32:: from ( * bytes) ;
281
250
let equal_one = chunk. cmp_eq ( ones) ;
282
251
if equal_one. any ( ) {
283
252
// Found a true in this chunk, search backwards within it
284
- for j in ( i..i + 32 ) . rev ( ) {
253
+ for j in ( i..i + LANES ) . rev ( ) {
285
254
if * ptr. add ( j) != 0 {
286
255
return j as isize ;
287
256
}
@@ -320,14 +289,111 @@ fn first_true_1d(py: Python,
320
289
}
321
290
}
322
291
292
+
293
+
294
+ // #[pyfunction]
295
+ // fn first_true_1d_g(py: Python, array: PyReadonlyArray1<bool>) -> isize {
296
+ // if let Ok(slice) = array.as_slice() {
297
+ // py.allow_threads(|| {
298
+ // let len = slice.len();
299
+ // let ptr = slice.as_ptr() as *const u8;
300
+ // let mut i = 0;
301
+
302
+ // type Lane = u8;
303
+ // const LANES: usize = 64;
304
+ // let ones = Simd::<Lane, LANES>::splat(1);
305
+
306
+ // unsafe {
307
+ // while i + LANES <= len {
308
+ // let chunk_ptr = ptr.add(i) as *const [u8; LANES];
309
+ // let chunk = Simd::from(*chunk_ptr);
310
+ // let mask = chunk.simd_eq(ones).to_bitmask();
311
+
312
+ // if mask != 0 {
313
+ // let offset = mask.trailing_zeros() as usize;
314
+ // return (i + offset) as isize;
315
+ // }
316
+
317
+ // i += LANES;
318
+ // }
319
+
320
+ // // Remainder (non-SIMD tail)
321
+ // while i < len {
322
+ // if *ptr.add(i) != 0 {
323
+ // return i as isize;
324
+ // }
325
+ // i += 1;
326
+ // }
327
+ // }
328
+
329
+ // -1
330
+ // })
331
+ // } else {
332
+ // // Fallback for non-contiguous arrays
333
+ // let view = array.as_array();
334
+ // py.allow_threads(|| {
335
+ // view.iter()
336
+ // .position(|&v| v)
337
+ // .map(|i| i as isize)
338
+ // .unwrap_or(-1)
339
+ // })
340
+ // }
341
+ // }
342
+
343
+
344
+ //------------------------------------------------------------------------------
345
+
346
+
347
+ fn prepare_array_for_axis < ' py > (
348
+ py : Python < ' py > ,
349
+ array : PyReadonlyArray2 < ' py , bool > ,
350
+ axis : usize ,
351
+ ) -> PyResult < Bound < ' py , PyArray2 < bool > > > {
352
+ if axis != 0 && axis != 1 {
353
+ return Err ( PyValueError :: new_err ( "axis must be 0 or 1" ) ) ;
354
+ }
355
+
356
+ let is_c = array. is_c_contiguous ( ) ;
357
+ let is_f = array. is_fortran_contiguous ( ) ;
358
+
359
+ match ( is_c, is_f, axis) {
360
+ ( true , _, 0 ) => {
361
+ let transposed = array. as_array ( ) . reversed_axes ( ) . to_owned ( ) ;
362
+ Ok ( transposed. into_pyarray ( py) )
363
+ }
364
+ ( true , _, 1 ) => Ok ( array. as_array ( ) . to_owned ( ) . into_pyarray ( py) ) , // copy to get full ownership
365
+ ( _, true , 0 ) => {
366
+ let transposed = array. as_array ( ) . reversed_axes ( ) ;
367
+ Ok ( transposed. to_owned ( ) . into_pyarray ( py) )
368
+ }
369
+ ( _, true , 1 ) => {
370
+ let owned = array. as_array ( ) . to_owned ( ) ;
371
+ Ok ( owned. into_pyarray ( py) )
372
+ }
373
+ ( false , false , 0 ) => {
374
+ let transposed = array. as_array ( ) . reversed_axes ( ) . to_owned ( ) ;
375
+ Ok ( transposed. into_pyarray ( py) )
376
+ }
377
+ ( false , false , 1 ) => {
378
+ let owned = array. as_array ( ) . to_owned ( ) ;
379
+ Ok ( owned. into_pyarray ( py) )
380
+ }
381
+ _ => unreachable ! ( ) ,
382
+ }
383
+ }
384
+
385
+
386
+ //------------------------------------------------------------------------------
387
+
388
+
323
389
#[ pymodule]
324
390
fn arrayredox ( m : & Bound < ' _ , PyModule > ) -> PyResult < ( ) > {
325
391
m. add_function ( wrap_pyfunction ! ( first_true_1d_a, m) ?) ?;
326
392
m. add_function ( wrap_pyfunction ! ( first_true_1d_b, m) ?) ?;
327
393
m. add_function ( wrap_pyfunction ! ( first_true_1d_c, m) ?) ?;
328
394
m. add_function ( wrap_pyfunction ! ( first_true_1d_d, m) ?) ?;
329
395
m. add_function ( wrap_pyfunction ! ( first_true_1d_e, m) ?) ?;
330
- m. add_function ( wrap_pyfunction ! ( first_true_1d_f , m) ?) ?;
396
+ // m.add_function(wrap_pyfunction!(first_true_1d_g , m)?)?;
331
397
m. add_function ( wrap_pyfunction ! ( first_true_1d, m) ?) ?;
332
398
Ok ( ( ) )
333
399
}
0 commit comments