@@ -18,8 +18,11 @@ impl Worker {
18
18
buf : & mut [ MaybeUninit < u8 > ] ,
19
19
) -> Result < ( u64 , usize ) , Error > {
20
20
match self . tag_recv_impl ( tag, tag_mask, buf) ? {
21
- Status :: Completed ( r) => r,
22
- Status :: Scheduled ( request_handle) => request_handle. await ,
21
+ Status :: Completed ( r) => r. map ( |info| ( info. sender_tag , info. length as usize ) ) ,
22
+ Status :: Scheduled ( request_handle) => {
23
+ let info = request_handle. await ?;
24
+ Ok ( ( info. sender_tag , info. length as usize ) )
25
+ }
23
26
}
24
27
}
25
28
@@ -70,15 +73,15 @@ impl Worker {
70
73
poll_fn : poll_tag,
71
74
}
72
75
. await
73
- . map ( |info| info. 1 )
76
+ . map ( |info| info. length as usize )
74
77
}
75
78
76
79
pub ( super ) fn tag_recv_impl (
77
80
& self ,
78
81
tag : u64 ,
79
82
tag_mask : u64 ,
80
83
buf : & mut [ MaybeUninit < u8 > ] ,
81
- ) -> Result < Status < ( u64 , usize ) > , Error > {
84
+ ) -> Result < Status < ucp_tag_recv_info > , Error > {
82
85
trace ! (
83
86
"tag_recv: worker={:?}, tag={}, mask={:#x} len={}" ,
84
87
self . handle,
@@ -104,7 +107,10 @@ impl Worker {
104
107
let request = & mut * ( request as * mut Request ) ;
105
108
request. waker . wake ( ) ;
106
109
}
107
- let param = RequestParam :: new ( ) . cb_tag_recv ( Some ( callback) ) ;
110
+ let mut info = MaybeUninit :: < ucp_tag_recv_info > :: uninit ( ) ;
111
+ let param = RequestParam :: new ( )
112
+ . cb_tag_recv ( Some ( callback) )
113
+ . recv_tag_info ( info. as_mut_ptr ( ) as _ ) ;
108
114
let status = unsafe {
109
115
ucp_tag_recv_nbx (
110
116
self . handle ,
@@ -115,7 +121,7 @@ impl Worker {
115
121
param. as_ref ( ) ,
116
122
)
117
123
} ;
118
- Ok ( Status :: from ( status, MaybeUninit :: uninit ( ) , poll_tag) )
124
+ Ok ( Status :: from ( status, info , poll_tag) )
119
125
}
120
126
}
121
127
@@ -208,14 +214,14 @@ impl Endpoint {
208
214
}
209
215
}
210
216
211
- fn poll_tag ( ptr : ucs_status_ptr_t ) -> Poll < Result < ( u64 , usize ) , Error > > {
217
+ fn poll_tag ( ptr : ucs_status_ptr_t ) -> Poll < Result < ucp_tag_recv_info , Error > > {
212
218
let mut info = MaybeUninit :: < ucp_tag_recv_info > :: uninit ( ) ;
213
219
let status = unsafe { ucp_tag_recv_request_test ( ptr as _ , info. as_mut_ptr ( ) as _ ) } ;
214
220
match status {
215
221
ucs_status_t:: UCS_INPROGRESS => Poll :: Pending ,
216
222
ucs_status_t:: UCS_OK => {
217
223
let info = unsafe { info. assume_init ( ) } ;
218
- Poll :: Ready ( Ok ( ( info. sender_tag , info . length as usize ) ) )
224
+ Poll :: Ready ( Ok ( info) )
219
225
}
220
226
status => Poll :: Ready ( Err ( Error :: from_error ( status) ) ) ,
221
227
}
@@ -281,4 +287,64 @@ mod tests {
281
287
assert_eq ! ( endpoint2. close( true ) . await , Ok ( ( ) ) ) ;
282
288
assert_eq ! ( endpoint2. get_rc( ) , ( 1 , 0 ) ) ;
283
289
}
290
+
291
+ #[ test_log:: test]
292
+ fn multi_tag ( ) {
293
+ for i in 0 ..20_usize {
294
+ spawn_thread ! ( _multi_tag( 4 << i) ) . join ( ) . unwrap ( ) ;
295
+ }
296
+ }
297
+
298
+ async fn _multi_tag ( msg_size : usize ) {
299
+ let context1 = Context :: new ( ) . unwrap ( ) ;
300
+ let worker1 = context1. create_worker ( ) . unwrap ( ) ;
301
+ let context2 = Context :: new ( ) . unwrap ( ) ;
302
+ let worker2 = context2. create_worker ( ) . unwrap ( ) ;
303
+ tokio:: task:: spawn_local ( worker1. clone ( ) . polling ( ) ) ;
304
+ tokio:: task:: spawn_local ( worker2. clone ( ) . polling ( ) ) ;
305
+
306
+ // connect with each other
307
+ let mut listener = worker1
308
+ . create_listener ( "0.0.0.0:0" . parse ( ) . unwrap ( ) )
309
+ . unwrap ( ) ;
310
+ let listen_port = listener. socket_addr ( ) . unwrap ( ) . port ( ) ;
311
+ println ! ( "listen at port {}" , listen_port) ;
312
+ let mut addr: SocketAddr = "127.0.0.1:0" . parse ( ) . unwrap ( ) ;
313
+ addr. set_port ( listen_port) ;
314
+
315
+ let ( endpoint1, endpoint2) = tokio:: join!(
316
+ async {
317
+ let conn1 = listener. next( ) . await ;
318
+ worker1. accept( conn1) . await . unwrap( )
319
+ } ,
320
+ async { worker2. connect_socket( addr) . await . unwrap( ) } ,
321
+ ) ;
322
+
323
+ // send tag message
324
+ tokio:: join!(
325
+ async {
326
+ // send
327
+ let mut buf = vec![ 0 ; msg_size] ;
328
+ endpoint2. tag_send( 3 , & mut buf) . await . unwrap( ) ;
329
+ println!( "tag sended" ) ;
330
+ } ,
331
+ async {
332
+ // recv
333
+ let mut buf = vec![ MaybeUninit :: <u8 >:: uninit( ) ; msg_size] ;
334
+ let ( tag, size) = worker1. tag_recv_mask( 0 , 0 , & mut buf) . await . unwrap( ) ;
335
+ assert_eq!( size, msg_size) ;
336
+ assert_eq!( tag, 3 ) ;
337
+ println!( "tag recved" ) ;
338
+ }
339
+ ) ;
340
+
341
+ assert_eq ! ( endpoint1. get_rc( ) , ( 1 , 1 ) ) ;
342
+ assert_eq ! ( endpoint2. get_rc( ) , ( 1 , 1 ) ) ;
343
+ assert_eq ! ( endpoint1. close( false ) . await , Ok ( ( ) ) ) ;
344
+ assert_eq ! ( endpoint2. close( false ) . await , Err ( Error :: ConnectionReset ) ) ;
345
+ assert_eq ! ( endpoint1. get_rc( ) , ( 1 , 0 ) ) ;
346
+ assert_eq ! ( endpoint2. get_rc( ) , ( 1 , 1 ) ) ;
347
+ assert_eq ! ( endpoint2. close( true ) . await , Ok ( ( ) ) ) ;
348
+ assert_eq ! ( endpoint2. get_rc( ) , ( 1 , 0 ) ) ;
349
+ }
284
350
}
0 commit comments