@@ -14,6 +14,9 @@ use numpy::ToPyArray;
14
14
use numpy:: { PyArray2 , PyReadonlyArray2 } ;
15
15
use numpy:: IntoPyArray ;
16
16
17
+ use rayon:: prelude:: * ;
18
+ use rayon:: ThreadPoolBuilder ;
19
+
17
20
#[ pyfunction]
18
21
fn first_true_1d_a ( array : PyReadonlyArray1 < bool > ) -> isize {
19
22
match array. as_slice ( ) {
@@ -480,7 +483,7 @@ pub fn prepare_array_for_axis<'py>(
480
483
481
484
#[ pyfunction]
482
485
#[ pyo3( signature = ( array, * , forward=true , axis) ) ]
483
- pub fn first_true_2d < ' py > (
486
+ pub fn first_true_2d_a < ' py > (
484
487
py : Python < ' py > ,
485
488
array : PyReadonlyArray2 < ' py , bool > ,
486
489
forward : bool ,
@@ -558,6 +561,114 @@ pub fn first_true_2d<'py>(
558
561
Ok ( PyArray1 :: from_vec ( py, result) . to_owned ( ) )
559
562
}
560
563
564
+ #[ pyfunction]
565
+ #[ pyo3( signature = ( array, * , forward=true , axis) ) ]
566
+ pub fn first_true_2d < ' py > (
567
+ py : Python < ' py > ,
568
+ array : PyReadonlyArray2 < ' py , bool > ,
569
+ forward : bool ,
570
+ axis : isize ,
571
+ ) -> PyResult < Bound < ' py , PyArray1 < isize > > > {
572
+ let prepared = prepare_array_for_axis ( py, array, axis) ?;
573
+ let data = prepared. data ;
574
+ let rows = prepared. nrows ;
575
+ let row_len = prepared. ncols ;
576
+
577
+ let mut result = vec ! [ -1isize ; rows] ;
578
+
579
+ // Dynamically select thread count
580
+ let max_threads = if rows < 100 {
581
+ 1
582
+ } else if rows < 1000 {
583
+ 2
584
+ } else if rows < 10000 {
585
+ 4
586
+ } else {
587
+ 16
588
+ } ;
589
+
590
+ py. allow_threads ( || {
591
+ let base_ptr = data. as_ptr ( ) as usize ;
592
+ const LANES : usize = 32 ;
593
+ let ones = u8x32:: splat ( 1 ) ;
594
+
595
+ let process_row = |row : usize | -> isize {
596
+ let ptr = ( base_ptr + row * row_len) as * const u8 ;
597
+ let mut found = -1isize ;
598
+
599
+ unsafe {
600
+ if forward {
601
+ let mut i = 0 ;
602
+ while i + LANES <= row_len {
603
+ let chunk = & * ( ptr. add ( i) as * const [ u8 ; LANES ] ) ;
604
+ let vec = u8x32:: from ( * chunk) ;
605
+ if vec. cmp_eq ( ones) . any ( ) {
606
+ break ;
607
+ }
608
+ i += LANES ;
609
+ }
610
+ while i < row_len {
611
+ if * ptr. add ( i) != 0 {
612
+ found = i as isize ;
613
+ break ;
614
+ }
615
+ i += 1 ;
616
+ }
617
+ } else {
618
+ let mut i = row_len;
619
+ while i >= LANES {
620
+ i -= LANES ;
621
+ let chunk = & * ( ptr. add ( i) as * const [ u8 ; LANES ] ) ;
622
+ let vec = u8x32:: from ( * chunk) ;
623
+ if vec. cmp_eq ( ones) . any ( ) {
624
+ for j in ( i..i + LANES ) . rev ( ) {
625
+ if * ptr. add ( j) != 0 {
626
+ found = j as isize ;
627
+ break ;
628
+ }
629
+ }
630
+ break ;
631
+ }
632
+ }
633
+ if i > 0 && i < LANES {
634
+ for j in ( 0 ..i) . rev ( ) {
635
+ if * ptr. add ( j) != 0 {
636
+ found = j as isize ;
637
+ break ;
638
+ }
639
+ }
640
+ }
641
+ }
642
+ }
643
+
644
+ found
645
+ } ;
646
+
647
+ if max_threads == 1 {
648
+ // Single-threaded path
649
+ for row in 0 ..rows {
650
+ result[ row] = process_row ( row) ;
651
+ }
652
+ } else {
653
+ // Multi-threaded path with Rayon
654
+ let pool = rayon:: ThreadPoolBuilder :: new ( )
655
+ . num_threads ( max_threads)
656
+ . build ( )
657
+ . unwrap ( ) ;
658
+
659
+ pool. install ( || {
660
+ result. par_iter_mut ( ) . enumerate ( ) . for_each ( |( row, out) | {
661
+ * out = process_row ( row) ;
662
+ } ) ;
663
+ } ) ;
664
+ }
665
+ } ) ;
666
+
667
+ Ok ( PyArray1 :: from_vec ( py, result) )
668
+ }
669
+
670
+
671
+
561
672
//------------------------------------------------------------------------------
562
673
563
674
#[ pymodule]
0 commit comments