1 #![cfg_attr(not(feature = "full"), allow(dead_code))]
2 
3 //! An intrusive double linked list of data.
4 //!
5 //! The data structure supports tracking pinned nodes. Most of the data
6 //! structure's APIs are `unsafe` as they require the caller to ensure the
7 //! specified node is actually contained by the list.
8 
9 use core::cell::UnsafeCell;
10 use core::fmt;
11 use core::marker::{PhantomData, PhantomPinned};
12 use core::mem::ManuallyDrop;
13 use core::ptr::{self, NonNull};
14 
15 /// An intrusive linked list.
16 ///
17 /// Currently, the list is not emptied on drop. It is the caller's
18 /// responsibility to ensure the list is empty before dropping it.
19 pub(crate) struct LinkedList<L, T> {
20     /// Linked list head
21     head: Option<NonNull<T>>,
22 
23     /// Linked list tail
24     tail: Option<NonNull<T>>,
25 
26     /// Node type marker.
27     _marker: PhantomData<*const L>,
28 }
29 
30 unsafe impl<L: Link> Send for LinkedList<L, L::Target> where L::Target: Send {}
31 unsafe impl<L: Link> Sync for LinkedList<L, L::Target> where L::Target: Sync {}
32 
33 /// Defines how a type is tracked within a linked list.
34 ///
35 /// In order to support storing a single type within multiple lists, accessing
36 /// the list pointers is decoupled from the entry type.
37 ///
38 /// # Safety
39 ///
40 /// Implementations must guarantee that `Target` types are pinned in memory. In
41 /// other words, when a node is inserted, the value will not be moved as long as
42 /// it is stored in the list.
43 pub(crate) unsafe trait Link {
44     /// Handle to the list entry.
45     ///
46     /// This is usually a pointer-ish type.
47     type Handle;
48 
49     /// Node type.
50     type Target;
51 
52     /// Convert the handle to a raw pointer without consuming the handle.
53     #[allow(clippy::wrong_self_convention)]
as_raw(handle: &Self::Handle) -> NonNull<Self::Target>54     fn as_raw(handle: &Self::Handle) -> NonNull<Self::Target>;
55 
56     /// Convert the raw pointer to a handle
from_raw(ptr: NonNull<Self::Target>) -> Self::Handle57     unsafe fn from_raw(ptr: NonNull<Self::Target>) -> Self::Handle;
58 
59     /// Return the pointers for a node
pointers(target: NonNull<Self::Target>) -> NonNull<Pointers<Self::Target>>60     unsafe fn pointers(target: NonNull<Self::Target>) -> NonNull<Pointers<Self::Target>>;
61 }
62 
63 /// Previous / next pointers.
64 pub(crate) struct Pointers<T> {
65     inner: UnsafeCell<PointersInner<T>>,
66 }
67 /// We do not want the compiler to put the `noalias` attribute on mutable
68 /// references to this type, so the type has been made `!Unpin` with a
69 /// `PhantomPinned` field.
70 ///
71 /// Additionally, we never access the `prev` or `next` fields directly, as any
72 /// such access would implicitly involve the creation of a reference to the
73 /// field, which we want to avoid since the fields are not `!Unpin`, and would
74 /// hence be given the `noalias` attribute if we were to do such an access.
75 /// As an alternative to accessing the fields directly, the `Pointers` type
76 /// provides getters and setters for the two fields, and those are implemented
77 /// using raw pointer casts and offsets, which is valid since the struct is
78 /// #[repr(C)].
79 ///
80 /// See this link for more information:
81 /// <https://github.com/rust-lang/rust/pull/82834>
82 #[repr(C)]
83 struct PointersInner<T> {
84     /// The previous node in the list. null if there is no previous node.
85     ///
86     /// This field is accessed through pointer manipulation, so it is not dead code.
87     #[allow(dead_code)]
88     prev: Option<NonNull<T>>,
89 
90     /// The next node in the list. null if there is no previous node.
91     ///
92     /// This field is accessed through pointer manipulation, so it is not dead code.
93     #[allow(dead_code)]
94     next: Option<NonNull<T>>,
95 
96     /// This type is !Unpin due to the heuristic from:
97     /// <https://github.com/rust-lang/rust/pull/82834>
98     _pin: PhantomPinned,
99 }
100 
101 unsafe impl<T: Send> Send for Pointers<T> {}
102 unsafe impl<T: Sync> Sync for Pointers<T> {}
103 
104 // ===== impl LinkedList =====
105 
106 impl<L, T> LinkedList<L, T> {
107     /// Creates an empty linked list.
new() -> LinkedList<L, T>108     pub(crate) const fn new() -> LinkedList<L, T> {
109         LinkedList {
110             head: None,
111             tail: None,
112             _marker: PhantomData,
113         }
114     }
115 }
116 
117 impl<L: Link> LinkedList<L, L::Target> {
118     /// Adds an element first in the list.
push_front(&mut self, val: L::Handle)119     pub(crate) fn push_front(&mut self, val: L::Handle) {
120         // The value should not be dropped, it is being inserted into the list
121         let val = ManuallyDrop::new(val);
122         let ptr = L::as_raw(&*val);
123         assert_ne!(self.head, Some(ptr));
124         unsafe {
125             L::pointers(ptr).as_mut().set_next(self.head);
126             L::pointers(ptr).as_mut().set_prev(None);
127 
128             if let Some(head) = self.head {
129                 L::pointers(head).as_mut().set_prev(Some(ptr));
130             }
131 
132             self.head = Some(ptr);
133 
134             if self.tail.is_none() {
135                 self.tail = Some(ptr);
136             }
137         }
138     }
139 
140     /// Removes the last element from a list and returns it, or None if it is
141     /// empty.
pop_back(&mut self) -> Option<L::Handle>142     pub(crate) fn pop_back(&mut self) -> Option<L::Handle> {
143         unsafe {
144             let last = self.tail?;
145             self.tail = L::pointers(last).as_ref().get_prev();
146 
147             if let Some(prev) = L::pointers(last).as_ref().get_prev() {
148                 L::pointers(prev).as_mut().set_next(None);
149             } else {
150                 self.head = None
151             }
152 
153             L::pointers(last).as_mut().set_prev(None);
154             L::pointers(last).as_mut().set_next(None);
155 
156             Some(L::from_raw(last))
157         }
158     }
159 
160     /// Returns whether the linked list does not contain any node
is_empty(&self) -> bool161     pub(crate) fn is_empty(&self) -> bool {
162         if self.head.is_some() {
163             return false;
164         }
165 
166         assert!(self.tail.is_none());
167         true
168     }
169 
170     /// Removes the specified node from the list
171     ///
172     /// # Safety
173     ///
174     /// The caller **must** ensure that `node` is currently contained by
175     /// `self` or not contained by any other list.
remove(&mut self, node: NonNull<L::Target>) -> Option<L::Handle>176     pub(crate) unsafe fn remove(&mut self, node: NonNull<L::Target>) -> Option<L::Handle> {
177         if let Some(prev) = L::pointers(node).as_ref().get_prev() {
178             debug_assert_eq!(L::pointers(prev).as_ref().get_next(), Some(node));
179             L::pointers(prev)
180                 .as_mut()
181                 .set_next(L::pointers(node).as_ref().get_next());
182         } else {
183             if self.head != Some(node) {
184                 return None;
185             }
186 
187             self.head = L::pointers(node).as_ref().get_next();
188         }
189 
190         if let Some(next) = L::pointers(node).as_ref().get_next() {
191             debug_assert_eq!(L::pointers(next).as_ref().get_prev(), Some(node));
192             L::pointers(next)
193                 .as_mut()
194                 .set_prev(L::pointers(node).as_ref().get_prev());
195         } else {
196             // This might be the last item in the list
197             if self.tail != Some(node) {
198                 return None;
199             }
200 
201             self.tail = L::pointers(node).as_ref().get_prev();
202         }
203 
204         L::pointers(node).as_mut().set_next(None);
205         L::pointers(node).as_mut().set_prev(None);
206 
207         Some(L::from_raw(node))
208     }
209 }
210 
211 impl<L: Link> fmt::Debug for LinkedList<L, L::Target> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result212     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213         f.debug_struct("LinkedList")
214             .field("head", &self.head)
215             .field("tail", &self.tail)
216             .finish()
217     }
218 }
219 
220 #[cfg(any(
221     feature = "fs",
222     all(unix, feature = "process"),
223     feature = "signal",
224     feature = "sync",
225 ))]
226 impl<L: Link> LinkedList<L, L::Target> {
last(&self) -> Option<&L::Target>227     pub(crate) fn last(&self) -> Option<&L::Target> {
228         let tail = self.tail.as_ref()?;
229         unsafe { Some(&*tail.as_ptr()) }
230     }
231 }
232 
233 impl<L: Link> Default for LinkedList<L, L::Target> {
default() -> Self234     fn default() -> Self {
235         Self::new()
236     }
237 }
238 
239 // ===== impl DrainFilter =====
240 
241 cfg_io_readiness! {
242     pub(crate) struct DrainFilter<'a, T: Link, F> {
243         list: &'a mut LinkedList<T, T::Target>,
244         filter: F,
245         curr: Option<NonNull<T::Target>>,
246     }
247 
248     impl<T: Link> LinkedList<T, T::Target> {
249         pub(crate) fn drain_filter<F>(&mut self, filter: F) -> DrainFilter<'_, T, F>
250         where
251             F: FnMut(&mut T::Target) -> bool,
252         {
253             let curr = self.head;
254             DrainFilter {
255                 curr,
256                 filter,
257                 list: self,
258             }
259         }
260     }
261 
262     impl<'a, T, F> Iterator for DrainFilter<'a, T, F>
263     where
264         T: Link,
265         F: FnMut(&mut T::Target) -> bool,
266     {
267         type Item = T::Handle;
268 
269         fn next(&mut self) -> Option<Self::Item> {
270             while let Some(curr) = self.curr {
271                 // safety: the pointer references data contained by the list
272                 self.curr = unsafe { T::pointers(curr).as_ref() }.get_next();
273 
274                 // safety: the value is still owned by the linked list.
275                 if (self.filter)(unsafe { &mut *curr.as_ptr() }) {
276                     return unsafe { self.list.remove(curr) };
277                 }
278             }
279 
280             None
281         }
282     }
283 }
284 
285 // ===== impl Pointers =====
286 
287 impl<T> Pointers<T> {
288     /// Create a new set of empty pointers
new() -> Pointers<T>289     pub(crate) fn new() -> Pointers<T> {
290         Pointers {
291             inner: UnsafeCell::new(PointersInner {
292                 prev: None,
293                 next: None,
294                 _pin: PhantomPinned,
295             }),
296         }
297     }
298 
get_prev(&self) -> Option<NonNull<T>>299     fn get_prev(&self) -> Option<NonNull<T>> {
300         // SAFETY: prev is the first field in PointersInner, which is #[repr(C)].
301         unsafe {
302             let inner = self.inner.get();
303             let prev = inner as *const Option<NonNull<T>>;
304             ptr::read(prev)
305         }
306     }
get_next(&self) -> Option<NonNull<T>>307     fn get_next(&self) -> Option<NonNull<T>> {
308         // SAFETY: next is the second field in PointersInner, which is #[repr(C)].
309         unsafe {
310             let inner = self.inner.get();
311             let prev = inner as *const Option<NonNull<T>>;
312             let next = prev.add(1);
313             ptr::read(next)
314         }
315     }
316 
set_prev(&mut self, value: Option<NonNull<T>>)317     fn set_prev(&mut self, value: Option<NonNull<T>>) {
318         // SAFETY: prev is the first field in PointersInner, which is #[repr(C)].
319         unsafe {
320             let inner = self.inner.get();
321             let prev = inner as *mut Option<NonNull<T>>;
322             ptr::write(prev, value);
323         }
324     }
set_next(&mut self, value: Option<NonNull<T>>)325     fn set_next(&mut self, value: Option<NonNull<T>>) {
326         // SAFETY: next is the second field in PointersInner, which is #[repr(C)].
327         unsafe {
328             let inner = self.inner.get();
329             let prev = inner as *mut Option<NonNull<T>>;
330             let next = prev.add(1);
331             ptr::write(next, value);
332         }
333     }
334 }
335 
336 impl<T> fmt::Debug for Pointers<T> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result337     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
338         let prev = self.get_prev();
339         let next = self.get_next();
340         f.debug_struct("Pointers")
341             .field("prev", &prev)
342             .field("next", &next)
343             .finish()
344     }
345 }
346 
347 #[cfg(test)]
348 #[cfg(not(loom))]
349 mod tests {
350     use super::*;
351 
352     use std::pin::Pin;
353 
354     #[derive(Debug)]
355     struct Entry {
356         pointers: Pointers<Entry>,
357         val: i32,
358     }
359 
360     unsafe impl<'a> Link for &'a Entry {
361         type Handle = Pin<&'a Entry>;
362         type Target = Entry;
363 
as_raw(handle: &Pin<&'_ Entry>) -> NonNull<Entry>364         fn as_raw(handle: &Pin<&'_ Entry>) -> NonNull<Entry> {
365             NonNull::from(handle.get_ref())
366         }
367 
from_raw(ptr: NonNull<Entry>) -> Pin<&'a Entry>368         unsafe fn from_raw(ptr: NonNull<Entry>) -> Pin<&'a Entry> {
369             Pin::new_unchecked(&*ptr.as_ptr())
370         }
371 
pointers(mut target: NonNull<Entry>) -> NonNull<Pointers<Entry>>372         unsafe fn pointers(mut target: NonNull<Entry>) -> NonNull<Pointers<Entry>> {
373             NonNull::from(&mut target.as_mut().pointers)
374         }
375     }
376 
entry(val: i32) -> Pin<Box<Entry>>377     fn entry(val: i32) -> Pin<Box<Entry>> {
378         Box::pin(Entry {
379             pointers: Pointers::new(),
380             val,
381         })
382     }
383 
ptr(r: &Pin<Box<Entry>>) -> NonNull<Entry>384     fn ptr(r: &Pin<Box<Entry>>) -> NonNull<Entry> {
385         r.as_ref().get_ref().into()
386     }
387 
collect_list(list: &mut LinkedList<&'_ Entry, <&'_ Entry as Link>::Target>) -> Vec<i32>388     fn collect_list(list: &mut LinkedList<&'_ Entry, <&'_ Entry as Link>::Target>) -> Vec<i32> {
389         let mut ret = vec![];
390 
391         while let Some(entry) = list.pop_back() {
392             ret.push(entry.val);
393         }
394 
395         ret
396     }
397 
push_all<'a>( list: &mut LinkedList<&'a Entry, <&'_ Entry as Link>::Target>, entries: &[Pin<&'a Entry>], )398     fn push_all<'a>(
399         list: &mut LinkedList<&'a Entry, <&'_ Entry as Link>::Target>,
400         entries: &[Pin<&'a Entry>],
401     ) {
402         for entry in entries.iter() {
403             list.push_front(*entry);
404         }
405     }
406 
407     macro_rules! assert_clean {
408         ($e:ident) => {{
409             assert!($e.pointers.get_next().is_none());
410             assert!($e.pointers.get_prev().is_none());
411         }};
412     }
413 
414     macro_rules! assert_ptr_eq {
415         ($a:expr, $b:expr) => {{
416             // Deal with mapping a Pin<&mut T> -> Option<NonNull<T>>
417             assert_eq!(Some($a.as_ref().get_ref().into()), $b)
418         }};
419     }
420 
421     #[test]
const_new()422     fn const_new() {
423         const _: LinkedList<&Entry, <&Entry as Link>::Target> = LinkedList::new();
424     }
425 
426     #[test]
push_and_drain()427     fn push_and_drain() {
428         let a = entry(5);
429         let b = entry(7);
430         let c = entry(31);
431 
432         let mut list = LinkedList::new();
433         assert!(list.is_empty());
434 
435         list.push_front(a.as_ref());
436         assert!(!list.is_empty());
437         list.push_front(b.as_ref());
438         list.push_front(c.as_ref());
439 
440         let items: Vec<i32> = collect_list(&mut list);
441         assert_eq!([5, 7, 31].to_vec(), items);
442 
443         assert!(list.is_empty());
444     }
445 
446     #[test]
push_pop_push_pop()447     fn push_pop_push_pop() {
448         let a = entry(5);
449         let b = entry(7);
450 
451         let mut list = LinkedList::<&Entry, <&Entry as Link>::Target>::new();
452 
453         list.push_front(a.as_ref());
454 
455         let entry = list.pop_back().unwrap();
456         assert_eq!(5, entry.val);
457         assert!(list.is_empty());
458 
459         list.push_front(b.as_ref());
460 
461         let entry = list.pop_back().unwrap();
462         assert_eq!(7, entry.val);
463 
464         assert!(list.is_empty());
465         assert!(list.pop_back().is_none());
466     }
467 
468     #[test]
remove_by_address()469     fn remove_by_address() {
470         let a = entry(5);
471         let b = entry(7);
472         let c = entry(31);
473 
474         unsafe {
475             // Remove first
476             let mut list = LinkedList::new();
477 
478             push_all(&mut list, &[c.as_ref(), b.as_ref(), a.as_ref()]);
479             assert!(list.remove(ptr(&a)).is_some());
480             assert_clean!(a);
481             // `a` should be no longer there and can't be removed twice
482             assert!(list.remove(ptr(&a)).is_none());
483             assert!(!list.is_empty());
484 
485             assert!(list.remove(ptr(&b)).is_some());
486             assert_clean!(b);
487             // `b` should be no longer there and can't be removed twice
488             assert!(list.remove(ptr(&b)).is_none());
489             assert!(!list.is_empty());
490 
491             assert!(list.remove(ptr(&c)).is_some());
492             assert_clean!(c);
493             // `b` should be no longer there and can't be removed twice
494             assert!(list.remove(ptr(&c)).is_none());
495             assert!(list.is_empty());
496         }
497 
498         unsafe {
499             // Remove middle
500             let mut list = LinkedList::new();
501 
502             push_all(&mut list, &[c.as_ref(), b.as_ref(), a.as_ref()]);
503 
504             assert!(list.remove(ptr(&a)).is_some());
505             assert_clean!(a);
506 
507             assert_ptr_eq!(b, list.head);
508             assert_ptr_eq!(c, b.pointers.get_next());
509             assert_ptr_eq!(b, c.pointers.get_prev());
510 
511             let items = collect_list(&mut list);
512             assert_eq!([31, 7].to_vec(), items);
513         }
514 
515         unsafe {
516             // Remove middle
517             let mut list = LinkedList::new();
518 
519             push_all(&mut list, &[c.as_ref(), b.as_ref(), a.as_ref()]);
520 
521             assert!(list.remove(ptr(&b)).is_some());
522             assert_clean!(b);
523 
524             assert_ptr_eq!(c, a.pointers.get_next());
525             assert_ptr_eq!(a, c.pointers.get_prev());
526 
527             let items = collect_list(&mut list);
528             assert_eq!([31, 5].to_vec(), items);
529         }
530 
531         unsafe {
532             // Remove last
533             // Remove middle
534             let mut list = LinkedList::new();
535 
536             push_all(&mut list, &[c.as_ref(), b.as_ref(), a.as_ref()]);
537 
538             assert!(list.remove(ptr(&c)).is_some());
539             assert_clean!(c);
540 
541             assert!(b.pointers.get_next().is_none());
542             assert_ptr_eq!(b, list.tail);
543 
544             let items = collect_list(&mut list);
545             assert_eq!([7, 5].to_vec(), items);
546         }
547 
548         unsafe {
549             // Remove first of two
550             let mut list = LinkedList::new();
551 
552             push_all(&mut list, &[b.as_ref(), a.as_ref()]);
553 
554             assert!(list.remove(ptr(&a)).is_some());
555 
556             assert_clean!(a);
557 
558             // a should be no longer there and can't be removed twice
559             assert!(list.remove(ptr(&a)).is_none());
560 
561             assert_ptr_eq!(b, list.head);
562             assert_ptr_eq!(b, list.tail);
563 
564             assert!(b.pointers.get_next().is_none());
565             assert!(b.pointers.get_prev().is_none());
566 
567             let items = collect_list(&mut list);
568             assert_eq!([7].to_vec(), items);
569         }
570 
571         unsafe {
572             // Remove last of two
573             let mut list = LinkedList::new();
574 
575             push_all(&mut list, &[b.as_ref(), a.as_ref()]);
576 
577             assert!(list.remove(ptr(&b)).is_some());
578 
579             assert_clean!(b);
580 
581             assert_ptr_eq!(a, list.head);
582             assert_ptr_eq!(a, list.tail);
583 
584             assert!(a.pointers.get_next().is_none());
585             assert!(a.pointers.get_prev().is_none());
586 
587             let items = collect_list(&mut list);
588             assert_eq!([5].to_vec(), items);
589         }
590 
591         unsafe {
592             // Remove last item
593             let mut list = LinkedList::new();
594 
595             push_all(&mut list, &[a.as_ref()]);
596 
597             assert!(list.remove(ptr(&a)).is_some());
598             assert_clean!(a);
599 
600             assert!(list.head.is_none());
601             assert!(list.tail.is_none());
602             let items = collect_list(&mut list);
603             assert!(items.is_empty());
604         }
605 
606         unsafe {
607             // Remove missing
608             let mut list = LinkedList::<&Entry, <&Entry as Link>::Target>::new();
609 
610             list.push_front(b.as_ref());
611             list.push_front(a.as_ref());
612 
613             assert!(list.remove(ptr(&c)).is_none());
614         }
615     }
616 
617     #[cfg(not(target_arch = "wasm32"))]
618     proptest::proptest! {
619         #[test]
620         fn fuzz_linked_list(ops: Vec<usize>) {
621             run_fuzz(ops);
622         }
623     }
624 
run_fuzz(ops: Vec<usize>)625     fn run_fuzz(ops: Vec<usize>) {
626         use std::collections::VecDeque;
627 
628         #[derive(Debug)]
629         enum Op {
630             Push,
631             Pop,
632             Remove(usize),
633         }
634 
635         let ops = ops
636             .iter()
637             .map(|i| match i % 3 {
638                 0 => Op::Push,
639                 1 => Op::Pop,
640                 2 => Op::Remove(i / 3),
641                 _ => unreachable!(),
642             })
643             .collect::<Vec<_>>();
644 
645         let mut ll = LinkedList::<&Entry, <&Entry as Link>::Target>::new();
646         let mut reference = VecDeque::new();
647 
648         let entries: Vec<_> = (0..ops.len()).map(|i| entry(i as i32)).collect();
649 
650         for (i, op) in ops.iter().enumerate() {
651             match op {
652                 Op::Push => {
653                     reference.push_front(i as i32);
654                     assert_eq!(entries[i].val, i as i32);
655 
656                     ll.push_front(entries[i].as_ref());
657                 }
658                 Op::Pop => {
659                     if reference.is_empty() {
660                         assert!(ll.is_empty());
661                         continue;
662                     }
663 
664                     let v = reference.pop_back();
665                     assert_eq!(v, ll.pop_back().map(|v| v.val));
666                 }
667                 Op::Remove(n) => {
668                     if reference.is_empty() {
669                         assert!(ll.is_empty());
670                         continue;
671                     }
672 
673                     let idx = n % reference.len();
674                     let expect = reference.remove(idx).unwrap();
675 
676                     unsafe {
677                         let entry = ll.remove(ptr(&entries[expect as usize])).unwrap();
678                         assert_eq!(expect, entry.val);
679                     }
680                 }
681             }
682         }
683     }
684 }
685