1 use core::{cmp, mem, ptr};
2 
3 use crate::mem::MaybeUninit;
4 
maybe_uninit_slice_as_mut_ptr<T: Copy>(this: &mut [MaybeUninit<T>]) -> *mut T5 fn maybe_uninit_slice_as_mut_ptr<T: Copy>(this: &mut [MaybeUninit<T>]) -> *mut T {
6     this as *mut [MaybeUninit<T>] as *mut T
7 }
8 
9 struct CopyOnDrop<T> {
10     src: *mut T,
11     dest: *mut T,
12 }
13 
14 impl<T> Drop for CopyOnDrop<T> {
drop(&mut self)15     fn drop(&mut self) {
16         unsafe {
17             ptr::copy_nonoverlapping(self.src, self.dest, 1);
18         }
19     }
20 }
21 
shift_tail<T, F>(v: &mut [T], is_less: &mut F) where F: FnMut(&T, &T) -> bool,22 fn shift_tail<T, F>(v: &mut [T], is_less: &mut F)
23 where
24     F: FnMut(&T, &T) -> bool,
25 {
26     let len = v.len();
27     unsafe {
28         if len >= 2 && is_less(v.get_unchecked(len - 1), v.get_unchecked(len - 2)) {
29             let mut tmp = mem::ManuallyDrop::new(ptr::read(v.get_unchecked(len - 1)));
30             let mut hole = CopyOnDrop {
31                 src: &mut *tmp,
32                 dest: v.get_unchecked_mut(len - 2),
33             };
34             ptr::copy_nonoverlapping(v.get_unchecked(len - 2), v.get_unchecked_mut(len - 1), 1);
35 
36             for i in (0..len - 2).rev() {
37                 if !is_less(&*tmp, v.get_unchecked(i)) {
38                     break;
39                 }
40 
41                 ptr::copy_nonoverlapping(v.get_unchecked(i), v.get_unchecked_mut(i + 1), 1);
42                 hole.dest = v.get_unchecked_mut(i);
43             }
44         }
45     }
46 }
47 
insertion_sort<T, F>(v: &mut [T], is_less: &mut F) where F: FnMut(&T, &T) -> bool,48 fn insertion_sort<T, F>(v: &mut [T], is_less: &mut F)
49 where
50     F: FnMut(&T, &T) -> bool,
51 {
52     for i in 1..v.len() {
53         shift_tail(&mut v[..i + 1], is_less);
54     }
55 }
56 
partition_in_blocks<T, F>(v: &mut [T], pivot: &T, is_less: &mut F) -> usize where F: FnMut(&T, &T) -> bool,57 fn partition_in_blocks<T, F>(v: &mut [T], pivot: &T, is_less: &mut F) -> usize
58 where
59     F: FnMut(&T, &T) -> bool,
60 {
61     const BLOCK: usize = 128;
62 
63     let mut l = v.as_mut_ptr();
64     let mut block_l = BLOCK;
65     let mut start_l = ptr::null_mut();
66     let mut end_l = ptr::null_mut();
67     let mut offsets_l = [MaybeUninit::<u8>::uninit(); BLOCK];
68 
69     let mut r = unsafe { l.add(v.len()) };
70     let mut block_r = BLOCK;
71     let mut start_r = ptr::null_mut();
72     let mut end_r = ptr::null_mut();
73     let mut offsets_r = [MaybeUninit::<u8>::uninit(); BLOCK];
74 
75     fn width<T>(l: *mut T, r: *mut T) -> usize {
76         assert!(mem::size_of::<T>() > 0);
77         (r as usize - l as usize) / mem::size_of::<T>()
78     }
79 
80     loop {
81         let is_done = width(l, r) <= 2 * BLOCK;
82 
83         if is_done {
84             let mut rem = width(l, r);
85             if start_l < end_l || start_r < end_r {
86                 rem -= BLOCK;
87             }
88 
89             if start_l < end_l {
90                 block_r = rem;
91             } else if start_r < end_r {
92                 block_l = rem;
93             } else {
94                 block_l = rem / 2;
95                 block_r = rem - block_l;
96             }
97             debug_assert!(block_l <= BLOCK && block_r <= BLOCK);
98             debug_assert!(width(l, r) == block_l + block_r);
99         }
100 
101         if start_l == end_l {
102             start_l = maybe_uninit_slice_as_mut_ptr(&mut offsets_l);
103             end_l = maybe_uninit_slice_as_mut_ptr(&mut offsets_l);
104             let mut elem = l;
105 
106             for i in 0..block_l {
107                 unsafe {
108                     *end_l = i as u8;
109                     end_l = end_l.offset(!is_less(&*elem, pivot) as isize);
110                     elem = elem.offset(1);
111                 }
112             }
113         }
114 
115         if start_r == end_r {
116             start_r = maybe_uninit_slice_as_mut_ptr(&mut offsets_r);
117             end_r = maybe_uninit_slice_as_mut_ptr(&mut offsets_r);
118             let mut elem = r;
119 
120             for i in 0..block_r {
121                 unsafe {
122                     elem = elem.offset(-1);
123                     *end_r = i as u8;
124                     end_r = end_r.offset(is_less(&*elem, pivot) as isize);
125                 }
126             }
127         }
128 
129         let count = cmp::min(width(start_l, end_l), width(start_r, end_r));
130 
131         if count > 0 {
132             macro_rules! left {
133                 () => {
134                     l.offset(*start_l as isize)
135                 };
136             }
137             macro_rules! right {
138                 () => {
139                     r.offset(-(*start_r as isize) - 1)
140                 };
141             }
142 
143             unsafe {
144                 let tmp = ptr::read(left!());
145                 ptr::copy_nonoverlapping(right!(), left!(), 1);
146 
147                 for _ in 1..count {
148                     start_l = start_l.offset(1);
149                     ptr::copy_nonoverlapping(left!(), right!(), 1);
150                     start_r = start_r.offset(1);
151                     ptr::copy_nonoverlapping(right!(), left!(), 1);
152                 }
153 
154                 ptr::copy_nonoverlapping(&tmp, right!(), 1);
155                 mem::forget(tmp);
156                 start_l = start_l.offset(1);
157                 start_r = start_r.offset(1);
158             }
159         }
160 
161         if start_l == end_l {
162             l = unsafe { l.add(block_l) };
163         }
164 
165         if start_r == end_r {
166             r = unsafe { r.offset(-(block_r as isize)) };
167         }
168 
169         if is_done {
170             break;
171         }
172     }
173 
174     if start_l < end_l {
175         debug_assert_eq!(width(l, r), block_l);
176         while start_l < end_l {
177             unsafe {
178                 end_l = end_l.offset(-1);
179                 ptr::swap(l.offset(*end_l as isize), r.offset(-1));
180                 r = r.offset(-1);
181             }
182         }
183         width(v.as_mut_ptr(), r)
184     } else if start_r < end_r {
185         debug_assert_eq!(width(l, r), block_r);
186         while start_r < end_r {
187             unsafe {
188                 end_r = end_r.offset(-1);
189                 ptr::swap(l, r.offset(-(*end_r as isize) - 1));
190                 l = l.offset(1);
191             }
192         }
193         width(v.as_mut_ptr(), l)
194     } else {
195         width(v.as_mut_ptr(), l)
196     }
197 }
198 
partition<T, F>(v: &mut [T], pivot: usize, is_less: &mut F) -> (usize, bool) where F: FnMut(&T, &T) -> bool,199 fn partition<T, F>(v: &mut [T], pivot: usize, is_less: &mut F) -> (usize, bool)
200 where
201     F: FnMut(&T, &T) -> bool,
202 {
203     let (mid, was_partitioned) = {
204         v.swap(0, pivot);
205         let (pivot, v) = v.split_at_mut(1);
206         let pivot = &mut pivot[0];
207 
208         let mut tmp = mem::ManuallyDrop::new(unsafe { ptr::read(pivot) });
209         let _pivot_guard = CopyOnDrop {
210             src: &mut *tmp,
211             dest: pivot,
212         };
213         let pivot = &*tmp;
214 
215         let mut l = 0;
216         let mut r = v.len();
217 
218         unsafe {
219             while l < r && is_less(v.get_unchecked(l), pivot) {
220                 l += 1;
221             }
222 
223             while l < r && !is_less(v.get_unchecked(r - 1), pivot) {
224                 r -= 1;
225             }
226         }
227 
228         (
229             l + partition_in_blocks(&mut v[l..r], pivot, is_less),
230             l >= r,
231         )
232     };
233 
234     v.swap(0, mid);
235 
236     (mid, was_partitioned)
237 }
238 
partition_equal<T, F>(v: &mut [T], pivot: usize, is_less: &mut F) -> usize where F: FnMut(&T, &T) -> bool,239 fn partition_equal<T, F>(v: &mut [T], pivot: usize, is_less: &mut F) -> usize
240 where
241     F: FnMut(&T, &T) -> bool,
242 {
243     v.swap(0, pivot);
244     let (pivot, v) = v.split_at_mut(1);
245     let pivot = &mut pivot[0];
246 
247     let mut tmp = mem::ManuallyDrop::new(unsafe { ptr::read(pivot) });
248     let _pivot_guard = CopyOnDrop {
249         src: &mut *tmp,
250         dest: pivot,
251     };
252     let pivot = &*tmp;
253 
254     let mut l = 0;
255     let mut r = v.len();
256     loop {
257         unsafe {
258             while l < r && !is_less(pivot, v.get_unchecked(l)) {
259                 l += 1;
260             }
261 
262             while l < r && is_less(pivot, v.get_unchecked(r - 1)) {
263                 r -= 1;
264             }
265 
266             if l >= r {
267                 break;
268             }
269 
270             r -= 1;
271             ptr::swap(v.get_unchecked_mut(l), v.get_unchecked_mut(r));
272             l += 1;
273         }
274     }
275 
276     l + 1
277 }
278 
choose_pivot<T, F>(v: &mut [T], is_less: &mut F) -> (usize, bool) where F: FnMut(&T, &T) -> bool,279 fn choose_pivot<T, F>(v: &mut [T], is_less: &mut F) -> (usize, bool)
280 where
281     F: FnMut(&T, &T) -> bool,
282 {
283     const SHORTEST_MEDIAN_OF_MEDIANS: usize = 50;
284     const MAX_SWAPS: usize = 4 * 3;
285 
286     let len = v.len();
287 
288     let mut a = len / 4;
289     let mut b = len / 4 * 2;
290     let mut c = len / 4 * 3;
291 
292     let mut swaps = 0;
293 
294     if len >= 8 {
295         let mut sort2 = |a: &mut usize, b: &mut usize| unsafe {
296             if is_less(v.get_unchecked(*b), v.get_unchecked(*a)) {
297                 ptr::swap(a, b);
298                 swaps += 1;
299             }
300         };
301 
302         let mut sort3 = |a: &mut usize, b: &mut usize, c: &mut usize| {
303             sort2(a, b);
304             sort2(b, c);
305             sort2(a, b);
306         };
307 
308         if len >= SHORTEST_MEDIAN_OF_MEDIANS {
309             let mut sort_adjacent = |a: &mut usize| {
310                 let tmp = *a;
311                 sort3(&mut (tmp - 1), a, &mut (tmp + 1));
312             };
313 
314             sort_adjacent(&mut a);
315             sort_adjacent(&mut b);
316             sort_adjacent(&mut c);
317         }
318 
319         sort3(&mut a, &mut b, &mut c);
320     }
321 
322     if swaps < MAX_SWAPS {
323         (b, swaps == 0)
324     } else {
325         v.reverse();
326         (len - 1 - b, true)
327     }
328 }
329 
partition_at_index_loop<'a, T, F>( mut v: &'a mut [T], mut index: usize, is_less: &mut F, mut pred: Option<&'a T>, ) where F: FnMut(&T, &T) -> bool,330 fn partition_at_index_loop<'a, T, F>(
331     mut v: &'a mut [T],
332     mut index: usize,
333     is_less: &mut F,
334     mut pred: Option<&'a T>,
335 ) where
336     F: FnMut(&T, &T) -> bool,
337 {
338     loop {
339         const MAX_INSERTION: usize = 10;
340         if v.len() <= MAX_INSERTION {
341             insertion_sort(v, is_less);
342             return;
343         }
344 
345         let (pivot, _) = choose_pivot(v, is_less);
346 
347         if let Some(p) = pred {
348             if !is_less(p, &v[pivot]) {
349                 let mid = partition_equal(v, pivot, is_less);
350 
351                 if mid > index {
352                     return;
353                 }
354 
355                 v = &mut v[mid..];
356                 index -= mid;
357                 pred = None;
358                 continue;
359             }
360         }
361 
362         let (mid, _) = partition(v, pivot, is_less);
363 
364         let (left, right) = { v }.split_at_mut(mid);
365         let (pivot, right) = right.split_at_mut(1);
366         let pivot = &pivot[0];
367 
368         match mid.cmp(&index) {
369             cmp::Ordering::Less => {
370                 v = right;
371                 index = index - mid - 1;
372                 pred = Some(pivot);
373             }
374             cmp::Ordering::Greater => {
375                 v = left;
376             }
377             cmp::Ordering::Equal => {
378                 return;
379             }
380         }
381     }
382 }
383 
partition_at_index<T, F>( v: &mut [T], index: usize, mut is_less: F, ) -> (&mut [T], &mut T, &mut [T]) where F: FnMut(&T, &T) -> bool,384 pub(crate) fn partition_at_index<T, F>(
385     v: &mut [T],
386     index: usize,
387     mut is_less: F,
388 ) -> (&mut [T], &mut T, &mut [T])
389 where
390     F: FnMut(&T, &T) -> bool,
391 {
392     use self::cmp::Ordering::{Greater, Less};
393 
394     if index >= v.len() {
395         panic!(
396             "partition_at_index index {} greater than length of slice {}",
397             index,
398             v.len()
399         );
400     }
401 
402     if mem::size_of::<T>() == 0 {
403     } else if index == v.len() - 1 {
404         let (max_index, _) = v
405             .iter()
406             .enumerate()
407             .max_by(|&(_, x), &(_, y)| if is_less(x, y) { Less } else { Greater })
408             .unwrap();
409         v.swap(max_index, index);
410     } else if index == 0 {
411         let (min_index, _) = v
412             .iter()
413             .enumerate()
414             .min_by(|&(_, x), &(_, y)| if is_less(x, y) { Less } else { Greater })
415             .unwrap();
416         v.swap(min_index, index);
417     } else {
418         partition_at_index_loop(v, index, &mut is_less, None);
419     }
420 
421     let (left, right) = v.split_at_mut(index);
422     let (pivot, right) = right.split_at_mut(1);
423     let pivot = &mut pivot[0];
424     (left, pivot, right)
425 }
426