Skip to content

Commit 09b79fd

Browse files
committed
feat: improve tag receive implementation with ucp_tag_recv_info
1 parent 35ce5c4 commit 09b79fd

File tree

3 files changed

+88
-15
lines changed

3 files changed

+88
-15
lines changed

src/ucp/endpoint/param.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ impl RequestParam {
4242
self
4343
}
4444

45+
pub fn recv_tag_info(mut self, info: *mut ucp_tag_recv_info) -> Self {
46+
self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_RECV_INFO as u32;
47+
self.inner.recv_info.tag_info = info;
48+
self
49+
}
50+
4551
#[cfg(feature = "am")]
4652
pub fn cb_recv_am(mut self, callback: ucp_am_recv_data_nbx_callback_t) -> Self {
4753
self.inner.op_attr_mask |= ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32;

src/ucp/endpoint/tag.rs

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@ impl Worker {
1818
buf: &mut [MaybeUninit<u8>],
1919
) -> Result<(u64, usize), Error> {
2020
match self.tag_recv_impl(tag, tag_mask, buf)? {
21-
Status::Completed(r) => r,
22-
Status::Scheduled(request_handle) => request_handle.await,
21+
Status::Completed(r) => r.map(|info| (info.sender_tag, info.length as usize)),
22+
Status::Scheduled(request_handle) => {
23+
let info = request_handle.await?;
24+
Ok((info.sender_tag, info.length as usize))
25+
}
2326
}
2427
}
2528

@@ -70,15 +73,15 @@ impl Worker {
7073
poll_fn: poll_tag,
7174
}
7275
.await
73-
.map(|info| info.1)
76+
.map(|info| info.length as usize)
7477
}
7578

7679
pub(super) fn tag_recv_impl(
7780
&self,
7881
tag: u64,
7982
tag_mask: u64,
8083
buf: &mut [MaybeUninit<u8>],
81-
) -> Result<Status<(u64, usize)>, Error> {
84+
) -> Result<Status<ucp_tag_recv_info>, Error> {
8285
trace!(
8386
"tag_recv: worker={:?}, tag={}, mask={:#x} len={}",
8487
self.handle,
@@ -104,7 +107,10 @@ impl Worker {
104107
let request = &mut *(request as *mut Request);
105108
request.waker.wake();
106109
}
107-
let param = RequestParam::new().cb_tag_recv(Some(callback));
110+
let mut info = MaybeUninit::<ucp_tag_recv_info>::uninit();
111+
let param = RequestParam::new()
112+
.cb_tag_recv(Some(callback))
113+
.recv_tag_info(info.as_mut_ptr() as _);
108114
let status = unsafe {
109115
ucp_tag_recv_nbx(
110116
self.handle,
@@ -115,7 +121,7 @@ impl Worker {
115121
param.as_ref(),
116122
)
117123
};
118-
Ok(Status::from(status, MaybeUninit::uninit(), poll_tag))
124+
Ok(Status::from(status, info, poll_tag))
119125
}
120126
}
121127

@@ -208,14 +214,14 @@ impl Endpoint {
208214
}
209215
}
210216

211-
fn poll_tag(ptr: ucs_status_ptr_t) -> Poll<Result<(u64, usize), Error>> {
217+
fn poll_tag(ptr: ucs_status_ptr_t) -> Poll<Result<ucp_tag_recv_info, Error>> {
212218
let mut info = MaybeUninit::<ucp_tag_recv_info>::uninit();
213219
let status = unsafe { ucp_tag_recv_request_test(ptr as _, info.as_mut_ptr() as _) };
214220
match status {
215221
ucs_status_t::UCS_INPROGRESS => Poll::Pending,
216222
ucs_status_t::UCS_OK => {
217223
let info = unsafe { info.assume_init() };
218-
Poll::Ready(Ok((info.sender_tag, info.length as usize)))
224+
Poll::Ready(Ok(info))
219225
}
220226
status => Poll::Ready(Err(Error::from_error(status))),
221227
}
@@ -281,4 +287,64 @@ mod tests {
281287
assert_eq!(endpoint2.close(true).await, Ok(()));
282288
assert_eq!(endpoint2.get_rc(), (1, 0));
283289
}
290+
291+
#[test_log::test]
292+
fn multi_tag() {
293+
for i in 0..20_usize {
294+
spawn_thread!(_multi_tag(4 << i)).join().unwrap();
295+
}
296+
}
297+
298+
async fn _multi_tag(msg_size: usize) {
299+
let context1 = Context::new().unwrap();
300+
let worker1 = context1.create_worker().unwrap();
301+
let context2 = Context::new().unwrap();
302+
let worker2 = context2.create_worker().unwrap();
303+
tokio::task::spawn_local(worker1.clone().polling());
304+
tokio::task::spawn_local(worker2.clone().polling());
305+
306+
// connect with each other
307+
let mut listener = worker1
308+
.create_listener("0.0.0.0:0".parse().unwrap())
309+
.unwrap();
310+
let listen_port = listener.socket_addr().unwrap().port();
311+
println!("listen at port {}", listen_port);
312+
let mut addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
313+
addr.set_port(listen_port);
314+
315+
let (endpoint1, endpoint2) = tokio::join!(
316+
async {
317+
let conn1 = listener.next().await;
318+
worker1.accept(conn1).await.unwrap()
319+
},
320+
async { worker2.connect_socket(addr).await.unwrap() },
321+
);
322+
323+
// send tag message
324+
tokio::join!(
325+
async {
326+
// send
327+
let mut buf = vec![0; msg_size];
328+
endpoint2.tag_send(3, &mut buf).await.unwrap();
329+
println!("tag sended");
330+
},
331+
async {
332+
// recv
333+
let mut buf = vec![MaybeUninit::<u8>::uninit(); msg_size];
334+
let (tag, size) = worker1.tag_recv_mask(0, 0, &mut buf).await.unwrap();
335+
assert_eq!(size, msg_size);
336+
assert_eq!(tag, 3);
337+
println!("tag recved");
338+
}
339+
);
340+
341+
assert_eq!(endpoint1.get_rc(), (1, 1));
342+
assert_eq!(endpoint2.get_rc(), (1, 1));
343+
assert_eq!(endpoint1.close(false).await, Ok(()));
344+
assert_eq!(endpoint2.close(false).await, Err(Error::ConnectionReset));
345+
assert_eq!(endpoint1.get_rc(), (1, 0));
346+
assert_eq!(endpoint2.get_rc(), (1, 1));
347+
assert_eq!(endpoint2.close(true).await, Ok(()));
348+
assert_eq!(endpoint2.get_rc(), (1, 0));
349+
}
284350
}

src/ucp/endpoint/util.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ impl Worker {
4242
}
4343
}
4444
/// make tag read stream with mask
45+
/// not suggested to use this function, because actual received tag should be checked by user
4546
pub fn tag_read_stream_mask(&self, tag: u64, tag_mask: u64) -> TagReadStream {
4647
TagReadStream {
4748
worker: self,
@@ -211,7 +212,7 @@ pub struct TagReadStream<'a> {
211212
tag: u64,
212213
tag_mask: u64,
213214
#[pin]
214-
request: Option<RequestHandle<Result<(u64, usize), Error>>>,
215+
request: Option<RequestHandle<Result<ucp_tag_recv_info, Error>>>,
215216
}
216217

217218
impl<'a> AsyncRead for TagReadStream<'a> {
@@ -222,11 +223,11 @@ impl<'a> AsyncRead for TagReadStream<'a> {
222223
) -> Poll<Result<(), std::io::Error>> {
223224
if let Some(mut req) = self.as_mut().project().request.as_pin_mut() {
224225
let r = match ready!(req.poll_unpin(cx)) {
225-
Ok((_, n)) => {
226+
Ok(info) => {
226227
// Safety: The buffer was filled by the recv operation.
227228
unsafe {
228-
out_buf.assume_init(n);
229-
out_buf.advance(n);
229+
out_buf.assume_init(info.length as usize);
230+
out_buf.advance(info.length as usize);
230231
}
231232
Ok(())
232233
}
@@ -239,11 +240,11 @@ impl<'a> AsyncRead for TagReadStream<'a> {
239240
match self.worker.tag_recv_impl(self.tag, self.tag_mask, buf) {
240241
Ok(Status::Completed(n_result)) => {
241242
match n_result {
242-
Ok((_, n)) => {
243+
Ok(info) => {
243244
// Safety: The buffer was filled by the recv operation.
244245
unsafe {
245-
out_buf.assume_init(n);
246-
out_buf.advance(n);
246+
out_buf.assume_init(info.length as usize);
247+
out_buf.advance(info.length as usize);
247248
}
248249
Poll::Ready(Ok(()))
249250
}

0 commit comments

Comments
 (0)