diff --git a/examples/bench-multi-thread.rs b/examples/bench-multi-thread.rs index dcbc469..e6b1ef2 100644 --- a/examples/bench-multi-thread.rs +++ b/examples/bench-multi-thread.rs @@ -17,7 +17,6 @@ async fn main() -> Result<()> { } else { local.run_until(server()).await; } - Ok(()) } async fn client(server_addr: String) -> ! { @@ -72,7 +71,7 @@ async fn server() -> ! { loop { ep.worker().tag_recv(tag, &mut buf).await.unwrap(); // ep.tag_send(tag, &[0]).await; - unsafe { *(&*counter as *const AtomicUsize as *mut usize) += 1 }; + counter.fetch_add(1, Ordering::Relaxed); } }); }); diff --git a/examples/bench.rs b/examples/bench.rs index 38b19f4..4f84a37 100644 --- a/examples/bench.rs +++ b/examples/bench.rs @@ -12,7 +12,6 @@ async fn main() -> Result<()> { } else { local.run_until(server()).await; } - Ok(()) } async fn client(server_addr: String) -> ! { diff --git a/examples/rma.rs b/examples/rma.rs index 1a227c2..b110d0d 100644 --- a/examples/rma.rs +++ b/examples/rma.rs @@ -18,16 +18,16 @@ async fn client(server_addr: String) -> Result<()> { println!("client: connect to {:?}", server_addr); let context = Context::new().unwrap(); let worker = context.create_worker().unwrap(); - let endpoint = worker - .connect_socket(server_addr.parse().unwrap()) - .await - .unwrap(); - endpoint.print_to_stderr(); #[cfg(not(feature = "event"))] tokio::task::spawn_local(worker.clone().polling()); #[cfg(feature = "event")] tokio::task::spawn_local(worker.clone().event_poll()); + let endpoint = worker + .connect_socket(server_addr.parse().unwrap()) + .await + .unwrap(); + endpoint.print_to_stderr(); // register memory region let mut buf: Vec = (0..0x1000).map(|x| x as u8).collect(); let mem = MemoryHandle::register(&context, &mut buf); @@ -62,13 +62,13 @@ async fn server() -> Result<()> { println!("accept"); endpoint.print_to_stderr(); - let mut vaddr_buf = [MaybeUninit::uninit(); 8]; + let mut vaddr_buf = [MaybeUninit::::uninit(); 8]; let len = endpoint.stream_recv(&mut vaddr_buf).await.unwrap(); assert_eq!(len, 8); let vaddr = u64::from_ne_bytes(unsafe { transmute(vaddr_buf) }); println!("recv: vaddr={:#x}", vaddr); - let mut rkey_buf = [MaybeUninit::uninit(); 100]; + let mut rkey_buf = [MaybeUninit::::uninit(); 100]; let len = endpoint.stream_recv(&mut rkey_buf).await.unwrap(); println!("recv rkey: len={}", len); diff --git a/examples/stream.rs b/examples/stream.rs index cd8edb5..bd7e3a3 100644 --- a/examples/stream.rs +++ b/examples/stream.rs @@ -53,7 +53,7 @@ async fn server() -> Result<()> { println!("accept"); endpoint.print_to_stderr(); - let mut buf = [MaybeUninit::uninit(); 10]; + let mut buf = [MaybeUninit::::uninit(); 10]; let len = endpoint.stream_recv(&mut buf).await.unwrap(); let msg = std::str::from_utf8(unsafe { transmute(&buf[..len]) }); println!("recv: {:?}", msg); diff --git a/examples/tag.rs b/examples/tag.rs index bb127f1..eda81de 100644 --- a/examples/tag.rs +++ b/examples/tag.rs @@ -55,7 +55,7 @@ async fn server() -> Result<()> { let _endpoint = worker.accept(connection).await.unwrap(); println!("accept"); - let mut buf = [MaybeUninit::uninit(); 0x1005]; + let mut buf = [MaybeUninit::::uninit(); 0x1005]; let len = worker.tag_recv(100, &mut buf).await.unwrap(); let msg = std::str::from_utf8(unsafe { transmute(&buf[..len]) }).unwrap(); println!("recv: {:?}", msg); diff --git a/src/ucp/endpoint/am.rs b/src/ucp/endpoint/am.rs index d5b22ad..3dc8285 100644 --- a/src/ucp/endpoint/am.rs +++ b/src/ucp/endpoint/am.rs @@ -8,10 +8,20 @@ use std::{ sync::atomic::AtomicBool, }; +//// Active message protocol. +/// Active message protocol is a mechanism for sending and receiving messages +/// between processes in a distributed system. +/// It allows a process to send a message to another process, which can then +/// handle the message and perform some action based on its contents. +/// Active messages are typically used in high-performance computing (HPC) +/// applications, where low-latency communication is critical. #[derive(Debug, PartialEq, Eq)] pub enum AmDataType { + /// Eager message Eager, + /// Data message Data, + /// Rendezvous message Rndv, } @@ -88,6 +98,7 @@ impl RawMsg { } } +/// Active message message. pub struct AmMsg<'a> { worker: &'a Worker, msg: RawMsg, @@ -98,35 +109,44 @@ impl<'a> AmMsg<'a> { AmMsg { worker, msg } } + /// Get the message ID. #[inline] pub fn id(&self) -> u16 { self.msg.id } + /// Get the message header. #[inline] pub fn header(&self) -> &[u8] { self.msg.header.as_ref() } + /// Get the message header length. #[inline] pub fn contains_data(&self) -> bool { self.data_type().is_some() } + /// Get the message data type. pub fn data_type(&self) -> Option { self.msg.data.as_ref().map(|data| data.data_type()) } + /// Get the message data. + /// Returns `None` if the message doesn't contain data. #[inline] pub fn get_data(&self) -> Option<&[u8]> { self.msg.data.as_ref().and_then(|data| data.data()) } + /// Get the message data length. + /// Returns `0` if the message doesn't contain data. #[inline] pub fn data_len(&self) -> usize { self.msg.data.as_ref().map_or(0, |data| data.len()) } + /// Receive the message data. pub async fn recv_data(&mut self) -> Result, Error> { match self.msg.data.take() { None => Ok(Vec::new()), @@ -144,6 +164,12 @@ impl<'a> AmMsg<'a> { } } + /// Receive the message data. + /// Returns `0` if the message doesn't contain data. + /// Returns the number of bytes received. + /// # Safety + /// User needs to ensure that the buffer is large enough to hold the data. + /// Otherwise, it will cause memory corruption. pub async fn recv_data_single(&mut self, buf: &mut [u8]) -> Result { if !self.contains_data() { Ok(0) @@ -153,6 +179,7 @@ impl<'a> AmMsg<'a> { } } + /// Receive the message data. pub async fn recv_data_vectored(&mut self, iov: &[IoSliceMut<'_>]) -> Result { let data = self.msg.data.take(); if let Some(data) = data { @@ -192,7 +219,7 @@ impl<'a> AmMsg<'a> { unsafe extern "C" fn callback( request: *mut c_void, status: ucs_status_t, - _length: u64, + _length: usize, _data: *mut c_void, ) { // todo: handle error & fix real data length @@ -255,6 +282,7 @@ impl<'a> AmMsg<'a> { } } + /// Check if the message needs a reply. #[inline] pub fn need_reply(&self) -> bool { self.msg.attr & (ucp_am_recv_attr_t::UCP_AM_RECV_ATTR_FIELD_REPLY_EP as u64) != 0 @@ -309,6 +337,7 @@ impl<'a> Drop for AmMsg<'a> { } } +/// Active message stream. #[derive(Clone)] pub struct AmStream<'a> { worker: &'a Worker, @@ -383,9 +412,9 @@ impl Worker { unsafe extern "C" fn callback( arg: *mut c_void, header: *const c_void, - header_len: u64, + header_len: usize, data: *mut c_void, - data_len: u64, + data_len: usize, param: *const ucp_am_recv_param_t, ) -> ucs_status_t { let handler = &*(arg as *const AmStreamInner); @@ -442,7 +471,9 @@ impl Worker { } } +/// Active message endpoint. impl Endpoint { + /// Send active message. pub async fn am_send( &self, id: u32, @@ -456,6 +487,7 @@ impl Endpoint { .await } + /// Send active message. pub async fn am_send_vectorized( &self, id: u32, @@ -469,8 +501,11 @@ impl Endpoint { } } +/// Active message protocol pub enum AmProto { + /// Eager protocol Eager, + /// Rendezvous protocol Rndv, } @@ -601,7 +636,7 @@ mod tests { header.as_slice(), data.as_slice(), true, - Some(AmProto::Eager), + Some(AmProto::Rndv), ) .await; assert!(result.is_ok()); @@ -627,7 +662,10 @@ mod tests { tokio::join!( async { // send reply - let result = unsafe { msg.reply(12, &header, &data, false, None).await }; + let result = unsafe { + msg.reply(12, &header, &data, false, Some(AmProto::Rndv)) + .await + }; assert!(result.is_ok()); }, async { diff --git a/src/ucp/endpoint/mod.rs b/src/ucp/endpoint/mod.rs index 60000b1..b477468 100644 --- a/src/ucp/endpoint/mod.rs +++ b/src/ucp/endpoint/mod.rs @@ -16,8 +16,6 @@ mod tag; #[cfg(feature = "am")] pub use self::am::*; pub use self::rma::*; -pub use self::stream::*; -pub use self::tag::*; // State associate with ucp_ep_h // todo: Add a `get_user_data` to UCX @@ -111,7 +109,7 @@ impl Endpoint { arg: std::ptr::null_mut(), // override by user_data }; - let mut handle = MaybeUninit::uninit(); + let mut handle = MaybeUninit::<*mut ucp_ep>::uninit(); let status = unsafe { ucp_ep_create(worker.handle, ¶ms, handle.as_mut_ptr()) }; if let Err(err) = Error::from_status(status) { // error happened, drop reference @@ -142,7 +140,7 @@ impl Endpoint { addrlen: sockaddr.len(), }, err_mode: ucp_err_handling_mode_t::UCP_ERR_HANDLING_MODE_PEER, - ..unsafe { MaybeUninit::uninit().assume_init() } + ..unsafe { MaybeUninit::zeroed().assume_init() } }; let endpoint = Endpoint::create(worker, params)?; @@ -157,15 +155,13 @@ impl Endpoint { worker: &Rc, addr: *const ucp_address_t, ) -> Result { - #[allow(invalid_value)] - #[allow(clippy::uninit_assumed_init)] let params = ucp_ep_params { field_mask: (ucp_ep_params_field::UCP_EP_PARAM_FIELD_REMOTE_ADDRESS | ucp_ep_params_field::UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE) .0 as u64, address: addr, err_mode: ucp_err_handling_mode_t::UCP_ERR_HANDLING_MODE_PEER, - ..unsafe { MaybeUninit::uninit().assume_init() } + ..unsafe { MaybeUninit::zeroed().assume_init() } }; Endpoint::create(worker, params) } @@ -174,12 +170,10 @@ impl Endpoint { worker: &Rc, connection: ConnectionRequest, ) -> Result { - #[allow(invalid_value)] - #[allow(clippy::uninit_assumed_init)] let params = ucp_ep_params { field_mask: ucp_ep_params_field::UCP_EP_PARAM_FIELD_CONN_REQUEST.0 as u64, conn_request: connection.handle, - ..unsafe { MaybeUninit::uninit().assume_init() } + ..unsafe { MaybeUninit::zeroed().assume_init() } }; let endpoint = Endpoint::create(worker, params)?; diff --git a/src/ucp/endpoint/rma.rs b/src/ucp/endpoint/rma.rs index 57c64f0..e9cd41e 100644 --- a/src/ucp/endpoint/rma.rs +++ b/src/ucp/endpoint/rma.rs @@ -19,9 +19,9 @@ impl MemoryHandle { .0 as u64, address: region.as_ptr() as _, length: region.len() as _, - ..unsafe { MaybeUninit::uninit().assume_init() } + ..unsafe { MaybeUninit::zeroed().assume_init() } }; - let mut handle = MaybeUninit::uninit(); + let mut handle = MaybeUninit::<*mut ucp_mem>::uninit(); let status = unsafe { ucp_mem_map(context.handle, ¶ms, handle.as_mut_ptr()) }; assert_eq!(status, ucs_status_t::UCS_OK); MemoryHandle { @@ -32,8 +32,8 @@ impl MemoryHandle { /// Packs into the buffer a remote access key (RKEY) object. pub fn pack(&self) -> RKeyBuffer { - let mut buf = MaybeUninit::uninit(); - let mut len = MaybeUninit::uninit(); + let mut buf = MaybeUninit::<*mut c_void>::uninit(); + let mut len = MaybeUninit::::uninit(); let status = unsafe { ucp_rkey_pack( self.context.handle, @@ -60,7 +60,7 @@ impl Drop for MemoryHandle { #[derive(Debug)] pub struct RKeyBuffer { buf: *mut c_void, - len: u64, + len: usize, } impl AsRef<[u8]> for RKeyBuffer { @@ -87,7 +87,7 @@ unsafe impl Sync for RKey {} impl RKey { /// Create remote access key from packed buffer. pub fn unpack(endpoint: &Endpoint, rkey_buffer: &[u8]) -> Self { - let mut handle = MaybeUninit::uninit(); + let mut handle = MaybeUninit::<*mut ucp_rkey>::uninit(); let status = unsafe { ucp_ep_rkey_unpack( endpoint.handle, diff --git a/src/ucp/endpoint/stream.rs b/src/ucp/endpoint/stream.rs index 13d3274..411a80b 100644 --- a/src/ucp/endpoint/stream.rs +++ b/src/ucp/endpoint/stream.rs @@ -40,7 +40,7 @@ impl Endpoint { /// Receives data from stream. pub async fn stream_recv(&self, buf: &mut [MaybeUninit]) -> Result { trace!("stream_recv: endpoint={:?} len={}", self.handle, buf.len()); - unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, length: u64) { + unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, length: usize) { trace!( "stream_recv: complete. req={:?}, status={:?}, len={}", request, @@ -50,7 +50,7 @@ impl Endpoint { let request = &mut *(request as *mut Request); request.waker.wake(); } - let mut length = MaybeUninit::uninit(); + let mut length = MaybeUninit::::uninit(); let status = unsafe { ucp_stream_recv_nb( self.get_handle()?, diff --git a/src/ucp/endpoint/tag.rs b/src/ucp/endpoint/tag.rs index 285974d..6f04b4d 100644 --- a/src/ucp/endpoint/tag.rs +++ b/src/ucp/endpoint/tag.rs @@ -238,7 +238,7 @@ mod tests { }, async { // recv - let mut buf = vec![MaybeUninit::uninit(); msg_size]; + let mut buf = vec![MaybeUninit::::uninit(); msg_size]; worker1.tag_recv(1, &mut buf).await.unwrap(); println!("tag recved"); } diff --git a/src/ucp/listener.rs b/src/ucp/listener.rs index 049053e..1ad0aa5 100644 --- a/src/ucp/listener.rs +++ b/src/ucp/listener.rs @@ -30,11 +30,10 @@ unsafe impl Send for ConnectionRequest {} impl ConnectionRequest { /// The address of the remote client that sent the connection request to the server. pub fn remote_addr(&self) -> Result { - #[allow(clippy::uninit_assumed_init)] let mut attr = ucp_conn_request_attr { field_mask: ucp_conn_request_attr_field::UCP_CONN_REQUEST_ATTR_FIELD_CLIENT_ADDR.0 as u64, - ..unsafe { MaybeUninit::uninit().assume_init() } + ..unsafe { MaybeUninit::zeroed().assume_init() } }; let status = unsafe { ucp_conn_request_query(self.handle, &mut attr) }; Error::from_status(status)?; @@ -75,7 +74,7 @@ impl Listener { arg: &*sender as *const mpsc::UnboundedSender as _, }, }; - let mut handle = MaybeUninit::uninit(); + let mut handle = MaybeUninit::<*mut ucp_listener>::uninit(); let status = unsafe { ucp_listener_create(worker.handle, ¶ms, handle.as_mut_ptr()) }; Error::from_status(status)?; trace!("create listener={:?}", handle); @@ -88,10 +87,9 @@ impl Listener { /// Returns the local socket address of this listener. pub fn socket_addr(&self) -> Result { - #[allow(clippy::uninit_assumed_init)] let mut attr = ucp_listener_attr_t { field_mask: ucp_listener_attr_field::UCP_LISTENER_ATTR_FIELD_SOCKADDR.0 as u64, - sockaddr: unsafe { MaybeUninit::uninit().assume_init() }, + sockaddr: unsafe { MaybeUninit::zeroed().assume_init() }, }; let status = unsafe { ucp_listener_query(self.handle, &mut attr) }; Error::from_status(status)?; diff --git a/src/ucp/mod.rs b/src/ucp/mod.rs index 9f8d8d4..440e6be 100644 --- a/src/ucp/mod.rs +++ b/src/ucp/mod.rs @@ -27,7 +27,7 @@ pub struct Config { impl Default for Config { fn default() -> Self { - let mut handle = MaybeUninit::uninit(); + let mut handle = MaybeUninit::<*mut ucp_config>::uninit(); let status = unsafe { ucp_config_read(null(), null(), handle.as_mut_ptr()) }; Error::from_status(status).unwrap(); @@ -83,7 +83,6 @@ impl Context { #[cfg(feature = "am")] let features = features | ucp_feature::UCP_FEATURE_AM; - #[allow(clippy::uninit_assumed_init)] let params = ucp_params_t { field_mask: (ucp_params_field::UCP_PARAM_FIELD_FEATURES | ucp_params_field::UCP_PARAM_FIELD_REQUEST_SIZE @@ -92,13 +91,13 @@ impl Context { | ucp_params_field::UCP_PARAM_FIELD_MT_WORKERS_SHARED) .0 as u64, features: features.0 as u64, - request_size: std::mem::size_of::() as u64, + request_size: std::mem::size_of::() as usize, request_init: Some(Request::init), request_cleanup: Some(Request::cleanup), mt_workers_shared: 1, - ..unsafe { MaybeUninit::uninit().assume_init() } + ..unsafe { std::mem::zeroed() } }; - let mut handle = MaybeUninit::uninit(); + let mut handle = MaybeUninit::<*mut ucp_context>::uninit(); let status = unsafe { ucp_init_version( UCP_API_MAJOR, @@ -130,18 +129,10 @@ impl Context { /// Fetches information about the context. pub fn query(&self) -> Result { - #[allow(invalid_value)] - #[allow(clippy::uninit_assumed_init)] - let mut attr = ucp_context_attr { - field_mask: (ucp_context_attr_field::UCP_ATTR_FIELD_REQUEST_SIZE - | ucp_context_attr_field::UCP_ATTR_FIELD_THREAD_MODE) - .0 as u64, - ..unsafe { MaybeUninit::uninit().assume_init() } - }; - let status = unsafe { ucp_context_query(self.handle, &mut attr) }; + let mut attr = MaybeUninit::::uninit(); + let status = unsafe { ucp_context_query(self.handle, attr.as_mut_ptr()) }; Error::from_status(status)?; - - Ok(attr) + Ok(unsafe { attr.assume_init() }) } } diff --git a/src/ucp/worker.rs b/src/ucp/worker.rs index b8e273c..00abd16 100644 --- a/src/ucp/worker.rs +++ b/src/ucp/worker.rs @@ -34,7 +34,7 @@ impl Worker { ucp_worker_params_field::UCP_WORKER_PARAM_FIELD_THREAD_MODE.0 as _; (*params.as_mut_ptr()).thread_mode = ucs_thread_mode_t::UCS_THREAD_MODE_SINGLE; }; - let mut handle = MaybeUninit::uninit(); + let mut handle = MaybeUninit::<*mut ucp_worker>::uninit(); let status = unsafe { ucp_worker_create(context.handle, params.as_ptr(), handle.as_mut_ptr()) }; Error::from_status(status)?; @@ -98,8 +98,8 @@ impl Worker { /// This address can be passed to remote instances of the UCP library /// in order to connect to this worker. pub fn address(&self) -> Result, Error> { - let mut handle = MaybeUninit::uninit(); - let mut length = MaybeUninit::uninit(); + let mut handle = MaybeUninit::<*mut ucp_address>::uninit(); + let mut length = MaybeUninit::::uninit(); let status = unsafe { ucp_worker_get_address(self.handle, handle.as_mut_ptr(), length.as_mut_ptr()) }; @@ -157,7 +157,7 @@ impl Worker { /// Returns a valid file descriptor for polling functions. pub fn event_fd(&self) -> Result { - let mut fd = MaybeUninit::uninit(); + let mut fd = MaybeUninit::::uninit(); let status = unsafe { ucp_worker_get_efd(self.handle, fd.as_mut_ptr()) }; Error::from_status(status)?; diff --git a/ucx1-sys/Cargo.toml b/ucx1-sys/Cargo.toml index e129285..d71ff71 100644 --- a/ucx1-sys/Cargo.toml +++ b/ucx1-sys/Cargo.toml @@ -13,4 +13,4 @@ categories = ["external-ffi-bindings"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [build-dependencies] -bindgen = "0.59" +bindgen = "0.66" diff --git a/ucx1-sys/ucx b/ucx1-sys/ucx index 6765970..938ffcd 160000 --- a/ucx1-sys/ucx +++ b/ucx1-sys/ucx @@ -1 +1 @@ -Subproject commit 67659706e8d5c2b6fe88af45720777748dd21503 +Subproject commit 938ffcd10122742d0f46a4f609e7395d1648c969