1 // There's a lot of scary concurrent code in this module, but it is copied from
2 // `std::sync::Once` with two changes:
3 // * no poisoning
4 // * init function can fail
5
6 use std::{
7 cell::{Cell, UnsafeCell},
8 hint::unreachable_unchecked,
9 marker::PhantomData,
10 panic::{RefUnwindSafe, UnwindSafe},
11 sync::atomic::{AtomicBool, AtomicUsize, Ordering},
12 thread::{self, Thread},
13 };
14
15 use crate::take_unchecked;
16
17 #[derive(Debug)]
18 pub(crate) struct OnceCell<T> {
19 // This `state` word is actually an encoded version of just a pointer to a
20 // `Waiter`, so we add the `PhantomData` appropriately.
21 state_and_queue: AtomicUsize,
22 _marker: PhantomData<*mut Waiter>,
23 value: UnsafeCell<Option<T>>,
24 }
25
26 // Why do we need `T: Send`?
27 // Thread A creates a `OnceCell` and shares it with
28 // scoped thread B, which fills the cell, which is
29 // then destroyed by A. That is, destructor observes
30 // a sent value.
31 unsafe impl<T: Sync + Send> Sync for OnceCell<T> {}
32 unsafe impl<T: Send> Send for OnceCell<T> {}
33
34 impl<T: RefUnwindSafe + UnwindSafe> RefUnwindSafe for OnceCell<T> {}
35 impl<T: UnwindSafe> UnwindSafe for OnceCell<T> {}
36
37 // Three states that a OnceCell can be in, encoded into the lower bits of `state` in
38 // the OnceCell structure.
39 const INCOMPLETE: usize = 0x0;
40 const RUNNING: usize = 0x1;
41 const COMPLETE: usize = 0x2;
42
43 // Mask to learn about the state. All other bits are the queue of waiters if
44 // this is in the RUNNING state.
45 const STATE_MASK: usize = 0x3;
46
47 // Representation of a node in the linked list of waiters in the RUNNING state.
48 #[repr(align(4))] // Ensure the two lower bits are free to use as state bits.
49 struct Waiter {
50 thread: Cell<Option<Thread>>,
51 signaled: AtomicBool,
52 next: *const Waiter,
53 }
54
55 // Head of a linked list of waiters.
56 // Every node is a struct on the stack of a waiting thread.
57 // Will wake up the waiters when it gets dropped, i.e. also on panic.
58 struct WaiterQueue<'a> {
59 state_and_queue: &'a AtomicUsize,
60 set_state_on_drop_to: usize,
61 }
62
63 impl<T> OnceCell<T> {
new() -> OnceCell<T>64 pub(crate) const fn new() -> OnceCell<T> {
65 OnceCell {
66 state_and_queue: AtomicUsize::new(INCOMPLETE),
67 _marker: PhantomData,
68 value: UnsafeCell::new(None),
69 }
70 }
71
72 /// Safety: synchronizes with store to value via Release/(Acquire|SeqCst).
73 #[inline]
is_initialized(&self) -> bool74 pub(crate) fn is_initialized(&self) -> bool {
75 // An `Acquire` load is enough because that makes all the initialization
76 // operations visible to us, and, this being a fast path, weaker
77 // ordering helps with performance. This `Acquire` synchronizes with
78 // `SeqCst` operations on the slow path.
79 self.state_and_queue.load(Ordering::Acquire) == COMPLETE
80 }
81
82 /// Safety: synchronizes with store to value via SeqCst read from state,
83 /// writes value only once because we never get to INCOMPLETE state after a
84 /// successful write.
85 #[cold]
initialize<F, E>(&self, f: F) -> Result<(), E> where F: FnOnce() -> Result<T, E>,86 pub(crate) fn initialize<F, E>(&self, f: F) -> Result<(), E>
87 where
88 F: FnOnce() -> Result<T, E>,
89 {
90 let mut f = Some(f);
91 let mut res: Result<(), E> = Ok(());
92 let slot: *mut Option<T> = self.value.get();
93 initialize_inner(&self.state_and_queue, &mut || {
94 let f = unsafe { take_unchecked(&mut f) };
95 match f() {
96 Ok(value) => {
97 unsafe { *slot = Some(value) };
98 true
99 }
100 Err(err) => {
101 res = Err(err);
102 false
103 }
104 }
105 });
106 res
107 }
108
109 /// Get the reference to the underlying value, without checking if the cell
110 /// is initialized.
111 ///
112 /// # Safety
113 ///
114 /// Caller must ensure that the cell is in initialized state, and that
115 /// the contents are acquired by (synchronized to) this thread.
get_unchecked(&self) -> &T116 pub(crate) unsafe fn get_unchecked(&self) -> &T {
117 debug_assert!(self.is_initialized());
118 let slot: &Option<T> = &*self.value.get();
119 match slot {
120 Some(value) => value,
121 // This unsafe does improve performance, see `examples/bench`.
122 None => {
123 debug_assert!(false);
124 unreachable_unchecked()
125 }
126 }
127 }
128
129 /// Gets the mutable reference to the underlying value.
130 /// Returns `None` if the cell is empty.
get_mut(&mut self) -> Option<&mut T>131 pub(crate) fn get_mut(&mut self) -> Option<&mut T> {
132 // Safe b/c we have a unique access.
133 unsafe { &mut *self.value.get() }.as_mut()
134 }
135
136 /// Consumes this `OnceCell`, returning the wrapped value.
137 /// Returns `None` if the cell was empty.
138 #[inline]
into_inner(self) -> Option<T>139 pub(crate) fn into_inner(self) -> Option<T> {
140 // Because `into_inner` takes `self` by value, the compiler statically
141 // verifies that it is not currently borrowed.
142 // So, it is safe to move out `Option<T>`.
143 self.value.into_inner()
144 }
145 }
146
147 // Corresponds to `std::sync::Once::call_inner`
148 // Note: this is intentionally monomorphic
149 #[inline(never)]
initialize_inner(my_state_and_queue: &AtomicUsize, init: &mut dyn FnMut() -> bool) -> bool150 fn initialize_inner(my_state_and_queue: &AtomicUsize, init: &mut dyn FnMut() -> bool) -> bool {
151 let mut state_and_queue = my_state_and_queue.load(Ordering::Acquire);
152
153 loop {
154 match state_and_queue {
155 COMPLETE => return true,
156 INCOMPLETE => {
157 let exchange = my_state_and_queue.compare_exchange(
158 state_and_queue,
159 RUNNING,
160 Ordering::Acquire,
161 Ordering::Acquire,
162 );
163 if let Err(old) = exchange {
164 state_and_queue = old;
165 continue;
166 }
167 let mut waiter_queue = WaiterQueue {
168 state_and_queue: my_state_and_queue,
169 set_state_on_drop_to: INCOMPLETE, // Difference, std uses `POISONED`
170 };
171 let success = init();
172
173 // Difference, std always uses `COMPLETE`
174 waiter_queue.set_state_on_drop_to = if success { COMPLETE } else { INCOMPLETE };
175 return success;
176 }
177 _ => {
178 assert!(state_and_queue & STATE_MASK == RUNNING);
179 wait(&my_state_and_queue, state_and_queue);
180 state_and_queue = my_state_and_queue.load(Ordering::Acquire);
181 }
182 }
183 }
184 }
185
186 // Copy-pasted from std exactly.
wait(state_and_queue: &AtomicUsize, mut current_state: usize)187 fn wait(state_and_queue: &AtomicUsize, mut current_state: usize) {
188 loop {
189 if current_state & STATE_MASK != RUNNING {
190 return;
191 }
192
193 let node = Waiter {
194 thread: Cell::new(Some(thread::current())),
195 signaled: AtomicBool::new(false),
196 next: (current_state & !STATE_MASK) as *const Waiter,
197 };
198 let me = &node as *const Waiter as usize;
199
200 let exchange = state_and_queue.compare_exchange(
201 current_state,
202 me | RUNNING,
203 Ordering::Release,
204 Ordering::Relaxed,
205 );
206 if let Err(old) = exchange {
207 current_state = old;
208 continue;
209 }
210
211 while !node.signaled.load(Ordering::Acquire) {
212 thread::park();
213 }
214 break;
215 }
216 }
217
218 // Copy-pasted from std exactly.
219 impl Drop for WaiterQueue<'_> {
drop(&mut self)220 fn drop(&mut self) {
221 let state_and_queue =
222 self.state_and_queue.swap(self.set_state_on_drop_to, Ordering::AcqRel);
223
224 assert_eq!(state_and_queue & STATE_MASK, RUNNING);
225
226 unsafe {
227 let mut queue = (state_and_queue & !STATE_MASK) as *const Waiter;
228 while !queue.is_null() {
229 let next = (*queue).next;
230 let thread = (*queue).thread.replace(None).unwrap();
231 (*queue).signaled.store(true, Ordering::Release);
232 queue = next;
233 thread.unpark();
234 }
235 }
236 }
237 }
238
239 // These test are snatched from std as well.
240 #[cfg(test)]
241 mod tests {
242 use std::panic;
243 use std::{sync::mpsc::channel, thread};
244
245 use super::OnceCell;
246
247 impl<T> OnceCell<T> {
init(&self, f: impl FnOnce() -> T)248 fn init(&self, f: impl FnOnce() -> T) {
249 enum Void {}
250 let _ = self.initialize(|| Ok::<T, Void>(f()));
251 }
252 }
253
254 #[test]
smoke_once()255 fn smoke_once() {
256 static O: OnceCell<()> = OnceCell::new();
257 let mut a = 0;
258 O.init(|| a += 1);
259 assert_eq!(a, 1);
260 O.init(|| a += 1);
261 assert_eq!(a, 1);
262 }
263
264 #[test]
265 #[cfg(not(miri))]
stampede_once()266 fn stampede_once() {
267 static O: OnceCell<()> = OnceCell::new();
268 static mut RUN: bool = false;
269
270 let (tx, rx) = channel();
271 for _ in 0..10 {
272 let tx = tx.clone();
273 thread::spawn(move || {
274 for _ in 0..4 {
275 thread::yield_now()
276 }
277 unsafe {
278 O.init(|| {
279 assert!(!RUN);
280 RUN = true;
281 });
282 assert!(RUN);
283 }
284 tx.send(()).unwrap();
285 });
286 }
287
288 unsafe {
289 O.init(|| {
290 assert!(!RUN);
291 RUN = true;
292 });
293 assert!(RUN);
294 }
295
296 for _ in 0..10 {
297 rx.recv().unwrap();
298 }
299 }
300
301 #[test]
poison_bad()302 fn poison_bad() {
303 static O: OnceCell<()> = OnceCell::new();
304
305 // poison the once
306 let t = panic::catch_unwind(|| {
307 O.init(|| panic!());
308 });
309 assert!(t.is_err());
310
311 // we can subvert poisoning, however
312 let mut called = false;
313 O.init(|| {
314 called = true;
315 });
316 assert!(called);
317
318 // once any success happens, we stop propagating the poison
319 O.init(|| {});
320 }
321
322 #[test]
wait_for_force_to_finish()323 fn wait_for_force_to_finish() {
324 static O: OnceCell<()> = OnceCell::new();
325
326 // poison the once
327 let t = panic::catch_unwind(|| {
328 O.init(|| panic!());
329 });
330 assert!(t.is_err());
331
332 // make sure someone's waiting inside the once via a force
333 let (tx1, rx1) = channel();
334 let (tx2, rx2) = channel();
335 let t1 = thread::spawn(move || {
336 O.init(|| {
337 tx1.send(()).unwrap();
338 rx2.recv().unwrap();
339 });
340 });
341
342 rx1.recv().unwrap();
343
344 // put another waiter on the once
345 let t2 = thread::spawn(|| {
346 let mut called = false;
347 O.init(|| {
348 called = true;
349 });
350 assert!(!called);
351 });
352
353 tx2.send(()).unwrap();
354
355 assert!(t1.join().is_ok());
356 assert!(t2.join().is_ok());
357 }
358
359 #[test]
360 #[cfg(target_pointer_width = "64")]
test_size()361 fn test_size() {
362 use std::mem::size_of;
363
364 assert_eq!(size_of::<OnceCell<u32>>(), 4 * size_of::<u32>());
365 }
366 }
367