@@ -109,7 +109,7 @@ impl<'a> AmMsg<'a> {
109
109
AmMsg { worker, msg }
110
110
}
111
111
112
- /// Get the message ID.
112
+ /// Get the ActiveStream id
113
113
#[ inline]
114
114
pub fn id ( & self ) -> u16 {
115
115
self . msg . id
@@ -121,10 +121,10 @@ impl<'a> AmMsg<'a> {
121
121
self . msg . header . as_ref ( )
122
122
}
123
123
124
- /// Get the message header length .
124
+ /// Returns `true` if the message contains data. Otherwise, `false` .
125
125
#[ inline]
126
126
pub fn contains_data ( & self ) -> bool {
127
- self . data_type ( ) . is_some ( )
127
+ self . msg . data . is_some ( )
128
128
}
129
129
130
130
/// Get the message data type.
@@ -133,10 +133,14 @@ impl<'a> AmMsg<'a> {
133
133
}
134
134
135
135
/// Get the message data.
136
- /// Returns `None` if the message doesn't contain data.
136
+ /// Returns `None` if needs to receive data.
137
+ /// Returns `Some(slice)` if the message contains concrete data.
137
138
#[ inline]
138
139
pub fn get_data ( & self ) -> Option < & [ u8 ] > {
139
- self . msg . data . as_ref ( ) . and_then ( |data| data. data ( ) )
140
+ match self . msg . data {
141
+ Some ( ref amdata) => amdata. data ( ) ,
142
+ None => Some ( & [ ] ) ,
143
+ }
140
144
}
141
145
142
146
/// Get the message data length.
@@ -151,6 +155,11 @@ impl<'a> AmMsg<'a> {
151
155
match self . msg . data . take ( ) {
152
156
None => Ok ( Vec :: new ( ) ) ,
153
157
Some ( AmData :: Eager ( vec) ) => Ok ( vec) ,
158
+ Some ( AmData :: Data ( data) ) => {
159
+ let v = data. to_vec ( ) ;
160
+ self . drop_msg ( AmData :: Data ( data) ) ;
161
+ Ok ( v)
162
+ }
154
163
Some ( data) => {
155
164
self . msg . data = Some ( data) ;
156
165
let mut buf = Vec :: with_capacity ( self . data_len ( ) ) ;
@@ -181,104 +190,110 @@ impl<'a> AmMsg<'a> {
181
190
182
191
/// Receive the message data.
183
192
pub async fn recv_data_vectored ( & mut self , iov : & [ IoSliceMut < ' _ > ] ) -> Result < usize , Error > {
184
- let data = self . msg . data . take ( ) ;
185
- if let Some ( data) = data {
186
- if let AmData :: Eager ( data) = data {
187
- // return error if buffer size < data length, same with ucx
188
- let cap = iov. iter ( ) . fold ( 0_usize , |cap, buf| cap + buf. len ( ) ) ;
189
- if cap < data. len ( ) {
190
- return Err ( Error :: MessageTruncated ) ;
191
- }
193
+ fn copy_data_to_iov ( data : & [ u8 ] , iov : & [ IoSliceMut < ' _ > ] ) -> Result < usize , Error > {
194
+ // return error if buffer size < data length, same with ucx
195
+ let cap = iov. iter ( ) . fold ( 0_usize , |cap, buf| cap + buf. len ( ) ) ;
196
+ if cap < data. len ( ) {
197
+ return Err ( Error :: MessageTruncated ) ;
198
+ }
192
199
193
- let mut copied = 0_usize ;
194
- for buf in iov {
195
- let len = std:: cmp:: min ( data. len ( ) - copied, buf. len ( ) ) ;
196
- if len == 0 {
197
- break ;
198
- }
200
+ let mut copied = 0_usize ;
201
+ for buf in iov {
202
+ let len = std:: cmp:: min ( data. len ( ) - copied, buf. len ( ) ) ;
203
+ if len == 0 {
204
+ break ;
205
+ }
199
206
200
- let buf = & buf[ ..len] ;
201
- unsafe {
202
- std:: ptr:: copy_nonoverlapping (
203
- data[ copied..] . as_ptr ( ) ,
204
- buf. as_ptr ( ) as _ ,
205
- len,
206
- )
207
- }
208
- copied += len;
207
+ let buf = & buf[ ..len] ;
208
+ unsafe {
209
+ std:: ptr:: copy_nonoverlapping ( data[ copied..] . as_ptr ( ) , buf. as_ptr ( ) as _ , len)
209
210
}
210
- return Ok ( copied) ;
211
+ copied += len ;
211
212
}
213
+ Ok ( copied)
214
+ }
215
+ let data = self . msg . data . take ( ) ;
212
216
213
- let ( data_desc, data_len) = match data {
214
- AmData :: Data ( data) => ( data. as_ptr ( ) , data. len ( ) ) ,
215
- AmData :: Rndv ( data) => ( data. as_ptr ( ) , data. len ( ) ) ,
216
- _ => unreachable ! ( ) ,
217
- } ;
218
-
219
- unsafe extern "C" fn callback (
220
- request : * mut c_void ,
221
- status : ucs_status_t ,
222
- _length : usize ,
223
- _data : * mut c_void ,
224
- ) {
225
- // todo: handle error & fix real data length
217
+ match data {
218
+ Some ( AmData :: Eager ( data) ) => {
219
+ // eager message, no need to receive
220
+ copy_data_to_iov ( & data, iov)
221
+ }
222
+ Some ( AmData :: Data ( data) ) => {
223
+ // data message, no need to receive
224
+ let size = copy_data_to_iov ( & data, iov) ?;
225
+ self . drop_msg ( AmData :: Data ( data) ) ;
226
+ Ok ( size)
227
+ }
228
+ Some ( AmData :: Rndv ( desc) ) => {
229
+ // rndv message, need to receive
230
+ let ( data_desc, data_len) = ( desc. as_ptr ( ) , desc. len ( ) ) ;
231
+
232
+ unsafe extern "C" fn callback (
233
+ request : * mut c_void ,
234
+ status : ucs_status_t ,
235
+ _length : usize ,
236
+ _data : * mut c_void ,
237
+ ) {
238
+ // todo: handle error & fix real data length
239
+ trace ! (
240
+ "recv_data_vectored: complete, req={:?}, status={:?}" ,
241
+ request,
242
+ status
243
+ ) ;
244
+ let request = & mut * ( request as * mut Request ) ;
245
+ request. waker . wake ( ) ;
246
+ }
226
247
trace ! (
227
- "recv_data_vectored: complete, req ={:?}, status={:? }" ,
228
- request ,
229
- status
248
+ "recv_data_vectored: worker ={:?} iov.len={ }" ,
249
+ self . worker . handle ,
250
+ iov . len ( )
230
251
) ;
231
- let request = & mut * ( request as * mut Request ) ;
232
- request. waker . wake ( ) ;
233
- }
234
- trace ! (
235
- "recv_data_vectored: worker={:?} iov.len={}" ,
236
- self . worker. handle,
237
- iov. len( )
238
- ) ;
239
- let mut param = MaybeUninit :: < ucp_request_param_t > :: uninit ( ) ;
240
- let ( buffer, count) = unsafe {
241
- let param = & mut * param. as_mut_ptr ( ) ;
242
- param. op_attr_mask = ucp_op_attr_t:: UCP_OP_ATTR_FIELD_CALLBACK as u32
243
- | ucp_op_attr_t:: UCP_OP_ATTR_FIELD_DATATYPE as u32 ;
244
- param. cb = ucp_request_param_t__bindgen_ty_1 {
245
- recv_am : Some ( callback) ,
252
+ let mut param = MaybeUninit :: < ucp_request_param_t > :: uninit ( ) ;
253
+ let ( buffer, count) = unsafe {
254
+ let param = & mut * param. as_mut_ptr ( ) ;
255
+ param. op_attr_mask = ucp_op_attr_t:: UCP_OP_ATTR_FIELD_CALLBACK as u32
256
+ | ucp_op_attr_t:: UCP_OP_ATTR_FIELD_DATATYPE as u32 ;
257
+ param. cb = ucp_request_param_t__bindgen_ty_1 {
258
+ recv_am : Some ( callback) ,
259
+ } ;
260
+
261
+ if iov. len ( ) == 1 {
262
+ param. datatype = ucp_dt_make_contig ( 1 ) ;
263
+ ( iov[ 0 ] . as_ptr ( ) , iov[ 0 ] . len ( ) )
264
+ } else {
265
+ param. datatype = ucp_dt_type:: UCP_DATATYPE_IOV as _ ;
266
+ ( iov. as_ptr ( ) as _ , iov. len ( ) )
267
+ }
246
268
} ;
247
269
248
- if iov. len ( ) == 1 {
249
- param. datatype = ucp_dt_make_contig ( 1 ) ;
250
- ( iov[ 0 ] . as_ptr ( ) , iov[ 0 ] . len ( ) )
270
+ let status = unsafe {
271
+ ucp_am_recv_data_nbx (
272
+ self . worker . handle ,
273
+ data_desc as _ ,
274
+ buffer as _ ,
275
+ count as _ ,
276
+ param. as_ptr ( ) ,
277
+ )
278
+ } ;
279
+ if status. is_null ( ) {
280
+ trace ! ( "recv_data_vectored: complete" ) ;
281
+ Ok ( data_len)
282
+ } else if UCS_PTR_IS_PTR ( status) {
283
+ RequestHandle {
284
+ ptr : status,
285
+ poll_fn : poll_recv,
286
+ }
287
+ . await ;
288
+ Ok ( data_len)
251
289
} else {
252
- param. datatype = ucp_dt_type:: UCP_DATATYPE_IOV as _ ;
253
- ( iov. as_ptr ( ) as _ , iov. len ( ) )
290
+ Err ( Error :: from_ptr ( status) . unwrap_err ( ) )
254
291
}
255
- } ;
256
-
257
- let status = unsafe {
258
- ucp_am_recv_data_nbx (
259
- self . worker . handle ,
260
- data_desc as _ ,
261
- buffer as _ ,
262
- count as _ ,
263
- param. as_ptr ( ) ,
264
- )
265
- } ;
266
- if status. is_null ( ) {
267
- trace ! ( "recv_data_vectored: complete" ) ;
268
- Ok ( data_len)
269
- } else if UCS_PTR_IS_PTR ( status) {
270
- RequestHandle {
271
- ptr : status,
272
- poll_fn : poll_recv,
273
- }
274
- . await ;
275
- Ok ( data_len)
276
- } else {
277
- Err ( Error :: from_ptr ( status) . unwrap_err ( ) )
278
292
}
279
- } else {
280
- // no data
281
- Ok ( 0 )
293
+ None => {
294
+ // no data
295
+ Ok ( 0 )
296
+ }
282
297
}
283
298
}
284
299
@@ -321,18 +336,24 @@ impl<'a> AmMsg<'a> {
321
336
assert ! ( self . need_reply( ) ) ;
322
337
am_send ( self . msg . reply_ep , id, header, data, need_reply, proto) . await
323
338
}
339
+
340
+ fn drop_msg ( & mut self , data : AmData ) {
341
+ match data {
342
+ AmData :: Eager ( _) => ( ) ,
343
+ AmData :: Data ( data) => unsafe {
344
+ ucp_am_data_release ( self . worker . handle , data. as_ptr ( ) as _ ) ;
345
+ } ,
346
+ AmData :: Rndv ( data) => unsafe {
347
+ ucp_am_data_release ( self . worker . handle , data. as_ptr ( ) as _ ) ;
348
+ } ,
349
+ }
350
+ }
324
351
}
325
352
326
353
impl < ' a > Drop for AmMsg < ' a > {
327
354
fn drop ( & mut self ) {
328
- match self . msg . data . take ( ) {
329
- Some ( AmData :: Data ( desc) ) => unsafe {
330
- ucp_am_data_release ( self . worker . handle , desc. as_ptr ( ) as _ ) ;
331
- } ,
332
- Some ( AmData :: Rndv ( desc) ) => unsafe {
333
- ucp_am_data_release ( self . worker . handle , desc. as_ptr ( ) as _ ) ;
334
- } ,
335
- _ => ( ) ,
355
+ if let Some ( data) = self . msg . data . take ( ) {
356
+ self . drop_msg ( data) ;
336
357
}
337
358
}
338
359
}
@@ -502,6 +523,8 @@ impl Endpoint {
502
523
}
503
524
504
525
/// Active message protocol
526
+ #[ derive( Debug , Clone , Copy ) ]
527
+ #[ repr( u32 ) ]
505
528
pub enum AmProto {
506
529
/// Eager protocol
507
530
Eager ,
@@ -594,12 +617,20 @@ mod tests {
594
617
595
618
#[ test_log:: test]
596
619
fn am ( ) {
597
- for i in 0 ..20_usize {
598
- spawn_thread ! ( send_recv( 4 << i) ) . join ( ) . unwrap ( ) ;
620
+ let protos = vec ! [ None , Some ( AmProto :: Eager ) , Some ( AmProto :: Rndv ) ] ;
621
+ for block_size_shift in 0 ..20_usize {
622
+ for p in protos. iter ( ) {
623
+ let rt = tokio:: runtime:: Builder :: new_current_thread ( )
624
+ . enable_time ( )
625
+ . build ( )
626
+ . unwrap ( ) ;
627
+ let local = tokio:: task:: LocalSet :: new ( ) ;
628
+ local. block_on ( & rt, send_recv ( 4 << block_size_shift, * p) ) ;
629
+ }
599
630
}
600
631
}
601
632
602
- async fn send_recv ( data_size : usize ) {
633
+ async fn send_recv ( data_size : usize , proto : Option < AmProto > ) {
603
634
let context1 = Context :: new ( ) . unwrap ( ) ;
604
635
let worker1 = context1. create_worker ( ) . unwrap ( ) ;
605
636
let context2 = Context :: new ( ) . unwrap ( ) ;
@@ -631,13 +662,7 @@ mod tests {
631
662
async {
632
663
// send msg
633
664
let result = endpoint2
634
- . am_send(
635
- 16 ,
636
- header. as_slice( ) ,
637
- data. as_slice( ) ,
638
- true ,
639
- Some ( AmProto :: Rndv ) ,
640
- )
665
+ . am_send( 16 , header. as_slice( ) , data. as_slice( ) , true , proto)
641
666
. await ;
642
667
assert!( result. is_ok( ) ) ;
643
668
} ,
@@ -662,10 +687,7 @@ mod tests {
662
687
tokio:: join!(
663
688
async {
664
689
// send reply
665
- let result = unsafe {
666
- msg. reply( 12 , & header, & data, false , Some ( AmProto :: Rndv ) )
667
- . await
668
- } ;
690
+ let result = unsafe { msg. reply( 12 , & header, & data, false , proto) . await } ;
669
691
assert!( result. is_ok( ) ) ;
670
692
} ,
671
693
async {
0 commit comments