Skip to content

Commit 903a16a

Browse files
committed
progress on 2d version
1 parent a5c35d6 commit 903a16a

File tree

2 files changed

+131
-60
lines changed

2 files changed

+131
-60
lines changed

doc/articles/first_true_1d.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from arrayredox import first_true_1d_c as ar_first_true_1d_c
1616
from arrayredox import first_true_1d_d as ar_first_true_1d_d
1717
from arrayredox import first_true_1d_e as ar_first_true_1d_e
18-
from arrayredox import first_true_1d_f as ar_first_true_1d_f
18+
from arrayredox import first_true_1d as ar_first_true_1d_f
19+
# from arrayredox import first_true_1d_g as ar_first_true_1d_g
1920

2021
import matplotlib.pyplot as plt
2122
import numpy as np
@@ -69,25 +70,28 @@ class ARFirstTrueD(ArrayProcessor):
6970
SORT = 0
7071

7172
def __call__(self):
72-
# _ = ar_first_true_1d(self.array, forward=True)
7373
_ = ar_first_true_1d_d(self.array)
7474

7575
class ARFirstTrueE(ArrayProcessor):
7676
NAME = 'ar.first_true_1d_e()'
7777
SORT = 0
7878

7979
def __call__(self):
80-
# _ = ar_first_true_1d(self.array, forward=True)
8180
_ = ar_first_true_1d_e(self.array)
8281

8382
class ARFirstTrueF(ArrayProcessor):
8483
NAME = 'ar.first_true_1d_f()'
8584
SORT = 0
8685

8786
def __call__(self):
88-
# _ = ar_first_true_1d(self.array, forward=True)
8987
_ = ar_first_true_1d_f(self.array)
9088

89+
# class ARFirstTrueG(ArrayProcessor):
90+
# NAME = 'ar.first_true_1d_g()'
91+
# SORT = 0
92+
93+
# def __call__(self):
94+
# _ = ar_first_true_1d_g(self.array)
9195

9296

9397

@@ -311,6 +315,7 @@ def get_versions() -> str:
311315
# ARFirstTrueD,
312316
ARFirstTrueE,
313317
ARFirstTrueF,
318+
# ARFirstTrueG,
314319
NPArgMax,
315320
# NPNonZero,
316321
# NPNotAnyArgMax,

src/lib.rs

Lines changed: 122 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
1+
#![feature(portable_simd)]
2+
13
use numpy::PyReadonlyArray1;
24
use pyo3::prelude::*;
3-
use wide::*;
5+
use pyo3::exceptions::PyValueError;
6+
use pyo3::Bound;
47
// use pyo3::types::{PyBool, PyAny};
8+
use wide::*;
9+
// use std::simd::Simd;
10+
// use std::simd::cmp::SimdPartialEq;
11+
12+
use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2};
13+
use numpy::ToPyArray;
14+
use numpy::PyArrayMethods;
15+
use numpy::PyUntypedArrayMethods;
16+
517

618
#[pyfunction]
719
fn first_true_1d_a(array: PyReadonlyArray1<bool>) -> isize {
@@ -190,51 +202,6 @@ fn first_true_1d_e(array: PyReadonlyArray1<bool>) -> isize {
190202
}
191203
}
192204

193-
#[pyfunction]
194-
fn first_true_1d_f(py: Python, array: PyReadonlyArray1<bool>) -> isize {
195-
if let Ok(slice) = array.as_slice() {
196-
py.allow_threads(|| {
197-
let len = slice.len();
198-
let ptr = slice.as_ptr() as *const u8;
199-
let mut i = 0;
200-
201-
let ones = u8x32::splat(1);
202-
unsafe {
203-
// Process 32 bytes at a time with SIMD
204-
while i + 32 <= len {
205-
// Cast pointer to array reference
206-
let bytes = &*(ptr.add(i) as *const [u8; 32]);
207-
208-
// Convert to SIMD vector
209-
let chunk = u8x32::from(*bytes);
210-
let equal_one = chunk.cmp_eq(ones);
211-
if equal_one.any() {
212-
break;
213-
}
214-
215-
i += 32;
216-
}
217-
// // Handle final remainder
218-
while i < len.min(i + 32) {
219-
if *ptr.add(i) != 0 {
220-
return i as isize;
221-
}
222-
i += 1;
223-
}
224-
-1
225-
}
226-
})
227-
} else {
228-
let array_view = array.as_array();
229-
py.allow_threads(|| {
230-
array_view
231-
.iter()
232-
.position(|&v| v)
233-
.map(|i| i as isize)
234-
.unwrap_or(-1)
235-
})
236-
}
237-
}
238205

239206
#[pyfunction]
240207
#[pyo3(signature = (array, forward=true))]
@@ -243,6 +210,8 @@ fn first_true_1d(py: Python,
243210
forward: bool,
244211
) -> isize {
245212
if let Ok(slice) = array.as_slice() {
213+
const LANES: usize = 32;
214+
246215
py.allow_threads(|| {
247216
let len = slice.len();
248217
let ptr = slice.as_ptr() as *const u8;
@@ -252,17 +221,17 @@ fn first_true_1d(py: Python,
252221
let mut i = 0;
253222
unsafe {
254223
// Process 32 bytes at a time with SIMD
255-
while i + 32 <= len {
256-
let bytes = &*(ptr.add(i) as *const [u8; 32]);
224+
while i + LANES <= len {
225+
let bytes = &*(ptr.add(i) as *const [u8; LANES]);
257226
let chunk = u8x32::from(*bytes);
258227
let equal_one = chunk.cmp_eq(ones);
259228
if equal_one.any() {
260229
break;
261230
}
262-
i += 32;
231+
i += LANES;
263232
}
264233
// Handle final remainder
265-
while i < len.min(i + 32) {
234+
while i < len.min(i + LANES) {
266235
if *ptr.add(i) != 0 {
267236
return i as isize;
268237
}
@@ -273,15 +242,15 @@ fn first_true_1d(py: Python,
273242
// Backward search
274243
let mut i = len;
275244
unsafe {
276-
// Process 32 bytes at a time with SIMD (backwards)
277-
while i >= 32 {
278-
i -= 32;
279-
let bytes = &*(ptr.add(i) as *const [u8; 32]);
245+
// Process LANES bytes at a time with SIMD (backwards)
246+
while i >= LANES {
247+
i -= LANES;
248+
let bytes = &*(ptr.add(i) as *const [u8; LANES]);
280249
let chunk = u8x32::from(*bytes);
281250
let equal_one = chunk.cmp_eq(ones);
282251
if equal_one.any() {
283252
// Found a true in this chunk, search backwards within it
284-
for j in (i..i + 32).rev() {
253+
for j in (i..i + LANES).rev() {
285254
if *ptr.add(j) != 0 {
286255
return j as isize;
287256
}
@@ -320,14 +289,111 @@ fn first_true_1d(py: Python,
320289
}
321290
}
322291

292+
293+
294+
// #[pyfunction]
295+
// fn first_true_1d_g(py: Python, array: PyReadonlyArray1<bool>) -> isize {
296+
// if let Ok(slice) = array.as_slice() {
297+
// py.allow_threads(|| {
298+
// let len = slice.len();
299+
// let ptr = slice.as_ptr() as *const u8;
300+
// let mut i = 0;
301+
302+
// type Lane = u8;
303+
// const LANES: usize = 64;
304+
// let ones = Simd::<Lane, LANES>::splat(1);
305+
306+
// unsafe {
307+
// while i + LANES <= len {
308+
// let chunk_ptr = ptr.add(i) as *const [u8; LANES];
309+
// let chunk = Simd::from(*chunk_ptr);
310+
// let mask = chunk.simd_eq(ones).to_bitmask();
311+
312+
// if mask != 0 {
313+
// let offset = mask.trailing_zeros() as usize;
314+
// return (i + offset) as isize;
315+
// }
316+
317+
// i += LANES;
318+
// }
319+
320+
// // Remainder (non-SIMD tail)
321+
// while i < len {
322+
// if *ptr.add(i) != 0 {
323+
// return i as isize;
324+
// }
325+
// i += 1;
326+
// }
327+
// }
328+
329+
// -1
330+
// })
331+
// } else {
332+
// // Fallback for non-contiguous arrays
333+
// let view = array.as_array();
334+
// py.allow_threads(|| {
335+
// view.iter()
336+
// .position(|&v| v)
337+
// .map(|i| i as isize)
338+
// .unwrap_or(-1)
339+
// })
340+
// }
341+
// }
342+
343+
344+
//------------------------------------------------------------------------------
345+
346+
347+
fn prepare_array_for_axis<'py>(
348+
py: Python<'py>,
349+
array: PyReadonlyArray2<'py, bool>,
350+
axis: usize,
351+
) -> PyResult<Bound<'py, PyArray2<bool>>> {
352+
if axis != 0 && axis != 1 {
353+
return Err(PyValueError::new_err("axis must be 0 or 1"));
354+
}
355+
356+
let is_c = array.is_c_contiguous();
357+
let is_f = array.is_fortran_contiguous();
358+
359+
match (is_c, is_f, axis) {
360+
(true, _, 0) => {
361+
let transposed = array.as_array().reversed_axes().to_owned();
362+
Ok(transposed.into_pyarray(py))
363+
}
364+
(true, _, 1) => Ok(array.as_array().to_owned().into_pyarray(py)), // copy to get full ownership
365+
(_, true, 0) => {
366+
let transposed = array.as_array().reversed_axes();
367+
Ok(transposed.to_owned().into_pyarray(py))
368+
}
369+
(_, true, 1) => {
370+
let owned = array.as_array().to_owned();
371+
Ok(owned.into_pyarray(py))
372+
}
373+
(false, false, 0) => {
374+
let transposed = array.as_array().reversed_axes().to_owned();
375+
Ok(transposed.into_pyarray(py))
376+
}
377+
(false, false, 1) => {
378+
let owned = array.as_array().to_owned();
379+
Ok(owned.into_pyarray(py))
380+
}
381+
_ => unreachable!(),
382+
}
383+
}
384+
385+
386+
//------------------------------------------------------------------------------
387+
388+
323389
#[pymodule]
324390
fn arrayredox(m: &Bound<'_, PyModule>) -> PyResult<()> {
325391
m.add_function(wrap_pyfunction!(first_true_1d_a, m)?)?;
326392
m.add_function(wrap_pyfunction!(first_true_1d_b, m)?)?;
327393
m.add_function(wrap_pyfunction!(first_true_1d_c, m)?)?;
328394
m.add_function(wrap_pyfunction!(first_true_1d_d, m)?)?;
329395
m.add_function(wrap_pyfunction!(first_true_1d_e, m)?)?;
330-
m.add_function(wrap_pyfunction!(first_true_1d_f, m)?)?;
396+
// m.add_function(wrap_pyfunction!(first_true_1d_g, m)?)?;
331397
m.add_function(wrap_pyfunction!(first_true_1d, m)?)?;
332398
Ok(())
333399
}

0 commit comments

Comments
 (0)