Skip to content

Commit 8ad546e

Browse files
committed
progress on first_true_2d
1 parent c643946 commit 8ad546e

File tree

1 file changed

+135
-44
lines changed

1 file changed

+135
-44
lines changed

src/lib.rs

Lines changed: 135 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use numpy::PyArrayMethods;
1212
use numpy::PyUntypedArrayMethods;
1313
use numpy::ToPyArray;
1414
use numpy::{PyArray2, PyReadonlyArray2};
15+
use numpy::IntoPyArray;
1516

1617
#[pyfunction]
1718
fn first_true_1d_a(array: PyReadonlyArray1<bool>) -> isize {
@@ -346,49 +347,127 @@ fn first_true_1d(py: Python, array: PyReadonlyArray1<bool>, forward: bool) -> is
346347
// axis == 0: transpose, copy to C
347348
// axis == 1: copy to C
348349

349-
fn prepare_array_for_axis<'py>(
350+
// fn prepare_array_for_axis<'py>(
351+
// py: Python<'py>,
352+
// array: PyReadonlyArray2<'py, bool>,
353+
// axis: isize,
354+
// ) -> PyResult<Bound<'py, PyArray2<bool>>> {
355+
// if axis != 0 && axis != 1 {
356+
// return Err(PyValueError::new_err("axis must be 0 or 1"));
357+
// }
358+
359+
// let is_c = array.is_c_contiguous();
360+
// let is_f = array.is_fortran_contiguous();
361+
// let array_view = array.as_array();
362+
363+
// match (is_c, is_f, axis) {
364+
// (true, _, 1) => {
365+
// // Already C-contiguous, no copy needed
366+
// Ok(array_view.to_pyarray(py).to_owned())
367+
// }
368+
// (_, true, 0) => {
369+
// // F-contiguous original -> transposed will be C-contiguous, no copy needed
370+
// Ok(array_view.reversed_axes().to_pyarray(py).to_owned())
371+
// }
372+
// (_, true, 1) => {
373+
// // F-contiguous, need to copy to C-contiguous
374+
// let contiguous = array_view.as_standard_layout();
375+
// Ok(contiguous.to_pyarray(py).to_owned())
376+
// }
377+
// (_, _, 1) => {
378+
// // Neither C nor F contiguous, need to copy
379+
// let contiguous = array_view.as_standard_layout();
380+
// Ok(contiguous.to_pyarray(py).to_owned())
381+
// }
382+
383+
// (true, _, 0) | (_, _, 0) => {
384+
// // C-contiguous or neither -> transposed won't be C-contiguous, need copy
385+
// let transposed = array_view.reversed_axes();
386+
// let contiguous = transposed.as_standard_layout();
387+
// Ok(contiguous.to_pyarray(py).to_owned())
388+
// }
389+
// _ => unreachable!(),
390+
// }
391+
// }
392+
393+
394+
// use numpy::{PyReadonlyArray2, IntoPyArray, PyArray2};
395+
// use pyo3::prelude::*;
396+
397+
pub struct PreparedBool2D<'py> {
398+
pub data: &'py [u8], // flat contiguous buffer
399+
pub nrows: usize, // number of logical rows
400+
pub ncols: usize,
401+
_keepalive: Option<Bound<'py, PyAny>>, // holds any copied/transposed buffer
402+
}
403+
404+
pub fn prepare_array_for_axis<'py>(
350405
py: Python<'py>,
351406
array: PyReadonlyArray2<'py, bool>,
352407
axis: isize,
353-
) -> PyResult<Bound<'py, PyArray2<bool>>> {
408+
) -> PyResult<PreparedBool2D<'py>> {
354409
if axis != 0 && axis != 1 {
355410
return Err(PyValueError::new_err("axis must be 0 or 1"));
356411
}
357412

413+
let shape = array.shape();
414+
let (nrows, ncols) = if axis == 0 {
415+
(shape[1], shape[0]) // transposed
416+
} else {
417+
(shape[0], shape[1]) // as-is
418+
};
419+
358420
let is_c = array.is_c_contiguous();
359421
let is_f = array.is_fortran_contiguous();
360422
let array_view = array.as_array();
361423

362-
match (is_c, is_f, axis) {
363-
(true, _, 1) => {
364-
// Already C-contiguous, no copy needed
365-
Ok(array_view.to_pyarray(py).to_owned())
366-
}
367-
(_, true, 0) => {
368-
// F-contiguous original -> transposed will be C-contiguous, no copy needed
369-
Ok(array_view.reversed_axes().to_pyarray(py).to_owned())
370-
}
371-
(_, true, 1) => {
372-
// F-contiguous, need to copy to C-contiguous
373-
let contiguous = array_view.as_standard_layout();
374-
Ok(contiguous.to_pyarray(py).to_owned())
375-
}
376-
(_, _, 1) => {
377-
// Neither C nor F contiguous, need to copy
378-
let contiguous = array_view.as_standard_layout();
379-
Ok(contiguous.to_pyarray(py).to_owned())
424+
// Case 1: C-contiguous + axis=1 → zero-copy slice
425+
if is_c && axis == 1 {
426+
if let Ok(slice) = array.as_slice() {
427+
return Ok(PreparedBool2D {
428+
data: unsafe { std::mem::transmute(slice) }, // &[bool] → &[u8]
429+
nrows,
430+
ncols,
431+
_keepalive: None,
432+
});
380433
}
434+
}
381435

382-
(true, _, 0) | (_, _, 0) => {
383-
// C-contiguous or neither -> transposed won't be C-contiguous, need copy
384-
let transposed = array_view.reversed_axes();
385-
let contiguous = transposed.as_standard_layout();
386-
Ok(contiguous.to_pyarray(py).to_owned())
436+
// Case 2: F-contiguous + axis=0 → transpose, check if sliceable
437+
if is_f && axis == 0 {
438+
let transposed = array_view.reversed_axes();
439+
if let Some(slice) = transposed.as_standard_layout().as_slice_memory_order() {
440+
return Ok(PreparedBool2D {
441+
data: unsafe { std::mem::transmute(slice) },
442+
nrows,
443+
ncols,
444+
_keepalive: None,
445+
});
387446
}
388-
_ => unreachable!(),
389447
}
448+
449+
// Case 3: fallback — create a new C-contiguous owned array
450+
let prepared_array: Bound<'py, PyArray2<bool>> = if axis == 0 {
451+
array_view.reversed_axes().as_standard_layout().to_owned().to_pyarray(py)
452+
} else {
453+
array_view.as_standard_layout().to_owned().to_pyarray(py)
454+
};
455+
456+
let array_view = unsafe { prepared_array.as_array() };
457+
let prepared_slice = array_view
458+
.as_slice_memory_order()
459+
.expect("Newly allocated array must be contiguous");
460+
461+
Ok(PreparedBool2D {
462+
data: unsafe { std::mem::transmute(prepared_slice) },
463+
nrows,
464+
ncols,
465+
_keepalive: Some(prepared_array.into_any()),
466+
})
390467
}
391468

469+
470+
392471
#[pyfunction]
393472
#[pyo3(signature = (array, *, forward=true, axis))]
394473
pub fn first_true_2d<'py>(
@@ -397,47 +476,58 @@ pub fn first_true_2d<'py>(
397476
forward: bool,
398477
axis: isize,
399478
) -> PyResult<Bound<'py, PyArray1<isize>>> {
400-
let prepped = prepare_array_for_axis(py, array, axis)?;
401-
let view = unsafe { prepped.as_array() };
402479

403-
// let view = array.as_array();
404-
// NOTE: these are rows in the view, not always the same as rows
405-
let rows = view.nrows();
406-
let mut result = Vec::with_capacity(rows);
480+
// let prepped = prepare_array_for_axis(py, array, axis)?;
481+
// let view = unsafe { prepped.as_array() };
482+
// // NOTE: these are rows in the view, not always the same as rows
483+
// let rows = view.nrows();
484+
485+
let prepared = prepare_array_for_axis(py, array, axis)?;
486+
let data = prepared.data;
487+
let rows = prepared.nrows;
488+
let row_len = prepared.ncols;
489+
490+
let mut result = vec![-1isize; rows];
407491

408492
py.allow_threads(|| {
409493
const LANES: usize = 32;
410494
let ones = u8x32::splat(1);
411495

496+
let base_ptr = data.as_ptr();
497+
412498
for row in 0..rows {
413-
let mut found = -1;
414-
let row_slice = &view.row(row);
415-
let ptr = row_slice.as_ptr() as *const u8;
416-
let len = row_slice.len();
499+
500+
let ptr = unsafe { base_ptr.add(row * row_len) };
501+
502+
// let mut found = -1;
503+
// let row_slice = &view.row(row);
504+
// let ptr = row_slice.as_ptr() as *const u8;
505+
// let len = row_slice.len();
417506

418507
if forward {
419508
// Forward search
420509
let mut i = 0;
421510
unsafe {
422-
while i + LANES <= len {
511+
while i + LANES <= row_len {
423512
let chunk = &*(ptr.add(i) as *const [u8; LANES]);
424513
let vec = u8x32::from(*chunk);
425514
if vec.cmp_eq(ones).any() {
426515
break;
427516
}
428517
i += LANES;
429518
}
430-
while i < len {
519+
while i < row_len {
431520
if *ptr.add(i) != 0 {
432-
found = i as isize;
521+
// found = i as isize;
522+
result[row] = i as isize;
433523
break;
434524
}
435525
i += 1;
436526
}
437527
}
438528
} else {
439529
// Backward search
440-
let mut i = len;
530+
let mut i = row_len;
441531
unsafe {
442532
// Process LANES bytes at a time with SIMD (backwards)
443533
while i >= LANES {
@@ -448,25 +538,26 @@ pub fn first_true_2d<'py>(
448538
// Found a true in this chunk, search backwards within it
449539
for j in (i..i + LANES).rev() {
450540
if *ptr.add(j) != 0 {
451-
found = j as isize;
541+
// found = j as isize;
542+
result[row] = j as isize;
452543
break;
453544
}
454545
}
455546
break;
456547
}
457548
}
458549
// Handle remaining bytes at the beginning
459-
if found == -1 && i > 0 {
550+
if i > 0 && i < LANES {
460551
for j in (0..i).rev() {
461552
if *ptr.add(j) != 0 {
462-
found = j as isize;
553+
// found = j as isize;
554+
result[row] = j as isize;
463555
break;
464556
}
465557
}
466558
}
467559
}
468560
}
469-
result.push(found);
470561
}
471562
});
472563

0 commit comments

Comments
 (0)