Skip to content

Commit 6466c29

Browse files
committed
simplify creation of WakerArray and WakerVec and make waker_from_redirect_position safer
1 parent 4f164de commit 6466c29

File tree

4 files changed

+78
-77
lines changed

4 files changed

+78
-77
lines changed

Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ harness = false
2828
bitvec = { version = "1.0.1", default-features = false, features = ["alloc"] }
2929
futures-core = "0.3"
3030
pin-project = "1.0.8"
31-
sptr = "0.3.2"
3231

3332
[dev-dependencies]
3433
futures = "0.3.25"

src/utils/wakers/array/waker_array.rs

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use core::array;
22
use core::task::Waker;
3-
use std::sync::{Arc, Mutex};
3+
use std::sync::{Arc, Mutex, Weak};
44

55
use super::{
66
super::shared_arc::{waker_from_redirect_position, SharedArcContent},
@@ -22,30 +22,18 @@ struct WakerArrayInner<const N: usize> {
2222
impl<const N: usize> WakerArray<N> {
2323
/// Create a new instance of `WakerArray`.
2424
pub(crate) fn new() -> Self {
25-
let mut inner = Arc::new(WakerArrayInner {
26-
readiness: Mutex::new(ReadinessArray::new()),
27-
redirect: [std::ptr::null(); N], // We don't know the Arc's address yet so put null for now.
28-
});
29-
let raw = Arc::into_raw(Arc::clone(&inner)); // The Arc's address.
30-
31-
// At this point the strong count is 2. Decrement it to 1.
32-
// Each time we create/clone a Waker the count will be incremented by 1.
33-
// So N Wakers -> count = N+1.
34-
unsafe { Arc::decrement_strong_count(raw) }
35-
36-
// Make redirect all point to the Arc itself.
37-
Arc::get_mut(&mut inner).unwrap().redirect = [raw; N];
38-
39-
// Now the redirect array is complete. Time to create the actual Wakers.
40-
let wakers = array::from_fn(|i| {
41-
let data = inner.redirect.get(i).unwrap();
42-
unsafe {
43-
waker_from_redirect_position::<WakerArrayInner<N>>(
44-
data as *const *const WakerArrayInner<N>,
45-
)
25+
let inner = Arc::new_cyclic(|w| {
26+
// `Weak::as_ptr` on a live Weak gives the same thing as `Arc::into_raw`.
27+
let raw = Weak::as_ptr(w);
28+
WakerArrayInner {
29+
readiness: Mutex::new(ReadinessArray::new()),
30+
redirect: [raw; N],
4631
}
4732
});
4833

34+
let wakers =
35+
array::from_fn(|i| unsafe { waker_from_redirect_position(Arc::clone(&inner), i) });
36+
4937
Self { inner, wakers }
5038
}
5139

@@ -59,7 +47,8 @@ impl<const N: usize> WakerArray<N> {
5947
}
6048
}
6149

62-
impl<const N: usize> SharedArcContent for WakerArrayInner<N> {
50+
#[deny(unsafe_op_in_unsafe_fn)]
51+
unsafe impl<const N: usize> SharedArcContent for WakerArrayInner<N> {
6352
fn get_redirect_slice(&self) -> &[*const Self] {
6453
&self.redirect
6554
}
@@ -84,7 +73,10 @@ mod tests {
8473
#[test]
8574
fn check_refcount() {
8675
let mut wa = WakerArray::<5>::new();
76+
77+
// Each waker holds 1 ref, and the combinator itself holds 1.
8778
assert_eq!(Arc::strong_count(&wa.inner), 6);
79+
8880
wa.wakers[4] = dummy_waker();
8981
assert_eq!(Arc::strong_count(&wa.inner), 5);
9082
let cloned = wa.wakers[3].clone();
@@ -105,13 +97,17 @@ mod tests {
10597
let taken = std::mem::replace(&mut wa.wakers[2], dummy_waker());
10698
assert_eq!(Arc::strong_count(&wa.inner), 4);
10799
taken.wake_by_ref();
108-
taken.wake_by_ref();
109-
taken.wake_by_ref();
100+
assert_eq!(Arc::strong_count(&wa.inner), 4);
101+
taken.clone().wake();
110102
assert_eq!(Arc::strong_count(&wa.inner), 4);
111103
taken.wake();
112104
assert_eq!(Arc::strong_count(&wa.inner), 3);
113105

114106
wa.wakers = array::from_fn(|_| dummy_waker());
115107
assert_eq!(Arc::strong_count(&wa.inner), 1);
108+
109+
let weak = Arc::downgrade(&wa.inner);
110+
drop(wa);
111+
assert_eq!(weak.strong_count(), 0);
116112
}
117113
}

src/utils/wakers/shared_arc.rs

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -43,43 +43,46 @@ use std::sync::Arc;
4343

4444
/// A trait to be implemented on [super::WakerArray] and [super::WakerVec] for polymorphism.
4545
/// These are the type that goes in the Arc. They both contain the Readiness and the redirect array/vec.
46-
pub(super) trait SharedArcContent {
47-
/// Get the reference of the redirect slice. This is used to compute the index.
46+
/// # Safety
47+
/// The `get_redirect_slice` method MUST always return the same slice for the same self.
48+
pub(super) unsafe trait SharedArcContent {
49+
/// Get the reference of the redirect slice.
4850
fn get_redirect_slice(&self) -> &[*const Self];
51+
4952
/// Called when the subfuture at the specified index should be polled.
5053
/// Should call `Readiness::set_ready`.
5154
fn wake_index(&self, index: usize);
5255
}
5356

5457
/// Create one waker following the mechanism described in the [module][self] doc.
55-
/// The following must be upheld for safety:
56-
/// - `pointer` must points to a slot in the redirect array.
57-
/// - that slot must contain a pointer obtained by `Arc::<A>::into_raw`.
58-
/// - the Arc must still be alive at the time this function is called.
59-
/// The following should be upheld for correct behavior:
60-
/// - calling `SharedArcContent::get_redirect_slice` on the content of the Arc should give the redirect array within which `pointer` points to.
58+
/// For safety, the index MUST be within bounds of the slice returned by `A::get_redirect_slice()`.
6159
#[deny(unsafe_op_in_unsafe_fn)]
6260
pub(super) unsafe fn waker_from_redirect_position<A: SharedArcContent>(
63-
pointer: *const *const A,
61+
arc: Arc<A>,
62+
index: usize,
6463
) -> Waker {
65-
/// Create a Waker from a type-erased pointer.
66-
/// The pointer must satisfy the safety constraints listed in the wrapping function's documentation.
67-
unsafe fn create_waker<A: SharedArcContent>(pointer: *const ()) -> RawWaker {
64+
// For `create_waker`, `wake_by_ref`, `wake`, and `drop_waker`, the following MUST be upheld for safety:
65+
// - `pointer` must points to a slot in the redirect array.
66+
// - that slot must contain a pointer of an Arc obtained from `Arc::<A>::into_raw`.
67+
// - that Arc must still be alive (strong count > 0) at the time the function is called.
68+
69+
/// Clone a Waker from a type-erased pointer.
70+
/// The pointer must satisfy the safety constraints listed in the code comments above.
71+
unsafe fn clone_waker<A: SharedArcContent>(pointer: *const ()) -> RawWaker {
6872
// Retype the type-erased pointer.
6973
let pointer = pointer as *const *const A;
7074

71-
// We're creating a new Waker, so we need to increment the count.
72-
// SAFETY: The constraints listed for the wrapping function documentation means
75+
// Increment the count so that the Arc won't die before this new Waker we're creating.
76+
// SAFETY: The required constraints means
7377
// - `*pointer` is an `*const A` obtained from `Arc::<A>::into_raw`.
74-
// - the Arc is alive.
75-
// So this operation is safe.
78+
// - the Arc is alive right now.
7679
unsafe { Arc::increment_strong_count(*pointer) };
7780

7881
RawWaker::new(pointer as *const (), create_vtable::<A>())
7982
}
8083

8184
/// Invoke `SharedArcContent::wake_index` with the index in the redirect slice where this pointer points to.
82-
/// The pointer must satisfy the safety constraints listed in the wrapping function's documentation.
85+
/// The pointer must satisfy the safety constraints listed in the code comments above.
8386
unsafe fn wake_by_ref<A: SharedArcContent>(pointer: *const ()) {
8487
// Retype the type-erased pointer.
8588
let pointer = pointer as *const *const A;
@@ -89,31 +92,28 @@ pub(super) unsafe fn waker_from_redirect_position<A: SharedArcContent>(
8992
// SAFETY: we are already requiring the pointer in the redirect array slot to be obtained from `Arc::into_raw`.
9093
let arc_content: &A = unsafe { &*raw };
9194

92-
// Calculate the index.
93-
// This is your familiar pointer math
94-
// `item_address = array_address + (index * item_size)`
95-
// rearranged to
96-
// `index = (item_address - array_address) / item_size`.
97-
let item_address = sptr::Strict::addr(pointer);
98-
let redirect_slice_address = sptr::Strict::addr(arc_content.get_redirect_slice().as_ptr());
99-
let redirect_item_size = core::mem::size_of::<*const A>(); // the size of each item in the redirect slice
100-
let index = (item_address - redirect_slice_address) / redirect_item_size;
95+
let slice_start = arc_content.get_redirect_slice().as_ptr();
96+
97+
// We'll switch to [`sub_ptr`](https://github.com/rust-lang/rust/issues/95892) once that's stable.
98+
let index = unsafe { pointer.offset_from(slice_start) } as usize;
10199

102100
arc_content.wake_index(index);
103101
}
104102

105-
/// The pointer must satisfy the safety constraints listed in the wrapping function's documentation.
103+
/// Drop the waker (and drop the shared Arc if other Wakers and the combinator have all been dropped).
104+
/// The pointer must satisfy the safety constraints listed in the code comments above.
106105
unsafe fn drop_waker<A: SharedArcContent>(pointer: *const ()) {
107106
// Retype the type-erased pointer.
108107
let pointer = pointer as *const *const A;
109108

110109
// SAFETY: we are already requiring `pointer` to point to a slot in the redirect array.
111-
let raw = unsafe { *pointer };
110+
let raw: *const A = unsafe { *pointer };
112111
// SAFETY: we are already requiring the pointer in the redirect array slot to be obtained from `Arc::into_raw`.
113112
unsafe { Arc::decrement_strong_count(raw) };
114113
}
115114

116-
/// The pointer must satisfy the safety constraints listed in the wrapping function's documentation.
115+
/// Like `wake_by_ref` but consumes the Waker.
116+
/// The pointer must satisfy the safety constraints listed in the code comments above.
117117
unsafe fn wake<A: SharedArcContent>(pointer: *const ()) {
118118
// SAFETY: we are already requiring the constraints of `wake_by_ref` and `drop_waker`.
119119
unsafe {
@@ -124,13 +124,27 @@ pub(super) unsafe fn waker_from_redirect_position<A: SharedArcContent>(
124124

125125
fn create_vtable<A: SharedArcContent>() -> &'static RawWakerVTable {
126126
&RawWakerVTable::new(
127-
create_waker::<A>,
127+
clone_waker::<A>,
128128
wake::<A>,
129129
wake_by_ref::<A>,
130130
drop_waker::<A>,
131131
)
132132
}
133+
134+
let redirect_slice: &[*const A] = arc.get_redirect_slice();
135+
136+
debug_assert!(redirect_slice.len() > index);
137+
138+
// SAFETY: we are already requiring that index be in bound of the slice.
139+
let pointer: *const *const A = unsafe { redirect_slice.as_ptr().add(index) };
140+
// Type-erase the pointer because the Waker methods expect so.
141+
let pointer = pointer as *const ();
142+
143+
// We want to transfer management of the one strong count associated with `arc` to the Waker we're creating.
144+
// That count should only be decremented when the Waker is dropped (by `drop_waker`).
145+
core::mem::forget(arc);
146+
133147
// SAFETY: All our vtable functions adhere to the RawWakerVTable contract,
134148
// and we are already requiring that `pointer` is what our functions expect.
135-
unsafe { Waker::from_raw(create_waker::<A>(pointer as *const ())) }
149+
unsafe { Waker::from_raw(RawWaker::new(pointer, create_vtable::<A>())) }
136150
}

src/utils/wakers/vec/waker_vec.rs

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use core::task::Waker;
2-
use std::sync::{Arc, Mutex};
2+
use std::sync::{Arc, Mutex, Weak};
33

44
use super::{
55
super::shared_arc::{waker_from_redirect_position, SharedArcContent},
@@ -21,27 +21,18 @@ struct WakerVecInner {
2121
impl WakerVec {
2222
/// Create a new instance of `WakerVec`.
2323
pub(crate) fn new(len: usize) -> Self {
24-
let mut inner = Arc::new(WakerVecInner {
25-
readiness: Mutex::new(ReadinessVec::new(len)),
26-
redirect: Vec::new(),
24+
let inner = Arc::new_cyclic(|w| {
25+
// `Weak::as_ptr` on a live Weak gives the same thing as `Arc::into_raw`.
26+
let raw = Weak::as_ptr(w);
27+
WakerVecInner {
28+
readiness: Mutex::new(ReadinessVec::new(len)),
29+
redirect: vec![raw; len],
30+
}
2731
});
28-
let raw = Arc::into_raw(Arc::clone(&inner)); // The Arc's address.
29-
30-
// At this point the strong count is 2. Decrement it to 1.
31-
// Each time we create/clone a Waker the count will be incremented by 1.
32-
// So N Wakers -> count = N+1.
33-
unsafe { Arc::decrement_strong_count(raw) }
34-
35-
// Make redirect all point to the Arc itself.
36-
Arc::get_mut(&mut inner).unwrap().redirect = vec![raw; len];
3732

3833
// Now the redirect vec is complete. Time to create the actual Wakers.
39-
let wakers = inner
40-
.redirect
41-
.iter()
42-
.map(|data| unsafe {
43-
waker_from_redirect_position::<WakerVecInner>(data as *const *const WakerVecInner)
44-
})
34+
let wakers = (0..len)
35+
.map(|i| unsafe { waker_from_redirect_position(Arc::clone(&inner), i) })
4536
.collect();
4637

4738
Self { inner, wakers }
@@ -56,7 +47,8 @@ impl WakerVec {
5647
}
5748
}
5849

59-
impl SharedArcContent for WakerVecInner {
50+
#[deny(unsafe_op_in_unsafe_fn)]
51+
unsafe impl SharedArcContent for WakerVecInner {
6052
fn get_redirect_slice(&self) -> &[*const Self] {
6153
&self.redirect
6254
}

0 commit comments

Comments
 (0)