1 // There's a lot of scary concurrent code in this module, but it is copied from
2 // `std::sync::Once` with two changes:
3 //   * no poisoning
4 //   * init function can fail
5 
6 use std::{
7     cell::{Cell, UnsafeCell},
8     hint::unreachable_unchecked,
9     marker::PhantomData,
10     panic::{RefUnwindSafe, UnwindSafe},
11     sync::atomic::{AtomicBool, AtomicUsize, Ordering},
12     thread::{self, Thread},
13 };
14 
15 use crate::take_unchecked;
16 
17 #[derive(Debug)]
18 pub(crate) struct OnceCell<T> {
19     // This `state` word is actually an encoded version of just a pointer to a
20     // `Waiter`, so we add the `PhantomData` appropriately.
21     state_and_queue: AtomicUsize,
22     _marker: PhantomData<*mut Waiter>,
23     value: UnsafeCell<Option<T>>,
24 }
25 
26 // Why do we need `T: Send`?
27 // Thread A creates a `OnceCell` and shares it with
28 // scoped thread B, which fills the cell, which is
29 // then destroyed by A. That is, destructor observes
30 // a sent value.
31 unsafe impl<T: Sync + Send> Sync for OnceCell<T> {}
32 unsafe impl<T: Send> Send for OnceCell<T> {}
33 
34 impl<T: RefUnwindSafe + UnwindSafe> RefUnwindSafe for OnceCell<T> {}
35 impl<T: UnwindSafe> UnwindSafe for OnceCell<T> {}
36 
37 // Three states that a OnceCell can be in, encoded into the lower bits of `state` in
38 // the OnceCell structure.
39 const INCOMPLETE: usize = 0x0;
40 const RUNNING: usize = 0x1;
41 const COMPLETE: usize = 0x2;
42 
43 // Mask to learn about the state. All other bits are the queue of waiters if
44 // this is in the RUNNING state.
45 const STATE_MASK: usize = 0x3;
46 
47 // Representation of a node in the linked list of waiters in the RUNNING state.
48 #[repr(align(4))] // Ensure the two lower bits are free to use as state bits.
49 struct Waiter {
50     thread: Cell<Option<Thread>>,
51     signaled: AtomicBool,
52     next: *const Waiter,
53 }
54 
55 // Head of a linked list of waiters.
56 // Every node is a struct on the stack of a waiting thread.
57 // Will wake up the waiters when it gets dropped, i.e. also on panic.
58 struct WaiterQueue<'a> {
59     state_and_queue: &'a AtomicUsize,
60     set_state_on_drop_to: usize,
61 }
62 
63 impl<T> OnceCell<T> {
new() -> OnceCell<T>64     pub(crate) const fn new() -> OnceCell<T> {
65         OnceCell {
66             state_and_queue: AtomicUsize::new(INCOMPLETE),
67             _marker: PhantomData,
68             value: UnsafeCell::new(None),
69         }
70     }
71 
72     /// Safety: synchronizes with store to value via Release/(Acquire|SeqCst).
73     #[inline]
is_initialized(&self) -> bool74     pub(crate) fn is_initialized(&self) -> bool {
75         // An `Acquire` load is enough because that makes all the initialization
76         // operations visible to us, and, this being a fast path, weaker
77         // ordering helps with performance. This `Acquire` synchronizes with
78         // `SeqCst` operations on the slow path.
79         self.state_and_queue.load(Ordering::Acquire) == COMPLETE
80     }
81 
82     /// Safety: synchronizes with store to value via SeqCst read from state,
83     /// writes value only once because we never get to INCOMPLETE state after a
84     /// successful write.
85     #[cold]
initialize<F, E>(&self, f: F) -> Result<(), E> where F: FnOnce() -> Result<T, E>,86     pub(crate) fn initialize<F, E>(&self, f: F) -> Result<(), E>
87     where
88         F: FnOnce() -> Result<T, E>,
89     {
90         let mut f = Some(f);
91         let mut res: Result<(), E> = Ok(());
92         let slot: *mut Option<T> = self.value.get();
93         initialize_inner(&self.state_and_queue, &mut || {
94             let f = unsafe { take_unchecked(&mut f) };
95             match f() {
96                 Ok(value) => {
97                     unsafe { *slot = Some(value) };
98                     true
99                 }
100                 Err(err) => {
101                     res = Err(err);
102                     false
103                 }
104             }
105         });
106         res
107     }
108 
109     /// Get the reference to the underlying value, without checking if the cell
110     /// is initialized.
111     ///
112     /// # Safety
113     ///
114     /// Caller must ensure that the cell is in initialized state, and that
115     /// the contents are acquired by (synchronized to) this thread.
get_unchecked(&self) -> &T116     pub(crate) unsafe fn get_unchecked(&self) -> &T {
117         debug_assert!(self.is_initialized());
118         let slot: &Option<T> = &*self.value.get();
119         match slot {
120             Some(value) => value,
121             // This unsafe does improve performance, see `examples/bench`.
122             None => {
123                 debug_assert!(false);
124                 unreachable_unchecked()
125             }
126         }
127     }
128 
129     /// Gets the mutable reference to the underlying value.
130     /// Returns `None` if the cell is empty.
get_mut(&mut self) -> Option<&mut T>131     pub(crate) fn get_mut(&mut self) -> Option<&mut T> {
132         // Safe b/c we have a unique access.
133         unsafe { &mut *self.value.get() }.as_mut()
134     }
135 
136     /// Consumes this `OnceCell`, returning the wrapped value.
137     /// Returns `None` if the cell was empty.
138     #[inline]
into_inner(self) -> Option<T>139     pub(crate) fn into_inner(self) -> Option<T> {
140         // Because `into_inner` takes `self` by value, the compiler statically
141         // verifies that it is not currently borrowed.
142         // So, it is safe to move out `Option<T>`.
143         self.value.into_inner()
144     }
145 }
146 
147 // Corresponds to `std::sync::Once::call_inner`
148 // Note: this is intentionally monomorphic
149 #[inline(never)]
initialize_inner(my_state_and_queue: &AtomicUsize, init: &mut dyn FnMut() -> bool) -> bool150 fn initialize_inner(my_state_and_queue: &AtomicUsize, init: &mut dyn FnMut() -> bool) -> bool {
151     let mut state_and_queue = my_state_and_queue.load(Ordering::Acquire);
152 
153     loop {
154         match state_and_queue {
155             COMPLETE => return true,
156             INCOMPLETE => {
157                 let exchange = my_state_and_queue.compare_exchange(
158                     state_and_queue,
159                     RUNNING,
160                     Ordering::Acquire,
161                     Ordering::Acquire,
162                 );
163                 if let Err(old) = exchange {
164                     state_and_queue = old;
165                     continue;
166                 }
167                 let mut waiter_queue = WaiterQueue {
168                     state_and_queue: my_state_and_queue,
169                     set_state_on_drop_to: INCOMPLETE, // Difference, std uses `POISONED`
170                 };
171                 let success = init();
172 
173                 // Difference, std always uses `COMPLETE`
174                 waiter_queue.set_state_on_drop_to = if success { COMPLETE } else { INCOMPLETE };
175                 return success;
176             }
177             _ => {
178                 assert!(state_and_queue & STATE_MASK == RUNNING);
179                 wait(&my_state_and_queue, state_and_queue);
180                 state_and_queue = my_state_and_queue.load(Ordering::Acquire);
181             }
182         }
183     }
184 }
185 
186 // Copy-pasted from std exactly.
wait(state_and_queue: &AtomicUsize, mut current_state: usize)187 fn wait(state_and_queue: &AtomicUsize, mut current_state: usize) {
188     loop {
189         if current_state & STATE_MASK != RUNNING {
190             return;
191         }
192 
193         let node = Waiter {
194             thread: Cell::new(Some(thread::current())),
195             signaled: AtomicBool::new(false),
196             next: (current_state & !STATE_MASK) as *const Waiter,
197         };
198         let me = &node as *const Waiter as usize;
199 
200         let exchange = state_and_queue.compare_exchange(
201             current_state,
202             me | RUNNING,
203             Ordering::Release,
204             Ordering::Relaxed,
205         );
206         if let Err(old) = exchange {
207             current_state = old;
208             continue;
209         }
210 
211         while !node.signaled.load(Ordering::Acquire) {
212             thread::park();
213         }
214         break;
215     }
216 }
217 
218 // Copy-pasted from std exactly.
219 impl Drop for WaiterQueue<'_> {
drop(&mut self)220     fn drop(&mut self) {
221         let state_and_queue =
222             self.state_and_queue.swap(self.set_state_on_drop_to, Ordering::AcqRel);
223 
224         assert_eq!(state_and_queue & STATE_MASK, RUNNING);
225 
226         unsafe {
227             let mut queue = (state_and_queue & !STATE_MASK) as *const Waiter;
228             while !queue.is_null() {
229                 let next = (*queue).next;
230                 let thread = (*queue).thread.replace(None).unwrap();
231                 (*queue).signaled.store(true, Ordering::Release);
232                 queue = next;
233                 thread.unpark();
234             }
235         }
236     }
237 }
238 
239 // These test are snatched from std as well.
240 #[cfg(test)]
241 mod tests {
242     use std::panic;
243     use std::{sync::mpsc::channel, thread};
244 
245     use super::OnceCell;
246 
247     impl<T> OnceCell<T> {
init(&self, f: impl FnOnce() -> T)248         fn init(&self, f: impl FnOnce() -> T) {
249             enum Void {}
250             let _ = self.initialize(|| Ok::<T, Void>(f()));
251         }
252     }
253 
254     #[test]
smoke_once()255     fn smoke_once() {
256         static O: OnceCell<()> = OnceCell::new();
257         let mut a = 0;
258         O.init(|| a += 1);
259         assert_eq!(a, 1);
260         O.init(|| a += 1);
261         assert_eq!(a, 1);
262     }
263 
264     #[test]
265     #[cfg(not(miri))]
stampede_once()266     fn stampede_once() {
267         static O: OnceCell<()> = OnceCell::new();
268         static mut RUN: bool = false;
269 
270         let (tx, rx) = channel();
271         for _ in 0..10 {
272             let tx = tx.clone();
273             thread::spawn(move || {
274                 for _ in 0..4 {
275                     thread::yield_now()
276                 }
277                 unsafe {
278                     O.init(|| {
279                         assert!(!RUN);
280                         RUN = true;
281                     });
282                     assert!(RUN);
283                 }
284                 tx.send(()).unwrap();
285             });
286         }
287 
288         unsafe {
289             O.init(|| {
290                 assert!(!RUN);
291                 RUN = true;
292             });
293             assert!(RUN);
294         }
295 
296         for _ in 0..10 {
297             rx.recv().unwrap();
298         }
299     }
300 
301     #[test]
poison_bad()302     fn poison_bad() {
303         static O: OnceCell<()> = OnceCell::new();
304 
305         // poison the once
306         let t = panic::catch_unwind(|| {
307             O.init(|| panic!());
308         });
309         assert!(t.is_err());
310 
311         // we can subvert poisoning, however
312         let mut called = false;
313         O.init(|| {
314             called = true;
315         });
316         assert!(called);
317 
318         // once any success happens, we stop propagating the poison
319         O.init(|| {});
320     }
321 
322     #[test]
wait_for_force_to_finish()323     fn wait_for_force_to_finish() {
324         static O: OnceCell<()> = OnceCell::new();
325 
326         // poison the once
327         let t = panic::catch_unwind(|| {
328             O.init(|| panic!());
329         });
330         assert!(t.is_err());
331 
332         // make sure someone's waiting inside the once via a force
333         let (tx1, rx1) = channel();
334         let (tx2, rx2) = channel();
335         let t1 = thread::spawn(move || {
336             O.init(|| {
337                 tx1.send(()).unwrap();
338                 rx2.recv().unwrap();
339             });
340         });
341 
342         rx1.recv().unwrap();
343 
344         // put another waiter on the once
345         let t2 = thread::spawn(|| {
346             let mut called = false;
347             O.init(|| {
348                 called = true;
349             });
350             assert!(!called);
351         });
352 
353         tx2.send(()).unwrap();
354 
355         assert!(t1.join().is_ok());
356         assert!(t2.join().is_ok());
357     }
358 
359     #[test]
360     #[cfg(target_pointer_width = "64")]
test_size()361     fn test_size() {
362         use std::mem::size_of;
363 
364         assert_eq!(size_of::<OnceCell<u32>>(), 4 * size_of::<u32>());
365     }
366 }
367