Skip to content

Commit bfe5896

Browse files
authored
Fix the crashing of sending with AmProto::Eager proto (#8)
* Manually drop the case AmData::Data According to the document of UCP_AM_RECV_ATTR_FLAG_DATA, it should call ucp_am_data_release after using AmData::Data. * format
1 parent acb197a commit bfe5896

File tree

1 file changed

+135
-113
lines changed

1 file changed

+135
-113
lines changed

src/ucp/endpoint/am.rs

Lines changed: 135 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ impl<'a> AmMsg<'a> {
109109
AmMsg { worker, msg }
110110
}
111111

112-
/// Get the message ID.
112+
/// Get the ActiveStream id
113113
#[inline]
114114
pub fn id(&self) -> u16 {
115115
self.msg.id
@@ -121,10 +121,10 @@ impl<'a> AmMsg<'a> {
121121
self.msg.header.as_ref()
122122
}
123123

124-
/// Get the message header length.
124+
/// Returns `true` if the message contains data. Otherwise, `false`.
125125
#[inline]
126126
pub fn contains_data(&self) -> bool {
127-
self.data_type().is_some()
127+
self.msg.data.is_some()
128128
}
129129

130130
/// Get the message data type.
@@ -133,10 +133,14 @@ impl<'a> AmMsg<'a> {
133133
}
134134

135135
/// Get the message data.
136-
/// Returns `None` if the message doesn't contain data.
136+
/// Returns `None` if needs to receive data.
137+
/// Returns `Some(slice)` if the message contains concrete data.
137138
#[inline]
138139
pub fn get_data(&self) -> Option<&[u8]> {
139-
self.msg.data.as_ref().and_then(|data| data.data())
140+
match self.msg.data {
141+
Some(ref amdata) => amdata.data(),
142+
None => Some(&[]),
143+
}
140144
}
141145

142146
/// Get the message data length.
@@ -151,6 +155,11 @@ impl<'a> AmMsg<'a> {
151155
match self.msg.data.take() {
152156
None => Ok(Vec::new()),
153157
Some(AmData::Eager(vec)) => Ok(vec),
158+
Some(AmData::Data(data)) => {
159+
let v = data.to_vec();
160+
self.drop_msg(AmData::Data(data));
161+
Ok(v)
162+
}
154163
Some(data) => {
155164
self.msg.data = Some(data);
156165
let mut buf = Vec::with_capacity(self.data_len());
@@ -181,104 +190,110 @@ impl<'a> AmMsg<'a> {
181190

182191
/// Receive the message data.
183192
pub async fn recv_data_vectored(&mut self, iov: &[IoSliceMut<'_>]) -> Result<usize, Error> {
184-
let data = self.msg.data.take();
185-
if let Some(data) = data {
186-
if let AmData::Eager(data) = data {
187-
// return error if buffer size < data length, same with ucx
188-
let cap = iov.iter().fold(0_usize, |cap, buf| cap + buf.len());
189-
if cap < data.len() {
190-
return Err(Error::MessageTruncated);
191-
}
193+
fn copy_data_to_iov(data: &[u8], iov: &[IoSliceMut<'_>]) -> Result<usize, Error> {
194+
// return error if buffer size < data length, same with ucx
195+
let cap = iov.iter().fold(0_usize, |cap, buf| cap + buf.len());
196+
if cap < data.len() {
197+
return Err(Error::MessageTruncated);
198+
}
192199

193-
let mut copied = 0_usize;
194-
for buf in iov {
195-
let len = std::cmp::min(data.len() - copied, buf.len());
196-
if len == 0 {
197-
break;
198-
}
200+
let mut copied = 0_usize;
201+
for buf in iov {
202+
let len = std::cmp::min(data.len() - copied, buf.len());
203+
if len == 0 {
204+
break;
205+
}
199206

200-
let buf = &buf[..len];
201-
unsafe {
202-
std::ptr::copy_nonoverlapping(
203-
data[copied..].as_ptr(),
204-
buf.as_ptr() as _,
205-
len,
206-
)
207-
}
208-
copied += len;
207+
let buf = &buf[..len];
208+
unsafe {
209+
std::ptr::copy_nonoverlapping(data[copied..].as_ptr(), buf.as_ptr() as _, len)
209210
}
210-
return Ok(copied);
211+
copied += len;
211212
}
213+
Ok(copied)
214+
}
215+
let data = self.msg.data.take();
212216

213-
let (data_desc, data_len) = match data {
214-
AmData::Data(data) => (data.as_ptr(), data.len()),
215-
AmData::Rndv(data) => (data.as_ptr(), data.len()),
216-
_ => unreachable!(),
217-
};
218-
219-
unsafe extern "C" fn callback(
220-
request: *mut c_void,
221-
status: ucs_status_t,
222-
_length: usize,
223-
_data: *mut c_void,
224-
) {
225-
// todo: handle error & fix real data length
217+
match data {
218+
Some(AmData::Eager(data)) => {
219+
// eager message, no need to receive
220+
copy_data_to_iov(&data, iov)
221+
}
222+
Some(AmData::Data(data)) => {
223+
// data message, no need to receive
224+
let size = copy_data_to_iov(&data, iov)?;
225+
self.drop_msg(AmData::Data(data));
226+
Ok(size)
227+
}
228+
Some(AmData::Rndv(desc)) => {
229+
// rndv message, need to receive
230+
let (data_desc, data_len) = (desc.as_ptr(), desc.len());
231+
232+
unsafe extern "C" fn callback(
233+
request: *mut c_void,
234+
status: ucs_status_t,
235+
_length: usize,
236+
_data: *mut c_void,
237+
) {
238+
// todo: handle error & fix real data length
239+
trace!(
240+
"recv_data_vectored: complete, req={:?}, status={:?}",
241+
request,
242+
status
243+
);
244+
let request = &mut *(request as *mut Request);
245+
request.waker.wake();
246+
}
226247
trace!(
227-
"recv_data_vectored: complete, req={:?}, status={:?}",
228-
request,
229-
status
248+
"recv_data_vectored: worker={:?} iov.len={}",
249+
self.worker.handle,
250+
iov.len()
230251
);
231-
let request = &mut *(request as *mut Request);
232-
request.waker.wake();
233-
}
234-
trace!(
235-
"recv_data_vectored: worker={:?} iov.len={}",
236-
self.worker.handle,
237-
iov.len()
238-
);
239-
let mut param = MaybeUninit::<ucp_request_param_t>::uninit();
240-
let (buffer, count) = unsafe {
241-
let param = &mut *param.as_mut_ptr();
242-
param.op_attr_mask = ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32
243-
| ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32;
244-
param.cb = ucp_request_param_t__bindgen_ty_1 {
245-
recv_am: Some(callback),
252+
let mut param = MaybeUninit::<ucp_request_param_t>::uninit();
253+
let (buffer, count) = unsafe {
254+
let param = &mut *param.as_mut_ptr();
255+
param.op_attr_mask = ucp_op_attr_t::UCP_OP_ATTR_FIELD_CALLBACK as u32
256+
| ucp_op_attr_t::UCP_OP_ATTR_FIELD_DATATYPE as u32;
257+
param.cb = ucp_request_param_t__bindgen_ty_1 {
258+
recv_am: Some(callback),
259+
};
260+
261+
if iov.len() == 1 {
262+
param.datatype = ucp_dt_make_contig(1);
263+
(iov[0].as_ptr(), iov[0].len())
264+
} else {
265+
param.datatype = ucp_dt_type::UCP_DATATYPE_IOV as _;
266+
(iov.as_ptr() as _, iov.len())
267+
}
246268
};
247269

248-
if iov.len() == 1 {
249-
param.datatype = ucp_dt_make_contig(1);
250-
(iov[0].as_ptr(), iov[0].len())
270+
let status = unsafe {
271+
ucp_am_recv_data_nbx(
272+
self.worker.handle,
273+
data_desc as _,
274+
buffer as _,
275+
count as _,
276+
param.as_ptr(),
277+
)
278+
};
279+
if status.is_null() {
280+
trace!("recv_data_vectored: complete");
281+
Ok(data_len)
282+
} else if UCS_PTR_IS_PTR(status) {
283+
RequestHandle {
284+
ptr: status,
285+
poll_fn: poll_recv,
286+
}
287+
.await;
288+
Ok(data_len)
251289
} else {
252-
param.datatype = ucp_dt_type::UCP_DATATYPE_IOV as _;
253-
(iov.as_ptr() as _, iov.len())
290+
Err(Error::from_ptr(status).unwrap_err())
254291
}
255-
};
256-
257-
let status = unsafe {
258-
ucp_am_recv_data_nbx(
259-
self.worker.handle,
260-
data_desc as _,
261-
buffer as _,
262-
count as _,
263-
param.as_ptr(),
264-
)
265-
};
266-
if status.is_null() {
267-
trace!("recv_data_vectored: complete");
268-
Ok(data_len)
269-
} else if UCS_PTR_IS_PTR(status) {
270-
RequestHandle {
271-
ptr: status,
272-
poll_fn: poll_recv,
273-
}
274-
.await;
275-
Ok(data_len)
276-
} else {
277-
Err(Error::from_ptr(status).unwrap_err())
278292
}
279-
} else {
280-
// no data
281-
Ok(0)
293+
None => {
294+
// no data
295+
Ok(0)
296+
}
282297
}
283298
}
284299

@@ -321,18 +336,24 @@ impl<'a> AmMsg<'a> {
321336
assert!(self.need_reply());
322337
am_send(self.msg.reply_ep, id, header, data, need_reply, proto).await
323338
}
339+
340+
fn drop_msg(&mut self, data: AmData) {
341+
match data {
342+
AmData::Eager(_) => (),
343+
AmData::Data(data) => unsafe {
344+
ucp_am_data_release(self.worker.handle, data.as_ptr() as _);
345+
},
346+
AmData::Rndv(data) => unsafe {
347+
ucp_am_data_release(self.worker.handle, data.as_ptr() as _);
348+
},
349+
}
350+
}
324351
}
325352

326353
impl<'a> Drop for AmMsg<'a> {
327354
fn drop(&mut self) {
328-
match self.msg.data.take() {
329-
Some(AmData::Data(desc)) => unsafe {
330-
ucp_am_data_release(self.worker.handle, desc.as_ptr() as _);
331-
},
332-
Some(AmData::Rndv(desc)) => unsafe {
333-
ucp_am_data_release(self.worker.handle, desc.as_ptr() as _);
334-
},
335-
_ => (),
355+
if let Some(data) = self.msg.data.take() {
356+
self.drop_msg(data);
336357
}
337358
}
338359
}
@@ -502,6 +523,8 @@ impl Endpoint {
502523
}
503524

504525
/// Active message protocol
526+
#[derive(Debug, Clone, Copy)]
527+
#[repr(u32)]
505528
pub enum AmProto {
506529
/// Eager protocol
507530
Eager,
@@ -594,12 +617,20 @@ mod tests {
594617

595618
#[test_log::test]
596619
fn am() {
597-
for i in 0..20_usize {
598-
spawn_thread!(send_recv(4 << i)).join().unwrap();
620+
let protos = vec![None, Some(AmProto::Eager), Some(AmProto::Rndv)];
621+
for block_size_shift in 0..20_usize {
622+
for p in protos.iter() {
623+
let rt = tokio::runtime::Builder::new_current_thread()
624+
.enable_time()
625+
.build()
626+
.unwrap();
627+
let local = tokio::task::LocalSet::new();
628+
local.block_on(&rt, send_recv(4 << block_size_shift, *p));
629+
}
599630
}
600631
}
601632

602-
async fn send_recv(data_size: usize) {
633+
async fn send_recv(data_size: usize, proto: Option<AmProto>) {
603634
let context1 = Context::new().unwrap();
604635
let worker1 = context1.create_worker().unwrap();
605636
let context2 = Context::new().unwrap();
@@ -631,13 +662,7 @@ mod tests {
631662
async {
632663
// send msg
633664
let result = endpoint2
634-
.am_send(
635-
16,
636-
header.as_slice(),
637-
data.as_slice(),
638-
true,
639-
Some(AmProto::Rndv),
640-
)
665+
.am_send(16, header.as_slice(), data.as_slice(), true, proto)
641666
.await;
642667
assert!(result.is_ok());
643668
},
@@ -662,10 +687,7 @@ mod tests {
662687
tokio::join!(
663688
async {
664689
// send reply
665-
let result = unsafe {
666-
msg.reply(12, &header, &data, false, Some(AmProto::Rndv))
667-
.await
668-
};
690+
let result = unsafe { msg.reply(12, &header, &data, false, proto).await };
669691
assert!(result.is_ok());
670692
},
671693
async {

0 commit comments

Comments
 (0)