Skip to content

Commit 2037595

Browse files
committed
progress on multithreaded implementation
1 parent 59295ee commit 2037595

File tree

3 files changed

+165
-1
lines changed

3 files changed

+165
-1
lines changed

Cargo.lock

Lines changed: 52 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ crate-type = ["cdylib"]
1212
pyo3 = "0.25.0"
1313
numpy = "0.25.0"
1414
wide = "0.7.33"
15+
rayon = "1.10.0"
1516

1617

1718
[profile.release]

src/lib.rs

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ use numpy::ToPyArray;
1414
use numpy::{PyArray2, PyReadonlyArray2};
1515
use numpy::IntoPyArray;
1616

17+
use rayon::prelude::*;
18+
use rayon::ThreadPoolBuilder;
19+
1720
#[pyfunction]
1821
fn first_true_1d_a(array: PyReadonlyArray1<bool>) -> isize {
1922
match array.as_slice() {
@@ -480,7 +483,7 @@ pub fn prepare_array_for_axis<'py>(
480483

481484
#[pyfunction]
482485
#[pyo3(signature = (array, *, forward=true, axis))]
483-
pub fn first_true_2d<'py>(
486+
pub fn first_true_2d_a<'py>(
484487
py: Python<'py>,
485488
array: PyReadonlyArray2<'py, bool>,
486489
forward: bool,
@@ -558,6 +561,114 @@ pub fn first_true_2d<'py>(
558561
Ok(PyArray1::from_vec(py, result).to_owned())
559562
}
560563

564+
#[pyfunction]
565+
#[pyo3(signature = (array, *, forward=true, axis))]
566+
pub fn first_true_2d<'py>(
567+
py: Python<'py>,
568+
array: PyReadonlyArray2<'py, bool>,
569+
forward: bool,
570+
axis: isize,
571+
) -> PyResult<Bound<'py, PyArray1<isize>>> {
572+
let prepared = prepare_array_for_axis(py, array, axis)?;
573+
let data = prepared.data;
574+
let rows = prepared.nrows;
575+
let row_len = prepared.ncols;
576+
577+
let mut result = vec![-1isize; rows];
578+
579+
// Dynamically select thread count
580+
let max_threads = if rows < 100 {
581+
1
582+
} else if rows < 1000 {
583+
2
584+
} else if rows < 10000 {
585+
4
586+
} else {
587+
16
588+
};
589+
590+
py.allow_threads(|| {
591+
let base_ptr = data.as_ptr() as usize;
592+
const LANES: usize = 32;
593+
let ones = u8x32::splat(1);
594+
595+
let process_row = |row: usize| -> isize {
596+
let ptr = (base_ptr + row * row_len) as *const u8;
597+
let mut found = -1isize;
598+
599+
unsafe {
600+
if forward {
601+
let mut i = 0;
602+
while i + LANES <= row_len {
603+
let chunk = &*(ptr.add(i) as *const [u8; LANES]);
604+
let vec = u8x32::from(*chunk);
605+
if vec.cmp_eq(ones).any() {
606+
break;
607+
}
608+
i += LANES;
609+
}
610+
while i < row_len {
611+
if *ptr.add(i) != 0 {
612+
found = i as isize;
613+
break;
614+
}
615+
i += 1;
616+
}
617+
} else {
618+
let mut i = row_len;
619+
while i >= LANES {
620+
i -= LANES;
621+
let chunk = &*(ptr.add(i) as *const [u8; LANES]);
622+
let vec = u8x32::from(*chunk);
623+
if vec.cmp_eq(ones).any() {
624+
for j in (i..i + LANES).rev() {
625+
if *ptr.add(j) != 0 {
626+
found = j as isize;
627+
break;
628+
}
629+
}
630+
break;
631+
}
632+
}
633+
if i > 0 && i < LANES {
634+
for j in (0..i).rev() {
635+
if *ptr.add(j) != 0 {
636+
found = j as isize;
637+
break;
638+
}
639+
}
640+
}
641+
}
642+
}
643+
644+
found
645+
};
646+
647+
if max_threads == 1 {
648+
// Single-threaded path
649+
for row in 0..rows {
650+
result[row] = process_row(row);
651+
}
652+
} else {
653+
// Multi-threaded path with Rayon
654+
let pool = rayon::ThreadPoolBuilder::new()
655+
.num_threads(max_threads)
656+
.build()
657+
.unwrap();
658+
659+
pool.install(|| {
660+
result.par_iter_mut().enumerate().for_each(|(row, out)| {
661+
*out = process_row(row);
662+
});
663+
});
664+
}
665+
});
666+
667+
Ok(PyArray1::from_vec(py, result))
668+
}
669+
670+
671+
561672
//------------------------------------------------------------------------------
562673

563674
#[pymodule]

0 commit comments

Comments
 (0)