@@ -7,15 +7,17 @@ use wide::*;
7
7
// use std::simd::Simd;
8
8
// use std::simd::cmp::SimdPartialEq;
9
9
10
+ use numpy:: ndarray:: { Array2 , ArrayView2 } ;
11
+ use numpy:: IntoPyArray ;
10
12
use numpy:: PyArray1 ;
11
13
use numpy:: PyArrayMethods ;
12
14
use numpy:: PyUntypedArrayMethods ;
13
15
use numpy:: ToPyArray ;
14
16
use numpy:: { PyArray2 , PyReadonlyArray2 } ;
15
- use numpy:: IntoPyArray ;
16
17
17
18
use rayon:: prelude:: * ;
18
19
use rayon:: ThreadPoolBuilder ;
20
+ use std:: sync:: Arc ;
19
21
20
22
#[ pyfunction]
21
23
fn first_true_1d_a ( array : PyReadonlyArray1 < bool > ) -> isize {
@@ -393,32 +395,17 @@ fn first_true_1d(py: Python, array: PyReadonlyArray1<bool>, forward: bool) -> is
393
395
// }
394
396
// }
395
397
396
-
397
- // use numpy::{PyReadonlyArray2, IntoPyArray, PyArray2};
398
- // use pyo3::prelude::*;
399
-
400
398
pub struct PreparedBool2D < ' py > {
401
- pub data : & ' py [ u8 ] , // flat contiguous buffer
402
- pub nrows : usize , // number of logical rows
399
+ pub data : & ' py [ u8 ] , // contiguous byte slice (bool as u8)
400
+ pub nrows : usize ,
403
401
pub ncols : usize ,
404
- _keepalive : Option < Bound < ' py , PyAny > > , // holds any copied/transposed buffer
402
+ _keepalive : Option < Arc < Array2 < bool > > > , // holds owned data if needed
405
403
}
406
404
407
405
pub fn prepare_array_for_axis < ' py > (
408
- py : Python < ' py > ,
409
406
array : PyReadonlyArray2 < ' py , bool > ,
410
407
axis : isize ,
411
408
) -> PyResult < PreparedBool2D < ' py > > {
412
-
413
- // let shape = array.shape();
414
- // let slice = array.as_slice().unwrap();
415
- // return Ok(PreparedBool2D {
416
- // data: unsafe { std::mem::transmute(slice) }, // &[bool] → &[u8]
417
- // nrows: shape[0],
418
- // ncols: shape[1],
419
- // _keepalive: None,
420
- // });
421
-
422
409
if axis != 0 && axis != 1 {
423
410
return Err ( PyValueError :: new_err ( "axis must be 0 or 1" ) ) ;
424
411
}
@@ -459,117 +446,122 @@ pub fn prepare_array_for_axis<'py>(
459
446
}
460
447
}
461
448
462
- // Case 3: fallback — create a new C-contiguous owned array
463
- let prepared_array : Bound < ' py , PyArray2 < bool > > = if axis == 0 {
464
- array_view. reversed_axes ( ) . as_standard_layout ( ) . to_owned ( ) . to_pyarray ( py )
449
+ // Case 3: fallback — make ndarray owned copy, but no PyArray!
450
+ let array_owned : Array2 < bool > = if axis == 0 {
451
+ array_view. reversed_axes ( ) . as_standard_layout ( ) . to_owned ( )
465
452
} else {
466
- array_view. as_standard_layout ( ) . to_owned ( ) . to_pyarray ( py )
453
+ array_view. as_standard_layout ( ) . to_owned ( )
467
454
} ;
468
455
469
- let array_view = unsafe { prepared_array. as_array ( ) } ;
470
- let prepared_slice = array_view
471
- . as_slice_memory_order ( )
472
- . expect ( "Newly allocated array must be contiguous" ) ;
456
+ let slice = array_owned
457
+ . as_slice_memory_order ( )
458
+ . expect ( "newly allocated Array2 must be contiguous" ) ;
473
459
474
460
Ok ( PreparedBool2D {
475
- data : unsafe { std:: mem:: transmute ( prepared_slice ) } ,
461
+ data : unsafe { std:: mem:: transmute ( slice ) } ,
476
462
nrows,
477
463
ncols,
478
- _keepalive : Some ( prepared_array . into_any ( ) ) ,
464
+ _keepalive : Some ( Arc :: new ( array_owned ) ) ,
479
465
} )
480
466
}
481
467
482
-
483
-
484
468
#[ pyfunction]
485
469
#[ pyo3( signature = ( array, * , forward=true , axis) ) ]
486
- pub fn first_true_2d_a < ' py > (
470
+ pub fn first_true_2d < ' py > (
487
471
py : Python < ' py > ,
488
472
array : PyReadonlyArray2 < ' py , bool > ,
489
473
forward : bool ,
490
474
axis : isize ,
491
475
) -> PyResult < Bound < ' py , PyArray1 < isize > > > {
492
-
493
- let prepared = prepare_array_for_axis ( py, array, axis) ?;
476
+ let prepared = prepare_array_for_axis ( array, axis) ?;
494
477
let data = prepared. data ;
495
478
let rows = prepared. nrows ;
496
479
let row_len = prepared. ncols ;
497
480
498
- let mut result = vec ! [ -1isize ; rows] ;
481
+ let pyarray = unsafe { PyArray1 :: < isize > :: new ( py, [ rows] , false ) } ;
482
+ let result = unsafe { pyarray. as_slice_mut ( ) . unwrap ( ) } ;
483
+ result. fill ( -1 ) ;
499
484
500
- py. allow_threads ( || {
501
- const LANES : usize = 32 ;
502
- let ones = u8x32:: splat ( 1 ) ;
503
- let base_ptr = data. as_ptr ( ) ;
485
+ // let mut result = vec![-1isize; rows];
486
+
487
+ // py.allow_threads(|| {
488
+ const LANES : usize = 32 ;
489
+ let ones = u8x32:: splat ( 1 ) ;
490
+ let base_ptr = data. as_ptr ( ) ;
491
+ let mut i;
504
492
493
+ if forward {
505
494
for row in 0 ..rows {
506
495
let ptr = unsafe { base_ptr. add ( row * row_len) } ;
507
- if forward {
508
- // Forward search
509
- let mut i = 0 ;
510
- unsafe {
511
- while i + LANES <= row_len {
512
- let chunk = & * ( ptr. add ( i) as * const [ u8 ; LANES ] ) ;
513
- let vec = u8x32:: from ( * chunk) ;
514
- if vec. cmp_eq ( ones) . any ( ) {
515
- break ;
516
- }
517
- i += LANES ;
496
+ i = 0 ;
497
+ unsafe {
498
+ while i + LANES <= row_len {
499
+ let chunk = & * ( ptr. add ( i) as * const [ u8 ; LANES ] ) ;
500
+ let vec = u8x32:: from ( * chunk) ;
501
+ if vec. cmp_eq ( ones) . any ( ) {
502
+ break ;
518
503
}
519
- while i < row_len {
520
- if * ptr . add ( i ) != 0 {
521
- result [ row ] = i as isize ;
522
- break ;
523
- }
524
- i += 1 ;
504
+ i += LANES ;
505
+ }
506
+ while i < row_len {
507
+ if * ptr . add ( i ) != 0 {
508
+ result [ row ] = i as isize ;
509
+ break ;
525
510
}
511
+ i += 1 ;
526
512
}
527
- } else {
528
- // Backward search
529
- let mut i = row_len;
530
- unsafe {
531
- // Process LANES bytes at a time with SIMD (backwards)
532
- while i >= LANES {
533
- i -= LANES ;
513
+ }
514
+ }
515
+ } else {
516
+ // Backward search
517
+ for row in 0 ..rows {
518
+ let ptr = unsafe { base_ptr. add ( row * row_len) } ;
534
519
535
- let chunk = & * ( ptr. add ( i) as * const [ u8 ; LANES ] ) ;
536
- let vec = u8x32:: from ( * chunk) ;
537
- if vec. cmp_eq ( ones) . any ( ) {
538
- // Found a true in this chunk, search backwards within it
539
- for j in ( i..i + LANES ) . rev ( ) {
540
- if * ptr. add ( j) != 0 {
541
- result[ row] = j as isize ;
542
- break ;
543
- }
544
- }
545
- break ;
546
- }
547
- }
548
- // Handle remaining bytes at the beginning
549
- if i > 0 && i < LANES {
550
- for j in ( 0 ..i) . rev ( ) {
520
+ i = row_len;
521
+ unsafe {
522
+ // Process LANES bytes at a time with SIMD (backwards)
523
+ while i >= LANES {
524
+ i -= LANES ;
525
+
526
+ let chunk = & * ( ptr. add ( i) as * const [ u8 ; LANES ] ) ;
527
+ let vec = u8x32:: from ( * chunk) ;
528
+ if vec. cmp_eq ( ones) . any ( ) {
529
+ // Found a true in this chunk, search backwards within it
530
+ for j in ( i..i + LANES ) . rev ( ) {
551
531
if * ptr. add ( j) != 0 {
552
532
result[ row] = j as isize ;
553
533
break ;
554
534
}
555
535
}
536
+ break ;
537
+ }
538
+ }
539
+ // Handle remaining bytes at the beginning
540
+ if i > 0 && i < LANES {
541
+ for j in ( 0 ..i) . rev ( ) {
542
+ if * ptr. add ( j) != 0 {
543
+ result[ row] = j as isize ;
544
+ break ;
545
+ }
556
546
}
557
547
}
558
548
}
559
549
}
560
- } ) ;
561
- Ok ( PyArray1 :: from_vec ( py, result) . to_owned ( ) )
550
+ }
551
+ // });
552
+ // Ok(PyArray1::from_vec(py, result).to_owned())
553
+ Ok ( pyarray)
562
554
}
563
555
564
556
#[ pyfunction]
565
557
#[ pyo3( signature = ( array, * , forward=true , axis) ) ]
566
- pub fn first_true_2d < ' py > (
558
+ pub fn first_true_2d_b < ' py > (
567
559
py : Python < ' py > ,
568
560
array : PyReadonlyArray2 < ' py , bool > ,
569
561
forward : bool ,
570
562
axis : isize ,
571
563
) -> PyResult < Bound < ' py , PyArray1 < isize > > > {
572
- let prepared = prepare_array_for_axis ( py , array, axis) ?;
564
+ let prepared = prepare_array_for_axis ( array, axis) ?;
573
565
let data = prepared. data ;
574
566
let rows = prepared. nrows ;
575
567
let row_len = prepared. ncols ;
@@ -580,9 +572,9 @@ pub fn first_true_2d<'py>(
580
572
let max_threads = if rows < 100 {
581
573
1
582
574
} else if rows < 1000 {
583
- 2
575
+ 1
584
576
} else if rows < 10000 {
585
- 4
577
+ 1
586
578
} else {
587
579
16
588
580
} ;
@@ -667,8 +659,6 @@ pub fn first_true_2d<'py>(
667
659
Ok ( PyArray1 :: from_vec ( py, result) )
668
660
}
669
661
670
-
671
-
672
662
//------------------------------------------------------------------------------
673
663
674
664
#[ pymodule]
0 commit comments