Skip to content

Commit d8e28ae

Browse files
committed
updated implemetnation of prepare_array_for_axis
1 parent 2037595 commit d8e28ae

File tree

1 file changed

+77
-87
lines changed

1 file changed

+77
-87
lines changed

src/lib.rs

Lines changed: 77 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@ use wide::*;
77
// use std::simd::Simd;
88
// use std::simd::cmp::SimdPartialEq;
99

10+
use numpy::ndarray::{Array2, ArrayView2};
11+
use numpy::IntoPyArray;
1012
use numpy::PyArray1;
1113
use numpy::PyArrayMethods;
1214
use numpy::PyUntypedArrayMethods;
1315
use numpy::ToPyArray;
1416
use numpy::{PyArray2, PyReadonlyArray2};
15-
use numpy::IntoPyArray;
1617

1718
use rayon::prelude::*;
1819
use rayon::ThreadPoolBuilder;
20+
use std::sync::Arc;
1921

2022
#[pyfunction]
2123
fn first_true_1d_a(array: PyReadonlyArray1<bool>) -> isize {
@@ -393,32 +395,17 @@ fn first_true_1d(py: Python, array: PyReadonlyArray1<bool>, forward: bool) -> is
393395
// }
394396
// }
395397

396-
397-
// use numpy::{PyReadonlyArray2, IntoPyArray, PyArray2};
398-
// use pyo3::prelude::*;
399-
400398
pub struct PreparedBool2D<'py> {
401-
pub data: &'py [u8], // flat contiguous buffer
402-
pub nrows: usize, // number of logical rows
399+
pub data: &'py [u8], // contiguous byte slice (bool as u8)
400+
pub nrows: usize,
403401
pub ncols: usize,
404-
_keepalive: Option<Bound<'py, PyAny>>, // holds any copied/transposed buffer
402+
_keepalive: Option<Arc<Array2<bool>>>, // holds owned data if needed
405403
}
406404

407405
pub fn prepare_array_for_axis<'py>(
408-
py: Python<'py>,
409406
array: PyReadonlyArray2<'py, bool>,
410407
axis: isize,
411408
) -> PyResult<PreparedBool2D<'py>> {
412-
413-
// let shape = array.shape();
414-
// let slice = array.as_slice().unwrap();
415-
// return Ok(PreparedBool2D {
416-
// data: unsafe { std::mem::transmute(slice) }, // &[bool] → &[u8]
417-
// nrows: shape[0],
418-
// ncols: shape[1],
419-
// _keepalive: None,
420-
// });
421-
422409
if axis != 0 && axis != 1 {
423410
return Err(PyValueError::new_err("axis must be 0 or 1"));
424411
}
@@ -459,117 +446,122 @@ pub fn prepare_array_for_axis<'py>(
459446
}
460447
}
461448

462-
// Case 3: fallback — create a new C-contiguous owned array
463-
let prepared_array: Bound<'py, PyArray2<bool>> = if axis == 0 {
464-
array_view.reversed_axes().as_standard_layout().to_owned().to_pyarray(py)
449+
// Case 3: fallback — make ndarray owned copy, but no PyArray!
450+
let array_owned: Array2<bool> = if axis == 0 {
451+
array_view.reversed_axes().as_standard_layout().to_owned()
465452
} else {
466-
array_view.as_standard_layout().to_owned().to_pyarray(py)
453+
array_view.as_standard_layout().to_owned()
467454
};
468455

469-
let array_view = unsafe { prepared_array.as_array() };
470-
let prepared_slice = array_view
471-
.as_slice_memory_order()
472-
.expect("Newly allocated array must be contiguous");
456+
let slice = array_owned
457+
.as_slice_memory_order()
458+
.expect("newly allocated Array2 must be contiguous");
473459

474460
Ok(PreparedBool2D {
475-
data: unsafe { std::mem::transmute(prepared_slice) },
461+
data: unsafe { std::mem::transmute(slice) },
476462
nrows,
477463
ncols,
478-
_keepalive: Some(prepared_array.into_any()),
464+
_keepalive: Some(Arc::new(array_owned)),
479465
})
480466
}
481467

482-
483-
484468
#[pyfunction]
485469
#[pyo3(signature = (array, *, forward=true, axis))]
486-
pub fn first_true_2d_a<'py>(
470+
pub fn first_true_2d<'py>(
487471
py: Python<'py>,
488472
array: PyReadonlyArray2<'py, bool>,
489473
forward: bool,
490474
axis: isize,
491475
) -> PyResult<Bound<'py, PyArray1<isize>>> {
492-
493-
let prepared = prepare_array_for_axis(py, array, axis)?;
476+
let prepared = prepare_array_for_axis(array, axis)?;
494477
let data = prepared.data;
495478
let rows = prepared.nrows;
496479
let row_len = prepared.ncols;
497480

498-
let mut result = vec![-1isize; rows];
481+
let pyarray = unsafe { PyArray1::<isize>::new(py, [rows], false) };
482+
let result = unsafe { pyarray.as_slice_mut().unwrap() };
483+
result.fill(-1);
499484

500-
py.allow_threads(|| {
501-
const LANES: usize = 32;
502-
let ones = u8x32::splat(1);
503-
let base_ptr = data.as_ptr();
485+
// let mut result = vec![-1isize; rows];
486+
487+
// py.allow_threads(|| {
488+
const LANES: usize = 32;
489+
let ones = u8x32::splat(1);
490+
let base_ptr = data.as_ptr();
491+
let mut i;
504492

493+
if forward {
505494
for row in 0..rows {
506495
let ptr = unsafe { base_ptr.add(row * row_len) };
507-
if forward {
508-
// Forward search
509-
let mut i = 0;
510-
unsafe {
511-
while i + LANES <= row_len {
512-
let chunk = &*(ptr.add(i) as *const [u8; LANES]);
513-
let vec = u8x32::from(*chunk);
514-
if vec.cmp_eq(ones).any() {
515-
break;
516-
}
517-
i += LANES;
496+
i = 0;
497+
unsafe {
498+
while i + LANES <= row_len {
499+
let chunk = &*(ptr.add(i) as *const [u8; LANES]);
500+
let vec = u8x32::from(*chunk);
501+
if vec.cmp_eq(ones).any() {
502+
break;
518503
}
519-
while i < row_len {
520-
if *ptr.add(i) != 0 {
521-
result[row] = i as isize;
522-
break;
523-
}
524-
i += 1;
504+
i += LANES;
505+
}
506+
while i < row_len {
507+
if *ptr.add(i) != 0 {
508+
result[row] = i as isize;
509+
break;
525510
}
511+
i += 1;
526512
}
527-
} else {
528-
// Backward search
529-
let mut i = row_len;
530-
unsafe {
531-
// Process LANES bytes at a time with SIMD (backwards)
532-
while i >= LANES {
533-
i -= LANES;
513+
}
514+
}
515+
} else {
516+
// Backward search
517+
for row in 0..rows {
518+
let ptr = unsafe { base_ptr.add(row * row_len) };
534519

535-
let chunk = &*(ptr.add(i) as *const [u8; LANES]);
536-
let vec = u8x32::from(*chunk);
537-
if vec.cmp_eq(ones).any() {
538-
// Found a true in this chunk, search backwards within it
539-
for j in (i..i + LANES).rev() {
540-
if *ptr.add(j) != 0 {
541-
result[row] = j as isize;
542-
break;
543-
}
544-
}
545-
break;
546-
}
547-
}
548-
// Handle remaining bytes at the beginning
549-
if i > 0 && i < LANES {
550-
for j in (0..i).rev() {
520+
i = row_len;
521+
unsafe {
522+
// Process LANES bytes at a time with SIMD (backwards)
523+
while i >= LANES {
524+
i -= LANES;
525+
526+
let chunk = &*(ptr.add(i) as *const [u8; LANES]);
527+
let vec = u8x32::from(*chunk);
528+
if vec.cmp_eq(ones).any() {
529+
// Found a true in this chunk, search backwards within it
530+
for j in (i..i + LANES).rev() {
551531
if *ptr.add(j) != 0 {
552532
result[row] = j as isize;
553533
break;
554534
}
555535
}
536+
break;
537+
}
538+
}
539+
// Handle remaining bytes at the beginning
540+
if i > 0 && i < LANES {
541+
for j in (0..i).rev() {
542+
if *ptr.add(j) != 0 {
543+
result[row] = j as isize;
544+
break;
545+
}
556546
}
557547
}
558548
}
559549
}
560-
});
561-
Ok(PyArray1::from_vec(py, result).to_owned())
550+
}
551+
// });
552+
// Ok(PyArray1::from_vec(py, result).to_owned())
553+
Ok(pyarray)
562554
}
563555

564556
#[pyfunction]
565557
#[pyo3(signature = (array, *, forward=true, axis))]
566-
pub fn first_true_2d<'py>(
558+
pub fn first_true_2d_b<'py>(
567559
py: Python<'py>,
568560
array: PyReadonlyArray2<'py, bool>,
569561
forward: bool,
570562
axis: isize,
571563
) -> PyResult<Bound<'py, PyArray1<isize>>> {
572-
let prepared = prepare_array_for_axis(py, array, axis)?;
564+
let prepared = prepare_array_for_axis(array, axis)?;
573565
let data = prepared.data;
574566
let rows = prepared.nrows;
575567
let row_len = prepared.ncols;
@@ -580,9 +572,9 @@ pub fn first_true_2d<'py>(
580572
let max_threads = if rows < 100 {
581573
1
582574
} else if rows < 1000 {
583-
2
575+
1
584576
} else if rows < 10000 {
585-
4
577+
1
586578
} else {
587579
16
588580
};
@@ -667,8 +659,6 @@ pub fn first_true_2d<'py>(
667659
Ok(PyArray1::from_vec(py, result))
668660
}
669661

670-
671-
672662
//------------------------------------------------------------------------------
673663

674664
#[pymodule]

0 commit comments

Comments
 (0)