Skip to content

Fix the crashing of sending with AmProto::Eager proto #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 29, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 135 additions & 113 deletions src/ucp/endpoint/am.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ impl<'a> AmMsg<'a> {
AmMsg { worker, msg }
}

/// Get the message ID.
/// Get the ActiveStream id
#[inline]
pub fn id(&self) -> u16 {
self.msg.id
Expand All @@ -121,10 +121,10 @@ impl<'a> AmMsg<'a> {
self.msg.header.as_ref()
}

/// Get the message header length.
/// Returns `true` if the message contains data. Otherwise, `false`.
#[inline]
pub fn contains_data(&self) -> bool {
self.data_type().is_some()
self.msg.data.is_some()
}

/// Get the message data type.
Expand All @@ -133,10 +133,14 @@ impl<'a> AmMsg<'a> {
}

/// Get the message data.
/// Returns `None` if the message doesn't contain data.
/// Returns `None` if needs to receive data.
/// Returns `Some(slice)` if the message contains concrete data.
#[inline]
pub fn get_data(&self) -> Option<&[u8]> {
self.msg.data.as_ref().and_then(|data| data.data())
match self.msg.data {
Some(ref amdata) => amdata.data(),
None => Some(&[]),
}
}

/// Get the message data length.
Expand All @@ -151,6 +155,11 @@ impl<'a> AmMsg<'a> {
match self.msg.data.take() {
None => Ok(Vec::new()),
Some(AmData::Eager(vec)) => Ok(vec),
Some(AmData::Data(data)) => {
let v = data.to_vec();
self.drop_msg(AmData::Data(data));
Ok(v)
}
Some(data) => {
self.msg.data = Some(data);
let mut buf = Vec::with_capacity(self.data_len());
Expand Down Expand Up @@ -181,104 +190,110 @@ impl<'a> AmMsg<'a> {

/// Receive the message data.
pub async fn recv_data_vectored(&mut self, iov: &[IoSliceMut<'_>]) -> Result<usize, Error> {
let data = self.msg.data.take();
if let Some(data) = data {
if let AmData::Eager(data) = data {
// return error if buffer size < data length, same with ucx
let cap = iov.iter().fold(0_usize, |cap, buf| cap + buf.len());
if cap < data.len() {
return Err(Error::MessageTruncated);
}
fn copy_data_to_iov(data: &[u8], iov: &[IoSliceMut<'_>]) -> Result<usize, Error> {
// return error if buffer size < data length, same with ucx
let cap = iov.iter().fold(0_usize, |cap, buf| cap + buf.len());
if cap < data.len() {
return Err(Error::MessageTruncated);
}

let mut copied = 0_usize;
for buf in iov {
let len = std::cmp::min(data.len() - copied, buf.len());
if len == 0 {
break;
}
let mut copied = 0_usize;
for buf in iov {
let len = std::cmp::min(data.len() - copied, buf.len());
if len == 0 {
break;
}

let buf = &buf[..len];
unsafe {
std::ptr::copy_nonoverlapping(
data[copied..].as_ptr(),
buf.as_ptr() as _,
len,
)
}
copied += len;
let buf = &buf[..len];
unsafe {
std::ptr::copy_nonoverlapping(data[copied..].as_ptr(), buf.as_ptr() as _, len)
}
return Ok(copied);
copied += len;
}
Ok(copied)
}
let data = self.msg.data.take();

let (data_desc, data_len) = match data {
AmData::Data(data) => (data.as_ptr(), data.len()),
AmData::Rndv(data) => (data.as_ptr(), data.len()),
_ => unreachable!(),
};

unsafe extern "C" fn callback(
request: *mut c_void,
status: ucs_status_t,
_length: usize,
_data: *mut c_void,
) {
// todo: handle error & fix real data length
match data {
Some(AmData::Eager(data)) => {
// eager message, no need to receive
copy_data_to_iov(&data, iov)
}
Some(AmData::Data(data)) => {
// data message, no need to receive
let size = copy_data_to_iov(&data, iov)?;
self.drop_msg(AmData::Data(data));
Ok(size)
}
Some(AmData::Rndv(desc)) => {
// rndv message, need to receive
let (data_desc, data_len) = (desc.as_ptr(), desc.len());

unsafe extern "C" fn callback(
request: *mut c_void,
status: ucs_status_t,
_length: usize,
_data: *mut c_void,
) {
// todo: handle error & fix real data length
trace!(
"recv_data_vectored: complete, req={:?}, status={:?}",
request,
status
);
let request = &mut *(request as *mut Request);
request.waker.wake();
}
trace!(
"recv_data_vectored: complete, req={:?}, status={:?}",
request,
status
"recv_data_vectored: worker={:?} iov.len={}",
self.worker.handle,
iov.len()
);
let request = &mut *(request as *mut Request);
request.waker.wake();
}
trace!(
"recv_data_vectored: worker={:?} iov.len={}",
self.worker.handle,
iov.len()
);
let mut param = MaybeUninit::<ucp_request_param_t>::uninit();
let (buffer, count) = unsafe {
let param = &mut *param.as_mut_ptr();
param.op_attr_mask = ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32
| ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32;
param.cb = ucp_request_param_t__bindgen_ty_1 {
recv_am: Some(callback),
let mut param = MaybeUninit::<ucp_request_param_t>::uninit();
let (buffer, count) = unsafe {
let param = &mut *param.as_mut_ptr();
param.op_attr_mask = ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32
| ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32;
param.cb = ucp_request_param_t__bindgen_ty_1 {
recv_am: Some(callback),
};

if iov.len() == 1 {
param.datatype = ucp_dt_make_contig(1);
(iov[0].as_ptr(), iov[0].len())
} else {
param.datatype = ucp_dt_type::UCP_DATATYPE_IOV as _;
(iov.as_ptr() as _, iov.len())
}
};

if iov.len() == 1 {
param.datatype = ucp_dt_make_contig(1);
(iov[0].as_ptr(), iov[0].len())
let status = unsafe {
ucp_am_recv_data_nbx(
self.worker.handle,
data_desc as _,
buffer as _,
count as _,
param.as_ptr(),
)
};
if status.is_null() {
trace!("recv_data_vectored: complete");
Ok(data_len)
} else if UCS_PTR_IS_PTR(status) {
RequestHandle {
ptr: status,
poll_fn: poll_recv,
}
.await;
Ok(data_len)
} else {
param.datatype = ucp_dt_type::UCP_DATATYPE_IOV as _;
(iov.as_ptr() as _, iov.len())
Err(Error::from_ptr(status).unwrap_err())
}
};

let status = unsafe {
ucp_am_recv_data_nbx(
self.worker.handle,
data_desc as _,
buffer as _,
count as _,
param.as_ptr(),
)
};
if status.is_null() {
trace!("recv_data_vectored: complete");
Ok(data_len)
} else if UCS_PTR_IS_PTR(status) {
RequestHandle {
ptr: status,
poll_fn: poll_recv,
}
.await;
Ok(data_len)
} else {
Err(Error::from_ptr(status).unwrap_err())
}
} else {
// no data
Ok(0)
None => {
// no data
Ok(0)
}
}
}

Expand Down Expand Up @@ -321,18 +336,24 @@ impl<'a> AmMsg<'a> {
assert!(self.need_reply());
am_send(self.msg.reply_ep, id, header, data, need_reply, proto).await
}

fn drop_msg(&mut self, data: AmData) {
match data {
AmData::Eager(_) => (),
AmData::Data(data) => unsafe {
ucp_am_data_release(self.worker.handle, data.as_ptr() as _);
},
AmData::Rndv(data) => unsafe {
ucp_am_data_release(self.worker.handle, data.as_ptr() as _);
},
}
}
}

impl<'a> Drop for AmMsg<'a> {
fn drop(&mut self) {
match self.msg.data.take() {
Some(AmData::Data(desc)) => unsafe {
ucp_am_data_release(self.worker.handle, desc.as_ptr() as _);
},
Some(AmData::Rndv(desc)) => unsafe {
ucp_am_data_release(self.worker.handle, desc.as_ptr() as _);
},
_ => (),
if let Some(data) = self.msg.data.take() {
self.drop_msg(data);
}
}
}
Expand Down Expand Up @@ -502,6 +523,8 @@ impl Endpoint {
}

/// Active message protocol
#[derive(Debug, Clone, Copy)]
#[repr(u32)]
pub enum AmProto {
/// Eager protocol
Eager,
Expand Down Expand Up @@ -594,12 +617,20 @@ mod tests {

#[test_log::test]
fn am() {
for i in 0..20_usize {
spawn_thread!(send_recv(4 << i)).join().unwrap();
let protos = vec![None, Some(AmProto::Eager), Some(AmProto::Rndv)];
for block_size_shift in 0..20_usize {
for p in protos.iter() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_time()
.build()
.unwrap();
let local = tokio::task::LocalSet::new();
local.block_on(&rt, send_recv(4 << block_size_shift, *p));
}
}
}

async fn send_recv(data_size: usize) {
async fn send_recv(data_size: usize, proto: Option<AmProto>) {
let context1 = Context::new().unwrap();
let worker1 = context1.create_worker().unwrap();
let context2 = Context::new().unwrap();
Expand Down Expand Up @@ -631,13 +662,7 @@ mod tests {
async {
// send msg
let result = endpoint2
.am_send(
16,
header.as_slice(),
data.as_slice(),
true,
Some(AmProto::Rndv),
)
.am_send(16, header.as_slice(), data.as_slice(), true, proto)
.await;
assert!(result.is_ok());
},
Expand All @@ -662,10 +687,7 @@ mod tests {
tokio::join!(
async {
// send reply
let result = unsafe {
msg.reply(12, &header, &data, false, Some(AmProto::Rndv))
.await
};
let result = unsafe { msg.reply(12, &header, &data, false, proto).await };
assert!(result.is_ok());
},
async {
Expand Down
Loading