From f9e88e64db00a6a1cfd9a59ada4bf773e93e173f Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Sun, 27 Apr 2025 21:09:33 +0800 Subject: [PATCH 1/2] Manually drop the case AmData::Data According to the document of UCP_AM_RECV_ATTR_FLAG_DATA, it should call ucp_am_data_release after using AmData::Data. --- src/ucp/endpoint/am.rs | 248 ++++++++++++++++++++++------------------- 1 file changed, 135 insertions(+), 113 deletions(-) diff --git a/src/ucp/endpoint/am.rs b/src/ucp/endpoint/am.rs index 3dc8285..82f91d1 100644 --- a/src/ucp/endpoint/am.rs +++ b/src/ucp/endpoint/am.rs @@ -109,7 +109,7 @@ impl<'a> AmMsg<'a> { AmMsg { worker, msg } } - /// Get the message ID. + /// Get the ActiveStream id #[inline] pub fn id(&self) -> u16 { self.msg.id @@ -121,10 +121,10 @@ impl<'a> AmMsg<'a> { self.msg.header.as_ref() } - /// Get the message header length. + /// Returns `true` if the message contains data. Otherwise, `false`. #[inline] pub fn contains_data(&self) -> bool { - self.data_type().is_some() + self.msg.data.is_some() } /// Get the message data type. @@ -133,10 +133,14 @@ impl<'a> AmMsg<'a> { } /// Get the message data. - /// Returns `None` if the message doesn't contain data. + /// Returns `None` if needs to receive data. + /// Returns `Some(slice)` if the message contains concrete data. #[inline] pub fn get_data(&self) -> Option<&[u8]> { - self.msg.data.as_ref().and_then(|data| data.data()) + match self.msg.data { + Some(ref amdata) => amdata.data(), + None => Some(&[]), + } } /// Get the message data length. @@ -151,6 +155,11 @@ impl<'a> AmMsg<'a> { match self.msg.data.take() { None => Ok(Vec::new()), Some(AmData::Eager(vec)) => Ok(vec), + Some(AmData::Data(data)) => { + let v = data.to_vec(); + self.drop_msg(AmData::Data(data)); + Ok(v) + } Some(data) => { self.msg.data = Some(data); let mut buf = Vec::with_capacity(self.data_len()); @@ -181,104 +190,110 @@ impl<'a> AmMsg<'a> { /// Receive the message data. pub async fn recv_data_vectored(&mut self, iov: &[IoSliceMut<'_>]) -> Result { - let data = self.msg.data.take(); - if let Some(data) = data { - if let AmData::Eager(data) = data { - // return error if buffer size < data length, same with ucx - let cap = iov.iter().fold(0_usize, |cap, buf| cap + buf.len()); - if cap < data.len() { - return Err(Error::MessageTruncated); - } + fn copy_data_to_iov(data: &[u8], iov: &[IoSliceMut<'_>]) -> Result { + // return error if buffer size < data length, same with ucx + let cap = iov.iter().fold(0_usize, |cap, buf| cap + buf.len()); + if cap < data.len() { + return Err(Error::MessageTruncated); + } - let mut copied = 0_usize; - for buf in iov { - let len = std::cmp::min(data.len() - copied, buf.len()); - if len == 0 { - break; - } + let mut copied = 0_usize; + for buf in iov { + let len = std::cmp::min(data.len() - copied, buf.len()); + if len == 0 { + break; + } - let buf = &buf[..len]; - unsafe { - std::ptr::copy_nonoverlapping( - data[copied..].as_ptr(), - buf.as_ptr() as _, - len, - ) - } - copied += len; + let buf = &buf[..len]; + unsafe { + std::ptr::copy_nonoverlapping(data[copied..].as_ptr(), buf.as_ptr() as _, len) } - return Ok(copied); + copied += len; } + Ok(copied) + } + let data = self.msg.data.take(); - let (data_desc, data_len) = match data { - AmData::Data(data) => (data.as_ptr(), data.len()), - AmData::Rndv(data) => (data.as_ptr(), data.len()), - _ => unreachable!(), - }; - - unsafe extern "C" fn callback( - request: *mut c_void, - status: ucs_status_t, - _length: usize, - _data: *mut c_void, - ) { - // todo: handle error & fix real data length + match data { + Some(AmData::Eager(data)) => { + // eager message, no need to receive + copy_data_to_iov(&data, iov) + } + Some(AmData::Data(data)) => { + // data message, no need to receive + let size = copy_data_to_iov(&data, iov)?; + self.drop_msg(AmData::Data(data)); + Ok(size) + } + Some(AmData::Rndv(desc)) => { + // rndv message, need to receive + let (data_desc, data_len) = (desc.as_ptr(), desc.len()); + + unsafe extern "C" fn callback( + request: *mut c_void, + status: ucs_status_t, + _length: usize, + _data: *mut c_void, + ) { + // todo: handle error & fix real data length + trace!( + "recv_data_vectored: complete, req={:?}, status={:?}", + request, + status + ); + let request = &mut *(request as *mut Request); + request.waker.wake(); + } trace!( - "recv_data_vectored: complete, req={:?}, status={:?}", - request, - status + "recv_data_vectored: worker={:?} iov.len={}", + self.worker.handle, + iov.len() ); - let request = &mut *(request as *mut Request); - request.waker.wake(); - } - trace!( - "recv_data_vectored: worker={:?} iov.len={}", - self.worker.handle, - iov.len() - ); - let mut param = MaybeUninit::::uninit(); - let (buffer, count) = unsafe { - let param = &mut *param.as_mut_ptr(); - param.op_attr_mask = ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32 - | ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32; - param.cb = ucp_request_param_t__bindgen_ty_1 { - recv_am: Some(callback), + let mut param = MaybeUninit::::uninit(); + let (buffer, count) = unsafe { + let param = &mut *param.as_mut_ptr(); + param.op_attr_mask = ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32 + | ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32; + param.cb = ucp_request_param_t__bindgen_ty_1 { + recv_am: Some(callback), + }; + + if iov.len() == 1 { + param.datatype = ucp_dt_make_contig(1); + (iov[0].as_ptr(), iov[0].len()) + } else { + param.datatype = ucp_dt_type::UCP_DATATYPE_IOV as _; + (iov.as_ptr() as _, iov.len()) + } }; - if iov.len() == 1 { - param.datatype = ucp_dt_make_contig(1); - (iov[0].as_ptr(), iov[0].len()) + let status = unsafe { + ucp_am_recv_data_nbx( + self.worker.handle, + data_desc as _, + buffer as _, + count as _, + param.as_ptr(), + ) + }; + if status.is_null() { + trace!("recv_data_vectored: complete"); + Ok(data_len) + } else if UCS_PTR_IS_PTR(status) { + RequestHandle { + ptr: status, + poll_fn: poll_recv, + } + .await; + Ok(data_len) } else { - param.datatype = ucp_dt_type::UCP_DATATYPE_IOV as _; - (iov.as_ptr() as _, iov.len()) + Err(Error::from_ptr(status).unwrap_err()) } - }; - - let status = unsafe { - ucp_am_recv_data_nbx( - self.worker.handle, - data_desc as _, - buffer as _, - count as _, - param.as_ptr(), - ) - }; - if status.is_null() { - trace!("recv_data_vectored: complete"); - Ok(data_len) - } else if UCS_PTR_IS_PTR(status) { - RequestHandle { - ptr: status, - poll_fn: poll_recv, - } - .await; - Ok(data_len) - } else { - Err(Error::from_ptr(status).unwrap_err()) } - } else { - // no data - Ok(0) + None => { + // no data + Ok(0) + } } } @@ -321,18 +336,24 @@ impl<'a> AmMsg<'a> { assert!(self.need_reply()); am_send(self.msg.reply_ep, id, header, data, need_reply, proto).await } + + fn drop_msg(&mut self, data: AmData) { + match data { + AmData::Eager(_) => (), + AmData::Data(data) => unsafe { + ucp_am_data_release(self.worker.handle, data.as_ptr() as _); + }, + AmData::Rndv(data) => unsafe { + ucp_am_data_release(self.worker.handle, data.as_ptr() as _); + }, + } + } } impl<'a> Drop for AmMsg<'a> { fn drop(&mut self) { - match self.msg.data.take() { - Some(AmData::Data(desc)) => unsafe { - ucp_am_data_release(self.worker.handle, desc.as_ptr() as _); - }, - Some(AmData::Rndv(desc)) => unsafe { - ucp_am_data_release(self.worker.handle, desc.as_ptr() as _); - }, - _ => (), + if let Some(data) = self.msg.data.take() { + self.drop_msg(data); } } } @@ -502,6 +523,8 @@ impl Endpoint { } /// Active message protocol +#[derive(Debug, Clone, Copy)] +#[repr(u32)] pub enum AmProto { /// Eager protocol Eager, @@ -594,12 +617,20 @@ mod tests { #[test_log::test] fn am() { - for i in 0..20_usize { - spawn_thread!(send_recv(4 << i)).join().unwrap(); + let protos = vec![None, Some(AmProto::Eager), Some(AmProto::Rndv)]; + for block_size_shift in 0..20_usize { + for p in protos.iter() { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_time() + .build() + .unwrap(); + let local = tokio::task::LocalSet::new(); + local.block_on(&rt, send_recv(4 << block_size_shift, *p)); + } } } - async fn send_recv(data_size: usize) { + async fn send_recv(data_size: usize, proto: Option) { let context1 = Context::new().unwrap(); let worker1 = context1.create_worker().unwrap(); let context2 = Context::new().unwrap(); @@ -631,13 +662,7 @@ mod tests { async { // send msg let result = endpoint2 - .am_send( - 16, - header.as_slice(), - data.as_slice(), - true, - Some(AmProto::Rndv), - ) + .am_send(16, header.as_slice(), data.as_slice(), true, proto) .await; assert!(result.is_ok()); }, @@ -662,10 +687,7 @@ mod tests { tokio::join!( async { // send reply - let result = unsafe { - msg.reply(12, &header, &data, false, Some(AmProto::Rndv)) - .await - }; + let result = unsafe { msg.reply(12, &header, &data, false, proto).await }; assert!(result.is_ok()); }, async { From 8a28434ad0cc2a9272dc4839c30bdd98ad26d425 Mon Sep 17 00:00:00 2001 From: Li Kaiwei Date: Sun, 27 Apr 2025 21:26:04 +0800 Subject: [PATCH 2/2] format --- src/ucp/endpoint/am.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/ucp/endpoint/am.rs b/src/ucp/endpoint/am.rs index 82f91d1..2002020 100644 --- a/src/ucp/endpoint/am.rs +++ b/src/ucp/endpoint/am.rs @@ -620,12 +620,12 @@ mod tests { let protos = vec![None, Some(AmProto::Eager), Some(AmProto::Rndv)]; for block_size_shift in 0..20_usize { for p in protos.iter() { - let rt = tokio::runtime::Builder::new_current_thread() - .enable_time() - .build() - .unwrap(); - let local = tokio::task::LocalSet::new(); - local.block_on(&rt, send_recv(4 << block_size_shift, *p)); + let rt = tokio::runtime::Builder::new_current_thread() + .enable_time() + .build() + .unwrap(); + let local = tokio::task::LocalSet::new(); + local.block_on(&rt, send_recv(4 << block_size_shift, *p)); } } }