1 use std::{
2     cell::UnsafeCell,
3     cmp::min,
4     mem::{self, MaybeUninit},
5     ptr::{self, copy},
6     sync::{
7         atomic::{AtomicUsize, Ordering},
8         Arc,
9     },
10 };
11 
12 use crate::{consumer::Consumer, producer::Producer};
13 
14 pub(crate) struct SharedVec<T: Sized> {
15     cell: UnsafeCell<Vec<T>>,
16 }
17 
18 unsafe impl<T: Sized> Sync for SharedVec<T> {}
19 
20 impl<T: Sized> SharedVec<T> {
new(data: Vec<T>) -> Self21     pub fn new(data: Vec<T>) -> Self {
22         Self {
23             cell: UnsafeCell::new(data),
24         }
25     }
get_ref(&self) -> &Vec<T>26     pub unsafe fn get_ref(&self) -> &Vec<T> {
27         &*self.cell.get()
28     }
29     #[allow(clippy::mut_from_ref)]
get_mut(&self) -> &mut Vec<T>30     pub unsafe fn get_mut(&self) -> &mut Vec<T> {
31         &mut *self.cell.get()
32     }
33 }
34 
35 /// Ring buffer itself.
36 pub struct RingBuffer<T: Sized> {
37     pub(crate) data: SharedVec<MaybeUninit<T>>,
38     pub(crate) head: AtomicUsize,
39     pub(crate) tail: AtomicUsize,
40 }
41 
42 impl<T: Sized> RingBuffer<T> {
43     /// Creates a new instance of a ring buffer.
new(capacity: usize) -> Self44     pub fn new(capacity: usize) -> Self {
45         let mut data = Vec::new();
46         data.resize_with(capacity + 1, MaybeUninit::uninit);
47         Self {
48             data: SharedVec::new(data),
49             head: AtomicUsize::new(0),
50             tail: AtomicUsize::new(0),
51         }
52     }
53 
54     /// Splits ring buffer into producer and consumer.
split(self) -> (Producer<T>, Consumer<T>)55     pub fn split(self) -> (Producer<T>, Consumer<T>) {
56         let arc = Arc::new(self);
57         (Producer { rb: arc.clone() }, Consumer { rb: arc })
58     }
59 
60     /// Returns capacity of the ring buffer.
capacity(&self) -> usize61     pub fn capacity(&self) -> usize {
62         unsafe { self.data.get_ref() }.len() - 1
63     }
64 
65     /// Checks if the ring buffer is empty.
is_empty(&self) -> bool66     pub fn is_empty(&self) -> bool {
67         let head = self.head.load(Ordering::Acquire);
68         let tail = self.tail.load(Ordering::Acquire);
69         head == tail
70     }
71 
72     /// Checks if the ring buffer is full.
is_full(&self) -> bool73     pub fn is_full(&self) -> bool {
74         let head = self.head.load(Ordering::Acquire);
75         let tail = self.tail.load(Ordering::Acquire);
76         (tail + 1) % (self.capacity() + 1) == head
77     }
78 
79     /// The length of the data in the buffer.
len(&self) -> usize80     pub fn len(&self) -> usize {
81         let head = self.head.load(Ordering::Acquire);
82         let tail = self.tail.load(Ordering::Acquire);
83         (tail + self.capacity() + 1 - head) % (self.capacity() + 1)
84     }
85 
86     /// The remaining space in the buffer.
remaining(&self) -> usize87     pub fn remaining(&self) -> usize {
88         self.capacity() - self.len()
89     }
90 }
91 
92 impl<T: Sized> Drop for RingBuffer<T> {
drop(&mut self)93     fn drop(&mut self) {
94         let data = unsafe { self.data.get_mut() };
95 
96         let head = self.head.load(Ordering::Acquire);
97         let tail = self.tail.load(Ordering::Acquire);
98         let len = data.len();
99 
100         let slices = if head <= tail {
101             (head..tail, 0..0)
102         } else {
103             (head..len, 0..tail)
104         };
105 
106         let drop = |elem_ref: &mut MaybeUninit<T>| unsafe {
107             mem::replace(elem_ref, MaybeUninit::uninit()).assume_init();
108         };
109         for elem in data[slices.0].iter_mut() {
110             drop(elem);
111         }
112         for elem in data[slices.1].iter_mut() {
113             drop(elem);
114         }
115     }
116 }
117 
118 struct SlicePtr<T: Sized> {
119     pub ptr: *mut T,
120     pub len: usize,
121 }
122 
123 impl<T> SlicePtr<T> {
null() -> Self124     fn null() -> Self {
125         Self {
126             ptr: ptr::null_mut(),
127             len: 0,
128         }
129     }
new(slice: &mut [T]) -> Self130     fn new(slice: &mut [T]) -> Self {
131         Self {
132             ptr: slice.as_mut_ptr(),
133             len: slice.len(),
134         }
135     }
shift(&mut self, count: usize)136     unsafe fn shift(&mut self, count: usize) {
137         self.ptr = self.ptr.add(count);
138         self.len -= count;
139     }
140 }
141 
142 /// Moves at most `count` items from the `src` consumer to the `dst` producer.
143 /// Consumer and producer may be of different buffers as well as of the same one.
144 ///
145 /// `count` is the number of items being moved, if `None` - as much as possible items will be moved.
146 ///
147 /// Returns number of items been moved.
move_items<T>(src: &mut Consumer<T>, dst: &mut Producer<T>, count: Option<usize>) -> usize148 pub fn move_items<T>(src: &mut Consumer<T>, dst: &mut Producer<T>, count: Option<usize>) -> usize {
149     unsafe {
150         src.pop_access(|src_left, src_right| -> usize {
151             dst.push_access(|dst_left, dst_right| -> usize {
152                 let n = count.unwrap_or_else(|| {
153                     min(
154                         src_left.len() + src_right.len(),
155                         dst_left.len() + dst_right.len(),
156                     )
157                 });
158                 let mut m = 0;
159                 let mut src = (SlicePtr::new(src_left), SlicePtr::new(src_right));
160                 let mut dst = (SlicePtr::new(dst_left), SlicePtr::new(dst_right));
161 
162                 loop {
163                     let k = min(n - m, min(src.0.len, dst.0.len));
164                     if k == 0 {
165                         break;
166                     }
167                     copy(src.0.ptr, dst.0.ptr, k);
168                     if src.0.len == k {
169                         src.0 = src.1;
170                         src.1 = SlicePtr::null();
171                     } else {
172                         src.0.shift(k);
173                     }
174                     if dst.0.len == k {
175                         dst.0 = dst.1;
176                         dst.1 = SlicePtr::null();
177                     } else {
178                         dst.0.shift(k);
179                     }
180                     m += k
181                 }
182 
183                 m
184             })
185         })
186     }
187 }
188