1{-# LANGUAGE BangPatterns #-}
2{-# LANGUAGE TypeFamilies #-}
3{-# LANGUAGE TypeOperators #-}
4{-# LANGUAGE ScopedTypeVariables #-}
5
6-- ---------------------------------------------------------------------------
7-- |
8-- Module      : Data.Vector.Algorithms.Intro
9-- Copyright   : (c) 2008-2015 Dan Doel
10-- Maintainer  : Dan Doel <dan.doel@gmail.com>
11-- Stability   : Experimental
12-- Portability : Non-portable (type operators, bang patterns)
13--
14-- This module implements various algorithms based on the introsort algorithm,
15-- originally described by David R. Musser in the paper /Introspective Sorting
16-- and Selection Algorithms/. It is also in widespread practical use, as the
17-- standard unstable sort used in the C++ Standard Template Library.
18--
19-- Introsort is at its core a quicksort. The version implemented here has the
20-- following optimizations that make it perform better in practice:
21--
22--   * Small segments of the array are left unsorted until a final insertion
23--     sort pass. This is faster than recursing all the way down to
24--     one-element arrays.
25--
26--   * The pivot for segment [l,u) is chosen as the median of the elements at
27--     l, u-1 and (u+l)/2. This yields good behavior on mostly sorted (or
28--     reverse-sorted) arrays.
29--
30--   * The algorithm tracks its recursion depth, and if it decides it is
31--     taking too long (depth greater than 2 * lg n), it switches to a heap
32--     sort to maintain O(n lg n) worst case behavior. (This is what makes the
33--     algorithm introsort).
34
35module Data.Vector.Algorithms.Intro
36       ( -- * Sorting
37         sort
38       , sortBy
39       , sortByBounds
40         -- * Selecting
41       , select
42       , selectBy
43       , selectByBounds
44         -- * Partial sorting
45       , partialSort
46       , partialSortBy
47       , partialSortByBounds
48       , Comparison
49       ) where
50
51import Prelude hiding (read, length)
52
53import Control.Monad
54import Control.Monad.Primitive
55
56import Data.Bits
57import Data.Vector.Generic.Mutable
58
59import Data.Vector.Algorithms.Common (Comparison, midPoint)
60
61import qualified Data.Vector.Algorithms.Insertion as I
62import qualified Data.Vector.Algorithms.Optimal   as O
63import qualified Data.Vector.Algorithms.Heap      as H
64
65-- | Sorts an entire array using the default ordering.
66sort :: (PrimMonad m, MVector v e, Ord e) => v (PrimState m) e -> m ()
67sort = sortBy compare
68{-# INLINABLE sort #-}
69
70-- | Sorts an entire array using a custom ordering.
71sortBy :: (PrimMonad m, MVector v e) => Comparison e -> v (PrimState m) e -> m ()
72sortBy cmp a = sortByBounds cmp a 0 (length a)
73{-# INLINE sortBy #-}
74
75-- | Sorts a portion of an array [l,u) using a custom ordering
76sortByBounds
77  :: (PrimMonad m, MVector v e)
78  => Comparison e
79  -> v (PrimState m) e
80  -> Int -- ^ lower index, l
81  -> Int -- ^ upper index, u
82  -> m ()
83sortByBounds cmp a l u
84  | len < 2   = return ()
85  | len == 2  = O.sort2ByOffset cmp a l
86  | len == 3  = O.sort3ByOffset cmp a l
87  | len == 4  = O.sort4ByOffset cmp a l
88  | otherwise = introsort cmp a (ilg len) l u
89 where len = u - l
90{-# INLINE sortByBounds #-}
91
92-- Internal version of the introsort loop which allows partial
93-- sort functions to call with a specified bound on iterations.
94introsort :: (PrimMonad m, MVector v e)
95          => Comparison e -> v (PrimState m) e -> Int -> Int -> Int -> m ()
96introsort cmp a i l u = sort i l u >> I.sortByBounds cmp a l u
97 where
98 sort 0 l u = H.sortByBounds cmp a l u
99 sort d l u
100   | len < threshold = return ()
101   | otherwise = do O.sort3ByIndex cmp a c l (u-1) -- sort the median into the lowest position
102                    p <- unsafeRead a l
103                    mid <- partitionBy cmp a p (l+1) u
104                    unsafeSwap a l (mid - 1)
105                    sort (d-1) mid u
106                    sort (d-1) l   (mid - 1)
107  where
108  len = u - l
109  c   = midPoint u l
110{-# INLINE introsort #-}
111
112-- | Moves the least k elements to the front of the array in
113-- no particular order.
114select
115  :: (PrimMonad m, MVector v e, Ord e)
116  => v (PrimState m) e
117  -> Int -- ^ number of elements to select, k
118  -> m ()
119select = selectBy compare
120{-# INLINE select #-}
121
122-- | Moves the least k elements (as defined by the comparison) to
123-- the front of the array in no particular order.
124selectBy
125  :: (PrimMonad m, MVector v e)
126  => Comparison e
127  -> v (PrimState m) e
128  -> Int -- ^ number of elements to select, k
129  -> m ()
130selectBy cmp a k = selectByBounds cmp a k 0 (length a)
131{-# INLINE selectBy #-}
132
133-- | Moves the least k elements in the interval [l,u) to the positions
134-- [l,k+l) in no particular order.
135selectByBounds
136  :: (PrimMonad m, MVector v e)
137  => Comparison e
138  -> v (PrimState m) e
139  -> Int -- ^ number of elements to select, k
140  -> Int -- ^ lower bound, l
141  -> Int -- ^ upper bound, u
142  -> m ()
143selectByBounds cmp a k l u
144  | l >= u    = return ()
145  | otherwise = go (ilg len) l (l + k) u
146 where
147 len = u - l
148 go 0 l m u = H.selectByBounds cmp a (m - l) l u
149 go n l m u = do O.sort3ByIndex cmp a c l (u-1)
150                 p <- unsafeRead a l
151                 mid <- partitionBy cmp a p (l+1) u
152                 unsafeSwap a l (mid - 1)
153                 if m > mid
154                   then go (n-1) mid m u
155                   else if m < mid - 1
156                        then go (n-1) l m (mid - 1)
157                        else return ()
158  where c = midPoint u l
159{-# INLINE selectByBounds #-}
160
161-- | Moves the least k elements to the front of the array, sorted.
162partialSort
163  :: (PrimMonad m, MVector v e, Ord e)
164  => v (PrimState m) e
165  -> Int -- ^ number of elements to sort, k
166  -> m ()
167partialSort = partialSortBy compare
168{-# INLINE partialSort #-}
169
170-- | Moves the least k elements (as defined by the comparison) to
171-- the front of the array, sorted.
172partialSortBy
173  :: (PrimMonad m, MVector v e)
174  => Comparison e
175  -> v (PrimState m) e
176  -> Int -- ^ number of elements to sort, k
177  -> m ()
178partialSortBy cmp a k = partialSortByBounds cmp a k 0 (length a)
179{-# INLINE partialSortBy #-}
180
181-- | Moves the least k elements in the interval [l,u) to the positions
182-- [l,k+l), sorted.
183partialSortByBounds
184  :: (PrimMonad m, MVector v e)
185  => Comparison e
186  -> v (PrimState m) e
187  -> Int -- ^ number of elements to sort, k
188  -> Int -- ^ lower index, l
189  -> Int -- ^ upper index, u
190  -> m ()
191partialSortByBounds cmp a k l u
192  | l >= u    = return ()
193  | otherwise = go (ilg len) l (l + k) u
194 where
195 isort = introsort cmp a
196 {-# INLINE [1] isort #-}
197 len = u - l
198 go 0 l m n = H.partialSortByBounds cmp a (m - l) l u
199 go n l m u
200   | l == m    = return ()
201   | otherwise = do O.sort3ByIndex cmp a c l (u-1)
202                    p <- unsafeRead a l
203                    mid <- partitionBy cmp a p (l+1) u
204                    unsafeSwap a l (mid - 1)
205                    case compare m mid of
206                      GT -> do isort (n-1) l (mid - 1)
207                               go (n-1) mid m u
208                      EQ -> isort (n-1) l m
209                      LT -> go n l m (mid - 1)
210  where c = midPoint u l
211{-# INLINE partialSortByBounds #-}
212
213partitionBy :: forall m v e. (PrimMonad m, MVector v e)
214            => Comparison e -> v (PrimState m) e -> e -> Int -> Int -> m Int
215partitionBy cmp a = partUp
216 where
217 partUp :: e -> Int -> Int -> m Int
218 partUp p l u
219   | l < u = do e <- unsafeRead a l
220                case cmp e p of
221                  LT -> partUp p (l+1) u
222                  _  -> partDown p l (u-1)
223   | otherwise = return l
224
225 partDown :: e -> Int -> Int -> m Int
226 partDown p l u
227   | l < u = do e <- unsafeRead a u
228                case cmp p e of
229                  LT -> partDown p l (u-1)
230                  _  -> unsafeSwap a l u >> partUp p (l+1) u
231   | otherwise = return l
232{-# INLINE partitionBy #-}
233
234-- computes the number of recursive calls after which heapsort should
235-- be invoked given the lower and upper indices of the array to be sorted
236ilg :: Int -> Int
237ilg m = 2 * loop m 0
238 where
239 loop 0 !k = k - 1
240 loop n !k = loop (n `shiftR` 1) (k+1)
241
242-- the size of array at which the introsort algorithm switches to insertion sort
243threshold :: Int
244threshold = 18
245{-# INLINE threshold #-}
246