Skip to content

fix in current rust version (rustc 1.86.0) #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/bench-multi-thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
} else {
local.run_until(server()).await;
}
Ok(())
}

async fn client(server_addr: String) -> ! {
Expand Down Expand Up @@ -72,7 +71,7 @@
loop {
ep.worker().tag_recv(tag, &mut buf).await.unwrap();
// ep.tag_send(tag, &[0]).await;
unsafe { *(&*counter as *const AtomicUsize as *mut usize) += 1 };
counter.fetch_add(1, Ordering::Relaxed);
}
});
});
Expand Down Expand Up @@ -124,7 +123,7 @@
.build()
.unwrap();
let local = tokio::task::LocalSet::new();
#[cfg(not(event))]

Check warning on line 126 in examples/bench-multi-thread.rs

View workflow job for this annotation

GitHub Actions / build

unexpected `cfg` condition name: `event`

Check warning on line 126 in examples/bench-multi-thread.rs

View workflow job for this annotation

GitHub Actions / build

unexpected `cfg` condition name: `event`
local.spawn_local(worker.clone().polling());
#[cfg(feature = "event")]
local.spawn_local(worker.clone().event_poll());
Expand Down
1 change: 0 additions & 1 deletion examples/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ async fn main() -> Result<()> {
} else {
local.run_until(server()).await;
}
Ok(())
}

async fn client(server_addr: String) -> ! {
Expand Down
14 changes: 7 additions & 7 deletions examples/rma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ async fn client(server_addr: String) -> Result<()> {
println!("client: connect to {:?}", server_addr);
let context = Context::new().unwrap();
let worker = context.create_worker().unwrap();
let endpoint = worker
.connect_socket(server_addr.parse().unwrap())
.await
.unwrap();
endpoint.print_to_stderr();
#[cfg(not(feature = "event"))]
tokio::task::spawn_local(worker.clone().polling());
#[cfg(feature = "event")]
tokio::task::spawn_local(worker.clone().event_poll());

let endpoint = worker
.connect_socket(server_addr.parse().unwrap())
.await
.unwrap();
endpoint.print_to_stderr();
// register memory region
let mut buf: Vec<u8> = (0..0x1000).map(|x| x as u8).collect();
let mem = MemoryHandle::register(&context, &mut buf);
Expand Down Expand Up @@ -62,13 +62,13 @@ async fn server() -> Result<()> {
println!("accept");
endpoint.print_to_stderr();

let mut vaddr_buf = [MaybeUninit::uninit(); 8];
let mut vaddr_buf = [MaybeUninit::<u8>::uninit(); 8];
let len = endpoint.stream_recv(&mut vaddr_buf).await.unwrap();
assert_eq!(len, 8);
let vaddr = u64::from_ne_bytes(unsafe { transmute(vaddr_buf) });
println!("recv: vaddr={:#x}", vaddr);

let mut rkey_buf = [MaybeUninit::uninit(); 100];
let mut rkey_buf = [MaybeUninit::<u8>::uninit(); 100];
let len = endpoint.stream_recv(&mut rkey_buf).await.unwrap();
println!("recv rkey: len={}", len);

Expand Down
2 changes: 1 addition & 1 deletion examples/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async fn server() -> Result<()> {
println!("accept");
endpoint.print_to_stderr();

let mut buf = [MaybeUninit::uninit(); 10];
let mut buf = [MaybeUninit::<u8>::uninit(); 10];
let len = endpoint.stream_recv(&mut buf).await.unwrap();
let msg = std::str::from_utf8(unsafe { transmute(&buf[..len]) });
println!("recv: {:?}", msg);
Expand Down
2 changes: 1 addition & 1 deletion examples/tag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async fn server() -> Result<()> {
let _endpoint = worker.accept(connection).await.unwrap();
println!("accept");

let mut buf = [MaybeUninit::uninit(); 0x1005];
let mut buf = [MaybeUninit::<u8>::uninit(); 0x1005];
let len = worker.tag_recv(100, &mut buf).await.unwrap();
let msg = std::str::from_utf8(unsafe { transmute(&buf[..len]) }).unwrap();
println!("recv: {:?}", msg);
Expand Down
48 changes: 43 additions & 5 deletions src/ucp/endpoint/am.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,20 @@ use std::{
sync::atomic::AtomicBool,
};

//// Active message protocol.
/// Active message protocol is a mechanism for sending and receiving messages
/// between processes in a distributed system.
/// It allows a process to send a message to another process, which can then
/// handle the message and perform some action based on its contents.
/// Active messages are typically used in high-performance computing (HPC)
/// applications, where low-latency communication is critical.
#[derive(Debug, PartialEq, Eq)]
pub enum AmDataType {
/// Eager message
Eager,
/// Data message
Data,
/// Rendezvous message
Rndv,
}

Expand Down Expand Up @@ -88,6 +98,7 @@ impl RawMsg {
}
}

/// Active message message.
pub struct AmMsg<'a> {
worker: &'a Worker,
msg: RawMsg,
Expand All @@ -98,35 +109,44 @@ impl<'a> AmMsg<'a> {
AmMsg { worker, msg }
}

/// Get the message ID.
#[inline]
pub fn id(&self) -> u16 {
self.msg.id
}

/// Get the message header.
#[inline]
pub fn header(&self) -> &[u8] {
self.msg.header.as_ref()
}

/// Get the message header length.
#[inline]
pub fn contains_data(&self) -> bool {
self.data_type().is_some()
}

/// Get the message data type.
pub fn data_type(&self) -> Option<AmDataType> {
self.msg.data.as_ref().map(|data| data.data_type())
}

/// Get the message data.
/// Returns `None` if the message doesn't contain data.
#[inline]
pub fn get_data(&self) -> Option<&[u8]> {
self.msg.data.as_ref().and_then(|data| data.data())
}

/// Get the message data length.
/// Returns `0` if the message doesn't contain data.
#[inline]
pub fn data_len(&self) -> usize {
self.msg.data.as_ref().map_or(0, |data| data.len())
}

/// Receive the message data.
pub async fn recv_data(&mut self) -> Result<Vec<u8>, Error> {
match self.msg.data.take() {
None => Ok(Vec::new()),
Expand All @@ -144,6 +164,12 @@ impl<'a> AmMsg<'a> {
}
}

/// Receive the message data.
/// Returns `0` if the message doesn't contain data.
/// Returns the number of bytes received.
/// # Safety
/// User needs to ensure that the buffer is large enough to hold the data.
/// Otherwise, it will cause memory corruption.
pub async fn recv_data_single(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
if !self.contains_data() {
Ok(0)
Expand All @@ -153,6 +179,7 @@ impl<'a> AmMsg<'a> {
}
}

/// Receive the message data.
pub async fn recv_data_vectored(&mut self, iov: &[IoSliceMut<'_>]) -> Result<usize, Error> {
let data = self.msg.data.take();
if let Some(data) = data {
Expand Down Expand Up @@ -192,7 +219,7 @@ impl<'a> AmMsg<'a> {
unsafe extern "C" fn callback(
request: *mut c_void,
status: ucs_status_t,
_length: u64,
_length: usize,
_data: *mut c_void,
) {
// todo: handle error & fix real data length
Expand Down Expand Up @@ -255,6 +282,7 @@ impl<'a> AmMsg<'a> {
}
}

/// Check if the message needs a reply.
#[inline]
pub fn need_reply(&self) -> bool {
self.msg.attr & (ucp_am_recv_attr_t::UCP_AM_RECV_ATTR_FIELD_REPLY_EP as u64) != 0
Expand Down Expand Up @@ -309,6 +337,7 @@ impl<'a> Drop for AmMsg<'a> {
}
}

/// Active message stream.
#[derive(Clone)]
pub struct AmStream<'a> {
worker: &'a Worker,
Expand Down Expand Up @@ -383,9 +412,9 @@ impl Worker {
unsafe extern "C" fn callback(
arg: *mut c_void,
header: *const c_void,
header_len: u64,
header_len: usize,
data: *mut c_void,
data_len: u64,
data_len: usize,
param: *const ucp_am_recv_param_t,
) -> ucs_status_t {
let handler = &*(arg as *const AmStreamInner);
Expand Down Expand Up @@ -442,7 +471,9 @@ impl Worker {
}
}

/// Active message endpoint.
impl Endpoint {
/// Send active message.
pub async fn am_send(
&self,
id: u32,
Expand All @@ -456,6 +487,7 @@ impl Endpoint {
.await
}

/// Send active message.
pub async fn am_send_vectorized(
&self,
id: u32,
Expand All @@ -469,8 +501,11 @@ impl Endpoint {
}
}

/// Active message protocol
pub enum AmProto {
/// Eager protocol
Eager,
/// Rendezvous protocol
Rndv,
}

Expand Down Expand Up @@ -601,7 +636,7 @@ mod tests {
header.as_slice(),
data.as_slice(),
true,
Some(AmProto::Eager),
Some(AmProto::Rndv),
)
.await;
assert!(result.is_ok());
Expand All @@ -627,7 +662,10 @@ mod tests {
tokio::join!(
async {
// send reply
let result = unsafe { msg.reply(12, &header, &data, false, None).await };
let result = unsafe {
msg.reply(12, &header, &data, false, Some(AmProto::Rndv))
.await
};
assert!(result.is_ok());
},
async {
Expand Down
14 changes: 4 additions & 10 deletions src/ucp/endpoint/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ mod tag;
#[cfg(feature = "am")]
pub use self::am::*;
pub use self::rma::*;
pub use self::stream::*;
pub use self::tag::*;

// State associate with ucp_ep_h
// todo: Add a `get_user_data` to UCX
Expand Down Expand Up @@ -111,7 +109,7 @@ impl Endpoint {
arg: std::ptr::null_mut(), // override by user_data
};

let mut handle = MaybeUninit::uninit();
let mut handle = MaybeUninit::<*mut ucp_ep>::uninit();
let status = unsafe { ucp_ep_create(worker.handle, &params, handle.as_mut_ptr()) };
if let Err(err) = Error::from_status(status) {
// error happened, drop reference
Expand Down Expand Up @@ -142,7 +140,7 @@ impl Endpoint {
addrlen: sockaddr.len(),
},
err_mode: ucp_err_handling_mode_t::UCP_ERR_HANDLING_MODE_PEER,
..unsafe { MaybeUninit::uninit().assume_init() }
..unsafe { MaybeUninit::zeroed().assume_init() }
};
let endpoint = Endpoint::create(worker, params)?;

Expand All @@ -157,15 +155,13 @@ impl Endpoint {
worker: &Rc<Worker>,
addr: *const ucp_address_t,
) -> Result<Self, Error> {
#[allow(invalid_value)]
#[allow(clippy::uninit_assumed_init)]
let params = ucp_ep_params {
field_mask: (ucp_ep_params_field::UCP_EP_PARAM_FIELD_REMOTE_ADDRESS
| ucp_ep_params_field::UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE)
.0 as u64,
address: addr,
err_mode: ucp_err_handling_mode_t::UCP_ERR_HANDLING_MODE_PEER,
..unsafe { MaybeUninit::uninit().assume_init() }
..unsafe { MaybeUninit::zeroed().assume_init() }
};
Endpoint::create(worker, params)
}
Expand All @@ -174,12 +170,10 @@ impl Endpoint {
worker: &Rc<Worker>,
connection: ConnectionRequest,
) -> Result<Self, Error> {
#[allow(invalid_value)]
#[allow(clippy::uninit_assumed_init)]
let params = ucp_ep_params {
field_mask: ucp_ep_params_field::UCP_EP_PARAM_FIELD_CONN_REQUEST.0 as u64,
conn_request: connection.handle,
..unsafe { MaybeUninit::uninit().assume_init() }
..unsafe { MaybeUninit::zeroed().assume_init() }
};
let endpoint = Endpoint::create(worker, params)?;

Expand Down
12 changes: 6 additions & 6 deletions src/ucp/endpoint/rma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ impl MemoryHandle {
.0 as u64,
address: region.as_ptr() as _,
length: region.len() as _,
..unsafe { MaybeUninit::uninit().assume_init() }
..unsafe { MaybeUninit::zeroed().assume_init() }
};
let mut handle = MaybeUninit::uninit();
let mut handle = MaybeUninit::<*mut ucp_mem>::uninit();
let status = unsafe { ucp_mem_map(context.handle, &params, handle.as_mut_ptr()) };
assert_eq!(status, ucs_status_t::UCS_OK);
MemoryHandle {
Expand All @@ -32,8 +32,8 @@ impl MemoryHandle {

/// Packs into the buffer a remote access key (RKEY) object.
pub fn pack(&self) -> RKeyBuffer {
let mut buf = MaybeUninit::uninit();
let mut len = MaybeUninit::uninit();
let mut buf = MaybeUninit::<*mut c_void>::uninit();
let mut len = MaybeUninit::<usize>::uninit();
let status = unsafe {
ucp_rkey_pack(
self.context.handle,
Expand All @@ -60,7 +60,7 @@ impl Drop for MemoryHandle {
#[derive(Debug)]
pub struct RKeyBuffer {
buf: *mut c_void,
len: u64,
len: usize,
}

impl AsRef<[u8]> for RKeyBuffer {
Expand All @@ -87,7 +87,7 @@ unsafe impl Sync for RKey {}
impl RKey {
/// Create remote access key from packed buffer.
pub fn unpack(endpoint: &Endpoint, rkey_buffer: &[u8]) -> Self {
let mut handle = MaybeUninit::uninit();
let mut handle = MaybeUninit::<*mut ucp_rkey>::uninit();
let status = unsafe {
ucp_ep_rkey_unpack(
endpoint.handle,
Expand Down
4 changes: 2 additions & 2 deletions src/ucp/endpoint/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl Endpoint {
/// Receives data from stream.
pub async fn stream_recv(&self, buf: &mut [MaybeUninit<u8>]) -> Result<usize, Error> {
trace!("stream_recv: endpoint={:?} len={}", self.handle, buf.len());
unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, length: u64) {
unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t, length: usize) {
trace!(
"stream_recv: complete. req={:?}, status={:?}, len={}",
request,
Expand All @@ -50,7 +50,7 @@ impl Endpoint {
let request = &mut *(request as *mut Request);
request.waker.wake();
}
let mut length = MaybeUninit::uninit();
let mut length = MaybeUninit::<usize>::uninit();
let status = unsafe {
ucp_stream_recv_nb(
self.get_handle()?,
Expand Down
2 changes: 1 addition & 1 deletion src/ucp/endpoint/tag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ mod tests {
},
async {
// recv
let mut buf = vec![MaybeUninit::uninit(); msg_size];
let mut buf = vec![MaybeUninit::<u8>::uninit(); msg_size];
worker1.tag_recv(1, &mut buf).await.unwrap();
println!("tag recved");
}
Expand Down
Loading
Loading