1 use std::cmp;
2 use std::fmt;
3 
4 use crate::errors::InvalidThreadAccess;
5 use crate::fragile::Fragile;
6 use crate::sticky::Sticky;
7 use std::mem;
8 
9 enum SemiStickyImpl<T> {
10     Fragile(Fragile<T>),
11     Sticky(Sticky<T>),
12 }
13 
14 /// A `SemiSticky<T>` keeps a value T stored in a thread if it has a drop.
15 ///
16 /// This is a combined version of `Fragile<T>` and `Sticky<T>`.  If the type
17 /// does not have a drop it will effectively be a `Fragile<T>`, otherwise it
18 /// will be internally behave like a `Sticky<T>`.
19 pub struct SemiSticky<T> {
20     inner: SemiStickyImpl<T>,
21 }
22 
23 impl<T> SemiSticky<T> {
24     /// Creates a new `SemiSticky` wrapping a `value`.
25     ///
26     /// The value that is moved into the `SemiSticky` can be non `Send` and
27     /// will be anchored to the thread that created the object.  If the
28     /// sticky wrapper type ends up being send from thread to thread
29     /// only the original thread can interact with the value.  In case the
30     /// value does not have `Drop` it will be stored in the `SemiSticky`
31     /// instead.
new(value: T) -> Self32     pub fn new(value: T) -> Self {
33         SemiSticky {
34             inner: if mem::needs_drop::<T>() {
35                 SemiStickyImpl::Sticky(Sticky::new(value))
36             } else {
37                 SemiStickyImpl::Fragile(Fragile::new(value))
38             },
39         }
40     }
41 
42     /// Returns `true` if the access is valid.
43     ///
44     /// This will be `false` if the value was sent to another thread.
is_valid(&self) -> bool45     pub fn is_valid(&self) -> bool {
46         match self.inner {
47             SemiStickyImpl::Fragile(ref inner) => inner.is_valid(),
48             SemiStickyImpl::Sticky(ref inner) => inner.is_valid(),
49         }
50     }
51 
52     /// Consumes the `SemiSticky`, returning the wrapped value.
53     ///
54     /// # Panics
55     ///
56     /// Panics if called from a different thread than the one where the
57     /// original value was created.
into_inner(self) -> T58     pub fn into_inner(self) -> T {
59         match self.inner {
60             SemiStickyImpl::Fragile(inner) => inner.into_inner(),
61             SemiStickyImpl::Sticky(inner) => inner.into_inner(),
62         }
63     }
64 
65     /// Consumes the `SemiSticky`, returning the wrapped value if successful.
66     ///
67     /// The wrapped value is returned if this is called from the same thread
68     /// as the one where the original value was created, otherwise the
69     /// `SemiSticky` is returned as `Err(self)`.
try_into_inner(self) -> Result<T, Self>70     pub fn try_into_inner(self) -> Result<T, Self> {
71         match self.inner {
72             SemiStickyImpl::Fragile(inner) => inner.try_into_inner().map_err(|inner| SemiSticky {
73                 inner: SemiStickyImpl::Fragile(inner),
74             }),
75             SemiStickyImpl::Sticky(inner) => inner.try_into_inner().map_err(|inner| SemiSticky {
76                 inner: SemiStickyImpl::Sticky(inner),
77             }),
78         }
79     }
80 
81     /// Immutably borrows the wrapped value.
82     ///
83     /// # Panics
84     ///
85     /// Panics if the calling thread is not the one that wrapped the value.
86     /// For a non-panicking variant, use [`try_get`](#method.try_get`).
get(&self) -> &T87     pub fn get(&self) -> &T {
88         match self.inner {
89             SemiStickyImpl::Fragile(ref inner) => inner.get(),
90             SemiStickyImpl::Sticky(ref inner) => inner.get(),
91         }
92     }
93 
94     /// Mutably borrows the wrapped value.
95     ///
96     /// # Panics
97     ///
98     /// Panics if the calling thread is not the one that wrapped the value.
99     /// For a non-panicking variant, use [`try_get_mut`](#method.try_get_mut`).
get_mut(&mut self) -> &mut T100     pub fn get_mut(&mut self) -> &mut T {
101         match self.inner {
102             SemiStickyImpl::Fragile(ref mut inner) => inner.get_mut(),
103             SemiStickyImpl::Sticky(ref mut inner) => inner.get_mut(),
104         }
105     }
106 
107     /// Tries to immutably borrow the wrapped value.
108     ///
109     /// Returns `None` if the calling thread is not the one that wrapped the value.
try_get(&self) -> Result<&T, InvalidThreadAccess>110     pub fn try_get(&self) -> Result<&T, InvalidThreadAccess> {
111         match self.inner {
112             SemiStickyImpl::Fragile(ref inner) => inner.try_get(),
113             SemiStickyImpl::Sticky(ref inner) => inner.try_get(),
114         }
115     }
116 
117     /// Tries to mutably borrow the wrapped value.
118     ///
119     /// Returns `None` if the calling thread is not the one that wrapped the value.
try_get_mut(&mut self) -> Result<&mut T, InvalidThreadAccess>120     pub fn try_get_mut(&mut self) -> Result<&mut T, InvalidThreadAccess> {
121         match self.inner {
122             SemiStickyImpl::Fragile(ref mut inner) => inner.try_get_mut(),
123             SemiStickyImpl::Sticky(ref mut inner) => inner.try_get_mut(),
124         }
125     }
126 }
127 
128 impl<T> From<T> for SemiSticky<T> {
129     #[inline]
from(t: T) -> SemiSticky<T>130     fn from(t: T) -> SemiSticky<T> {
131         SemiSticky::new(t)
132     }
133 }
134 
135 impl<T: Clone> Clone for SemiSticky<T> {
136     #[inline]
clone(&self) -> SemiSticky<T>137     fn clone(&self) -> SemiSticky<T> {
138         SemiSticky::new(self.get().clone())
139     }
140 }
141 
142 impl<T: Default> Default for SemiSticky<T> {
143     #[inline]
default() -> SemiSticky<T>144     fn default() -> SemiSticky<T> {
145         SemiSticky::new(T::default())
146     }
147 }
148 
149 impl<T: PartialEq> PartialEq for SemiSticky<T> {
150     #[inline]
eq(&self, other: &SemiSticky<T>) -> bool151     fn eq(&self, other: &SemiSticky<T>) -> bool {
152         *self.get() == *other.get()
153     }
154 }
155 
156 impl<T: Eq> Eq for SemiSticky<T> {}
157 
158 impl<T: PartialOrd> PartialOrd for SemiSticky<T> {
159     #[inline]
partial_cmp(&self, other: &SemiSticky<T>) -> Option<cmp::Ordering>160     fn partial_cmp(&self, other: &SemiSticky<T>) -> Option<cmp::Ordering> {
161         self.get().partial_cmp(&*other.get())
162     }
163 
164     #[inline]
lt(&self, other: &SemiSticky<T>) -> bool165     fn lt(&self, other: &SemiSticky<T>) -> bool {
166         *self.get() < *other.get()
167     }
168 
169     #[inline]
le(&self, other: &SemiSticky<T>) -> bool170     fn le(&self, other: &SemiSticky<T>) -> bool {
171         *self.get() <= *other.get()
172     }
173 
174     #[inline]
gt(&self, other: &SemiSticky<T>) -> bool175     fn gt(&self, other: &SemiSticky<T>) -> bool {
176         *self.get() > *other.get()
177     }
178 
179     #[inline]
ge(&self, other: &SemiSticky<T>) -> bool180     fn ge(&self, other: &SemiSticky<T>) -> bool {
181         *self.get() >= *other.get()
182     }
183 }
184 
185 impl<T: Ord> Ord for SemiSticky<T> {
186     #[inline]
cmp(&self, other: &SemiSticky<T>) -> cmp::Ordering187     fn cmp(&self, other: &SemiSticky<T>) -> cmp::Ordering {
188         self.get().cmp(&*other.get())
189     }
190 }
191 
192 impl<T: fmt::Display> fmt::Display for SemiSticky<T> {
fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error>193     fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
194         fmt::Display::fmt(self.get(), f)
195     }
196 }
197 
198 impl<T: fmt::Debug> fmt::Debug for SemiSticky<T> {
fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error>199     fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
200         match self.try_get() {
201             Ok(value) => f.debug_struct("SemiSticky").field("value", value).finish(),
202             Err(..) => {
203                 struct InvalidPlaceholder;
204                 impl fmt::Debug for InvalidPlaceholder {
205                     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
206                         f.write_str("<invalid thread>")
207                     }
208                 }
209 
210                 f.debug_struct("SemiSticky")
211                     .field("value", &InvalidPlaceholder)
212                     .finish()
213             }
214         }
215     }
216 }
217 
218 #[test]
test_basic()219 fn test_basic() {
220     use std::thread;
221     let val = SemiSticky::new(true);
222     assert_eq!(val.to_string(), "true");
223     assert_eq!(val.get(), &true);
224     assert!(val.try_get().is_ok());
225     thread::spawn(move || {
226         assert!(val.try_get().is_err());
227     })
228     .join()
229     .unwrap();
230 }
231 
232 #[test]
test_mut()233 fn test_mut() {
234     let mut val = SemiSticky::new(true);
235     *val.get_mut() = false;
236     assert_eq!(val.to_string(), "false");
237     assert_eq!(val.get(), &false);
238 }
239 
240 #[test]
241 #[should_panic]
test_access_other_thread()242 fn test_access_other_thread() {
243     use std::thread;
244     let val = SemiSticky::new(true);
245     thread::spawn(move || {
246         val.get();
247     })
248     .join()
249     .unwrap();
250 }
251 
252 #[test]
test_drop_same_thread()253 fn test_drop_same_thread() {
254     use std::sync::atomic::{AtomicBool, Ordering};
255     use std::sync::Arc;
256     let was_called = Arc::new(AtomicBool::new(false));
257     struct X(Arc<AtomicBool>);
258     impl Drop for X {
259         fn drop(&mut self) {
260             self.0.store(true, Ordering::SeqCst);
261         }
262     }
263     let val = SemiSticky::new(X(was_called.clone()));
264     mem::drop(val);
265     assert_eq!(was_called.load(Ordering::SeqCst), true);
266 }
267 
268 #[test]
test_noop_drop_elsewhere()269 fn test_noop_drop_elsewhere() {
270     use std::sync::atomic::{AtomicBool, Ordering};
271     use std::sync::Arc;
272     use std::thread;
273 
274     let was_called = Arc::new(AtomicBool::new(false));
275 
276     {
277         let was_called = was_called.clone();
278         thread::spawn(move || {
279             struct X(Arc<AtomicBool>);
280             impl Drop for X {
281                 fn drop(&mut self) {
282                     self.0.store(true, Ordering::SeqCst);
283                 }
284             }
285 
286             let val = SemiSticky::new(X(was_called.clone()));
287             assert!(thread::spawn(move || {
288                 // moves it here but do not deallocate
289                 val.try_get().ok();
290             })
291             .join()
292             .is_ok());
293 
294             assert_eq!(was_called.load(Ordering::SeqCst), false);
295         })
296         .join()
297         .unwrap();
298     }
299 
300     assert_eq!(was_called.load(Ordering::SeqCst), true);
301 }
302 
303 #[test]
test_rc_sending()304 fn test_rc_sending() {
305     use std::rc::Rc;
306     use std::thread;
307     let val = SemiSticky::new(Rc::new(true));
308     thread::spawn(move || {
309         assert!(val.try_get().is_err());
310     })
311     .join()
312     .unwrap();
313 }
314