@@ -12,6 +12,7 @@ use numpy::PyArrayMethods;
12
12
use numpy:: PyUntypedArrayMethods ;
13
13
use numpy:: ToPyArray ;
14
14
use numpy:: { PyArray2 , PyReadonlyArray2 } ;
15
+ use numpy:: IntoPyArray ;
15
16
16
17
#[ pyfunction]
17
18
fn first_true_1d_a ( array : PyReadonlyArray1 < bool > ) -> isize {
@@ -346,49 +347,127 @@ fn first_true_1d(py: Python, array: PyReadonlyArray1<bool>, forward: bool) -> is
346
347
// axis == 0: transpose, copy to C
347
348
// axis == 1: copy to C
348
349
349
- fn prepare_array_for_axis < ' py > (
350
+ // fn prepare_array_for_axis<'py>(
351
+ // py: Python<'py>,
352
+ // array: PyReadonlyArray2<'py, bool>,
353
+ // axis: isize,
354
+ // ) -> PyResult<Bound<'py, PyArray2<bool>>> {
355
+ // if axis != 0 && axis != 1 {
356
+ // return Err(PyValueError::new_err("axis must be 0 or 1"));
357
+ // }
358
+
359
+ // let is_c = array.is_c_contiguous();
360
+ // let is_f = array.is_fortran_contiguous();
361
+ // let array_view = array.as_array();
362
+
363
+ // match (is_c, is_f, axis) {
364
+ // (true, _, 1) => {
365
+ // // Already C-contiguous, no copy needed
366
+ // Ok(array_view.to_pyarray(py).to_owned())
367
+ // }
368
+ // (_, true, 0) => {
369
+ // // F-contiguous original -> transposed will be C-contiguous, no copy needed
370
+ // Ok(array_view.reversed_axes().to_pyarray(py).to_owned())
371
+ // }
372
+ // (_, true, 1) => {
373
+ // // F-contiguous, need to copy to C-contiguous
374
+ // let contiguous = array_view.as_standard_layout();
375
+ // Ok(contiguous.to_pyarray(py).to_owned())
376
+ // }
377
+ // (_, _, 1) => {
378
+ // // Neither C nor F contiguous, need to copy
379
+ // let contiguous = array_view.as_standard_layout();
380
+ // Ok(contiguous.to_pyarray(py).to_owned())
381
+ // }
382
+
383
+ // (true, _, 0) | (_, _, 0) => {
384
+ // // C-contiguous or neither -> transposed won't be C-contiguous, need copy
385
+ // let transposed = array_view.reversed_axes();
386
+ // let contiguous = transposed.as_standard_layout();
387
+ // Ok(contiguous.to_pyarray(py).to_owned())
388
+ // }
389
+ // _ => unreachable!(),
390
+ // }
391
+ // }
392
+
393
+
394
+ // use numpy::{PyReadonlyArray2, IntoPyArray, PyArray2};
395
+ // use pyo3::prelude::*;
396
+
397
+ pub struct PreparedBool2D < ' py > {
398
+ pub data : & ' py [ u8 ] , // flat contiguous buffer
399
+ pub nrows : usize , // number of logical rows
400
+ pub ncols : usize ,
401
+ _keepalive : Option < Bound < ' py , PyAny > > , // holds any copied/transposed buffer
402
+ }
403
+
404
+ pub fn prepare_array_for_axis < ' py > (
350
405
py : Python < ' py > ,
351
406
array : PyReadonlyArray2 < ' py , bool > ,
352
407
axis : isize ,
353
- ) -> PyResult < Bound < ' py , PyArray2 < bool > > > {
408
+ ) -> PyResult < PreparedBool2D < ' py > > {
354
409
if axis != 0 && axis != 1 {
355
410
return Err ( PyValueError :: new_err ( "axis must be 0 or 1" ) ) ;
356
411
}
357
412
413
+ let shape = array. shape ( ) ;
414
+ let ( nrows, ncols) = if axis == 0 {
415
+ ( shape[ 1 ] , shape[ 0 ] ) // transposed
416
+ } else {
417
+ ( shape[ 0 ] , shape[ 1 ] ) // as-is
418
+ } ;
419
+
358
420
let is_c = array. is_c_contiguous ( ) ;
359
421
let is_f = array. is_fortran_contiguous ( ) ;
360
422
let array_view = array. as_array ( ) ;
361
423
362
- match ( is_c, is_f, axis) {
363
- ( true , _, 1 ) => {
364
- // Already C-contiguous, no copy needed
365
- Ok ( array_view. to_pyarray ( py) . to_owned ( ) )
366
- }
367
- ( _, true , 0 ) => {
368
- // F-contiguous original -> transposed will be C-contiguous, no copy needed
369
- Ok ( array_view. reversed_axes ( ) . to_pyarray ( py) . to_owned ( ) )
370
- }
371
- ( _, true , 1 ) => {
372
- // F-contiguous, need to copy to C-contiguous
373
- let contiguous = array_view. as_standard_layout ( ) ;
374
- Ok ( contiguous. to_pyarray ( py) . to_owned ( ) )
375
- }
376
- ( _, _, 1 ) => {
377
- // Neither C nor F contiguous, need to copy
378
- let contiguous = array_view. as_standard_layout ( ) ;
379
- Ok ( contiguous. to_pyarray ( py) . to_owned ( ) )
424
+ // Case 1: C-contiguous + axis=1 → zero-copy slice
425
+ if is_c && axis == 1 {
426
+ if let Ok ( slice) = array. as_slice ( ) {
427
+ return Ok ( PreparedBool2D {
428
+ data : unsafe { std:: mem:: transmute ( slice) } , // &[bool] → &[u8]
429
+ nrows,
430
+ ncols,
431
+ _keepalive : None ,
432
+ } ) ;
380
433
}
434
+ }
381
435
382
- ( true , _, 0 ) | ( _, _, 0 ) => {
383
- // C-contiguous or neither -> transposed won't be C-contiguous, need copy
384
- let transposed = array_view. reversed_axes ( ) ;
385
- let contiguous = transposed. as_standard_layout ( ) ;
386
- Ok ( contiguous. to_pyarray ( py) . to_owned ( ) )
436
+ // Case 2: F-contiguous + axis=0 → transpose, check if sliceable
437
+ if is_f && axis == 0 {
438
+ let transposed = array_view. reversed_axes ( ) ;
439
+ if let Some ( slice) = transposed. as_standard_layout ( ) . as_slice_memory_order ( ) {
440
+ return Ok ( PreparedBool2D {
441
+ data : unsafe { std:: mem:: transmute ( slice) } ,
442
+ nrows,
443
+ ncols,
444
+ _keepalive : None ,
445
+ } ) ;
387
446
}
388
- _ => unreachable ! ( ) ,
389
447
}
448
+
449
+ // Case 3: fallback — create a new C-contiguous owned array
450
+ let prepared_array: Bound < ' py , PyArray2 < bool > > = if axis == 0 {
451
+ array_view. reversed_axes ( ) . as_standard_layout ( ) . to_owned ( ) . to_pyarray ( py)
452
+ } else {
453
+ array_view. as_standard_layout ( ) . to_owned ( ) . to_pyarray ( py)
454
+ } ;
455
+
456
+ let array_view = unsafe { prepared_array. as_array ( ) } ;
457
+ let prepared_slice = array_view
458
+ . as_slice_memory_order ( )
459
+ . expect ( "Newly allocated array must be contiguous" ) ;
460
+
461
+ Ok ( PreparedBool2D {
462
+ data : unsafe { std:: mem:: transmute ( prepared_slice) } ,
463
+ nrows,
464
+ ncols,
465
+ _keepalive : Some ( prepared_array. into_any ( ) ) ,
466
+ } )
390
467
}
391
468
469
+
470
+
392
471
#[ pyfunction]
393
472
#[ pyo3( signature = ( array, * , forward=true , axis) ) ]
394
473
pub fn first_true_2d < ' py > (
@@ -397,47 +476,58 @@ pub fn first_true_2d<'py>(
397
476
forward : bool ,
398
477
axis : isize ,
399
478
) -> PyResult < Bound < ' py , PyArray1 < isize > > > {
400
- let prepped = prepare_array_for_axis ( py, array, axis) ?;
401
- let view = unsafe { prepped. as_array ( ) } ;
402
479
403
- // let view = array.as_array();
404
- // NOTE: these are rows in the view, not always the same as rows
405
- let rows = view. nrows ( ) ;
406
- let mut result = Vec :: with_capacity ( rows) ;
480
+ // let prepped = prepare_array_for_axis(py, array, axis)?;
481
+ // let view = unsafe { prepped.as_array() };
482
+ // // NOTE: these are rows in the view, not always the same as rows
483
+ // let rows = view.nrows();
484
+
485
+ let prepared = prepare_array_for_axis ( py, array, axis) ?;
486
+ let data = prepared. data ;
487
+ let rows = prepared. nrows ;
488
+ let row_len = prepared. ncols ;
489
+
490
+ let mut result = vec ! [ -1isize ; rows] ;
407
491
408
492
py. allow_threads ( || {
409
493
const LANES : usize = 32 ;
410
494
let ones = u8x32:: splat ( 1 ) ;
411
495
496
+ let base_ptr = data. as_ptr ( ) ;
497
+
412
498
for row in 0 ..rows {
413
- let mut found = -1 ;
414
- let row_slice = & view. row ( row) ;
415
- let ptr = row_slice. as_ptr ( ) as * const u8 ;
416
- let len = row_slice. len ( ) ;
499
+
500
+ let ptr = unsafe { base_ptr. add ( row * row_len) } ;
501
+
502
+ // let mut found = -1;
503
+ // let row_slice = &view.row(row);
504
+ // let ptr = row_slice.as_ptr() as *const u8;
505
+ // let len = row_slice.len();
417
506
418
507
if forward {
419
508
// Forward search
420
509
let mut i = 0 ;
421
510
unsafe {
422
- while i + LANES <= len {
511
+ while i + LANES <= row_len {
423
512
let chunk = & * ( ptr. add ( i) as * const [ u8 ; LANES ] ) ;
424
513
let vec = u8x32:: from ( * chunk) ;
425
514
if vec. cmp_eq ( ones) . any ( ) {
426
515
break ;
427
516
}
428
517
i += LANES ;
429
518
}
430
- while i < len {
519
+ while i < row_len {
431
520
if * ptr. add ( i) != 0 {
432
- found = i as isize ;
521
+ // found = i as isize;
522
+ result[ row] = i as isize ;
433
523
break ;
434
524
}
435
525
i += 1 ;
436
526
}
437
527
}
438
528
} else {
439
529
// Backward search
440
- let mut i = len ;
530
+ let mut i = row_len ;
441
531
unsafe {
442
532
// Process LANES bytes at a time with SIMD (backwards)
443
533
while i >= LANES {
@@ -448,25 +538,26 @@ pub fn first_true_2d<'py>(
448
538
// Found a true in this chunk, search backwards within it
449
539
for j in ( i..i + LANES ) . rev ( ) {
450
540
if * ptr. add ( j) != 0 {
451
- found = j as isize ;
541
+ // found = j as isize;
542
+ result[ row] = j as isize ;
452
543
break ;
453
544
}
454
545
}
455
546
break ;
456
547
}
457
548
}
458
549
// Handle remaining bytes at the beginning
459
- if found == - 1 && i > 0 {
550
+ if i > 0 && i < LANES {
460
551
for j in ( 0 ..i) . rev ( ) {
461
552
if * ptr. add ( j) != 0 {
462
- found = j as isize ;
553
+ // found = j as isize;
554
+ result[ row] = j as isize ;
463
555
break ;
464
556
}
465
557
}
466
558
}
467
559
}
468
560
}
469
- result. push ( found) ;
470
561
}
471
562
} ) ;
472
563
0 commit comments