1- use core:: {
2- marker:: PhantomData ,
3- num:: { NonZeroU32 , NonZeroU64 } ,
4- ptr:: NonNull ,
5- sync:: atomic:: { AtomicU64 , Ordering } ,
6- } ;
1+ use core:: { marker:: PhantomData , ptr:: NonNull } ;
2+
3+ #[ cfg( not( feature = "portable-atomic" ) ) ]
4+ use core:: sync:: atomic;
5+ #[ cfg( feature = "portable-atomic" ) ]
6+ use portable_atomic as atomic;
7+
8+ use atomic:: Ordering ;
79
810use super :: { Node , Stack } ;
911
12+ #[ cfg( target_pointer_width = "32" ) ]
13+ mod types {
14+ use super :: atomic;
15+
16+ pub type Inner = u64 ;
17+ pub type InnerAtomic = atomic:: AtomicU64 ;
18+ pub type InnerNonZero = core:: num:: NonZeroU64 ;
19+
20+ pub type Tag = core:: num:: NonZeroU32 ;
21+ pub type Address = u32 ;
22+ }
23+
24+ #[ cfg( target_pointer_width = "64" ) ]
25+ mod types {
26+ use super :: atomic;
27+
28+ pub type Inner = u128 ;
29+ pub type InnerAtomic = atomic:: AtomicU128 ;
30+ pub type InnerNonZero = core:: num:: NonZeroU128 ;
31+
32+ pub type Tag = core:: num:: NonZeroU64 ;
33+ pub type Address = u64 ;
34+ }
35+
36+ use types:: * ;
37+
1038pub struct AtomicPtr < N >
1139where
1240 N : Node ,
1341{
14- inner : AtomicU64 ,
42+ inner : InnerAtomic ,
1543 _marker : PhantomData < * mut N > ,
1644}
1745
1846impl < N > AtomicPtr < N >
1947where
2048 N : Node ,
2149{
50+ #[ inline]
2251 pub const fn null ( ) -> Self {
2352 Self {
24- inner : AtomicU64 :: new ( 0 ) ,
53+ inner : InnerAtomic :: new ( 0 ) ,
2554 _marker : PhantomData ,
2655 }
2756 }
@@ -35,37 +64,38 @@ where
3564 ) -> Result < ( ) , Option < NonNullPtr < N > > > {
3665 self . inner
3766 . compare_exchange_weak (
38- current
39- . map ( |pointer| pointer. into_u64 ( ) )
40- . unwrap_or_default ( ) ,
41- new. map ( |pointer| pointer. into_u64 ( ) ) . unwrap_or_default ( ) ,
67+ current. map ( NonNullPtr :: into_inner) . unwrap_or_default ( ) ,
68+ new. map ( NonNullPtr :: into_inner) . unwrap_or_default ( ) ,
4269 success,
4370 failure,
4471 )
4572 . map ( drop)
46- . map_err ( NonNullPtr :: from_u64)
73+ . map_err ( |value| {
74+ // SAFETY: `value` cam from a `NonNullPtr::into_inner` call.
75+ unsafe { NonNullPtr :: from_inner ( value) }
76+ } )
4777 }
4878
79+ #[ inline]
4980 fn load ( & self , order : Ordering ) -> Option < NonNullPtr < N > > {
50- NonZeroU64 :: new ( self . inner . load ( order ) ) . map ( |inner| NonNullPtr {
51- inner,
81+ Some ( NonNullPtr {
82+ inner : InnerNonZero :: new ( self . inner . load ( order ) ) ? ,
5283 _marker : PhantomData ,
5384 } )
5485 }
5586
87+ #[ inline]
5688 fn store ( & self , value : Option < NonNullPtr < N > > , order : Ordering ) {
57- self . inner . store (
58- value. map ( |pointer| pointer. into_u64 ( ) ) . unwrap_or_default ( ) ,
59- order,
60- )
89+ self . inner
90+ . store ( value. map ( NonNullPtr :: into_inner) . unwrap_or_default ( ) , order)
6191 }
6292}
6393
6494pub struct NonNullPtr < N >
6595where
6696 N : Node ,
6797{
68- inner : NonZeroU64 ,
98+ inner : InnerNonZero ,
6999 _marker : PhantomData < * mut N > ,
70100}
71101
@@ -84,65 +114,72 @@ impl<N> NonNullPtr<N>
84114where
85115 N : Node ,
86116{
117+ #[ inline]
87118 pub fn as_ptr ( & self ) -> * mut N {
88119 self . inner . get ( ) as * mut N
89120 }
90121
91- pub fn from_static_mut_ref ( ref_ : & ' static mut N ) -> NonNullPtr < N > {
92- let non_null = NonNull :: from ( ref_) ;
93- Self :: from_non_null ( non_null)
122+ #[ inline]
123+ pub fn from_static_mut_ref ( reference : & ' static mut N ) -> NonNullPtr < N > {
124+ // SAFETY: `reference` is a static mutable reference, i.e. a valid pointer.
125+ unsafe { Self :: new_unchecked ( initial_tag ( ) , NonNull :: from ( reference) ) }
94126 }
95127
96- fn from_non_null ( ptr : NonNull < N > ) -> Self {
97- let address = ptr. as_ptr ( ) as u32 ;
98- let tag = initial_tag ( ) . get ( ) ;
99-
100- let value = ( u64:: from ( tag) << 32 ) | u64:: from ( address) ;
128+ /// # Safety
129+ ///
130+ /// - `ptr` must be a valid pointer.
131+ #[ inline]
132+ unsafe fn new_unchecked ( tag : Tag , ptr : NonNull < N > ) -> Self {
133+ let value =
134+ ( Inner :: from ( tag. get ( ) ) << Address :: BITS ) | Inner :: from ( ptr. as_ptr ( ) as Address ) ;
101135
102136 Self {
103- inner : unsafe { NonZeroU64 :: new_unchecked ( value) } ,
137+ // SAFETY: `value` is constructed from a `Tag` which is non-zero and half the
138+ // size of the `InnerNonZero` type, and a `NonNull<N>` pointer.
139+ inner : unsafe { InnerNonZero :: new_unchecked ( value) } ,
104140 _marker : PhantomData ,
105141 }
106142 }
107143
108- fn from_u64 ( value : u64 ) -> Option < Self > {
109- NonZeroU64 :: new ( value) . map ( |inner| Self {
110- inner,
144+ /// # Safety
145+ ///
146+ /// - `value` must come from a `Self::into_inner` call.
147+ #[ inline]
148+ unsafe fn from_inner ( value : Inner ) -> Option < Self > {
149+ Some ( Self {
150+ inner : InnerNonZero :: new ( value) ?,
111151 _marker : PhantomData ,
112152 } )
113153 }
114154
155+ #[ inline]
115156 fn non_null ( & self ) -> NonNull < N > {
116- unsafe { NonNull :: new_unchecked ( self . inner . get ( ) as * mut N ) }
157+ // SAFETY: `Self` can only be constructed using a `NonNull<N>`.
158+ unsafe { NonNull :: new_unchecked ( self . as_ptr ( ) ) }
117159 }
118160
119- fn tag ( & self ) -> NonZeroU32 {
120- unsafe { NonZeroU32 :: new_unchecked ( ( self . inner . get ( ) >> 32 ) as u32 ) }
121- }
122-
123- fn into_u64 ( self ) -> u64 {
161+ #[ inline]
162+ fn into_inner ( self ) -> Inner {
124163 self . inner . get ( )
125164 }
126165
127- fn increase_tag ( & mut self ) {
128- let address = self . as_ptr ( ) as u32 ;
129-
130- let new_tag = self
131- . tag ( )
132- . get ( )
133- . checked_add ( 1 )
134- . map ( |val| unsafe { NonZeroU32 :: new_unchecked ( val) } )
135- . unwrap_or_else ( initial_tag)
136- . get ( ) ;
166+ #[ inline]
167+ fn tag ( & self ) -> Tag {
168+ // SAFETY: `self.inner` was constructed from a non-zero `Tag`.
169+ unsafe { Tag :: new_unchecked ( ( self . inner . get ( ) >> Address :: BITS ) as Address ) }
170+ }
137171
138- let value = ( u64:: from ( new_tag) << 32 ) | u64:: from ( address) ;
172+ fn increment_tag ( & mut self ) {
173+ let new_tag = self . tag ( ) . checked_add ( 1 ) . unwrap_or_else ( initial_tag) ;
139174
140- self . inner = unsafe { NonZeroU64 :: new_unchecked ( value) } ;
175+ // SAFETY: `self.non_null()` is a valid pointer.
176+ * self = unsafe { Self :: new_unchecked ( new_tag, self . non_null ( ) ) } ;
141177 }
142178}
143179
144- fn initial_tag ( ) -> NonZeroU32 {
145- unsafe { NonZeroU32 :: new_unchecked ( 1 ) }
180+ #[ inline]
181+ const fn initial_tag ( ) -> Tag {
182+ Tag :: MIN
146183}
147184
148185pub unsafe fn push < N > ( stack : & Stack < N > , new_top : NonNullPtr < N > )
@@ -184,7 +221,40 @@ where
184221 . compare_and_exchange_weak ( Some ( top) , next, Ordering :: Release , Ordering :: Relaxed )
185222 . is_ok ( )
186223 {
187- top. increase_tag ( ) ;
224+ // Prevent the ABA problem (https://en.wikipedia.org/wiki/Treiber_stack#Correctness).
225+ //
226+ // Without this, the following would be possible:
227+ //
228+ // | Thread 1 | Thread 2 | Stack |
229+ // |-------------------------------|-------------------------|------------------------------|
230+ // | push((1, 1)) | | (1, 1) |
231+ // | push((1, 2)) | | (1, 2) -> (1, 1) |
232+ // | p = try_pop()::load // (1, 2) | | (1, 2) -> (1, 1) |
233+ // | | p = try_pop() // (1, 2) | (1, 1) |
234+ // | | push((1, 3)) | (1, 3) -> (1, 1) |
235+ // | | push(p) | (1, 2) -> (1, 3) -> (1, 1) |
236+ // | try_pop()::cas(p, p.next) | | (1, 1) |
237+ //
238+ // As can be seen, the `cas` operation succeeds, wrongly removing pointer `3` from the stack.
239+ //
240+ // By incrementing the tag before returning the pointer, it cannot be pushed again with the,
241+ // same tag, preventing the `try_pop()::cas(p, p.next)` operation from succeeding.
242+ //
243+ // With this fix, `try_pop()` in thread 2 returns `(2, 2)` and the comparison between
244+ // `(1, 2)` and `(2, 2)` fails, restarting the loop and correctly removing the new top:
245+ //
246+ // | Thread 1 | Thread 2 | Stack |
247+ // |-------------------------------|-------------------------|------------------------------|
248+ // | push((1, 1)) | | (1, 1) |
249+ // | push((1, 2)) | | (1, 2) -> (1, 1) |
250+ // | p = try_pop()::load // (1, 2) | | (1, 2) -> (1, 1) |
251+ // | | p = try_pop() // (2, 2) | (1, 1) |
252+ // | | push((1, 3)) | (1, 3) -> (1, 1) |
253+ // | | push(p) | (2, 2) -> (1, 3) -> (1, 1) |
254+ // | try_pop()::cas(p, p.next) | | (2, 2) -> (1, 3) -> (1, 1) |
255+ // | p = try_pop()::load // (2, 2) | | (2, 2) -> (1, 3) -> (1, 1) |
256+ // | try_pop()::cas(p, p.next) | | (1, 3) -> (1, 1) |
257+ top. increment_tag ( ) ;
188258
189259 return Some ( top) ;
190260 }
0 commit comments