1{-# LANGUAGE CPP #-}
2
3-- ---------------------------------------------------------------------------
4-- |
5-- Module      : Data.Vector.Algorithms.Optimal
6-- Copyright   : (c) 2008-2010 Dan Doel
7-- Maintainer  : Dan Doel
8-- Stability   : Experimental
9-- Portability : Portable
10--
11-- Optimal sorts for very small array sizes, or for small numbers of
12-- particular indices in a larger array (to be used, for instance, for
13-- sorting a median of 3 values into the lowest position in an array
14-- for a median-of-3 quicksort).
15
16-- The code herein was adapted from a C algorithm for optimal sorts
17-- of small arrays. The original code was produced for the article
18-- /Sorting Revisited/ by Paul Hsieh, available here:
19--
20--   http://www.azillionmonkeys.com/qed/sort.html
21--
22-- The LICENSE file contains the relevant copyright information for
23-- the reference C code.
24
25module Data.Vector.Algorithms.Optimal
26       ( sort2ByIndex
27       , sort2ByOffset
28       , sort3ByIndex
29       , sort3ByOffset
30       , sort4ByIndex
31       , sort4ByOffset
32       , Comparison
33       ) where
34
35import Prelude hiding (read, length)
36
37import Control.Monad.Primitive
38
39import Data.Vector.Generic.Mutable
40
41import Data.Vector.Algorithms.Common (Comparison)
42
43#include "vector.h"
44
45-- | Sorts the elements at the positions 'off' and 'off + 1' in the given
46-- array using the comparison.
47sort2ByOffset :: (PrimMonad m, MVector v e)
48              => Comparison e -> v (PrimState m) e -> Int -> m ()
49sort2ByOffset cmp a off = sort2ByIndex cmp a off (off + 1)
50{-# INLINABLE sort2ByOffset #-}
51
52-- | Sorts the elements at the two given indices using the comparison. This
53-- is essentially a compare-and-swap, although the first index is assumed to
54-- be the 'lower' of the two.
55sort2ByIndex :: (PrimMonad m, MVector v e)
56             => Comparison e -> v (PrimState m) e -> Int -> Int -> m ()
57sort2ByIndex cmp a i j = UNSAFE_CHECK(checkIndex) "sort2ByIndex" i (length a)
58                       $ UNSAFE_CHECK(checkIndex) "sort2ByIndex" j (length a) $  do
59  a0 <- unsafeRead a i
60  a1 <- unsafeRead a j
61  case cmp a0 a1 of
62    GT -> unsafeWrite a i a1 >> unsafeWrite a j a0
63    _  -> return ()
64{-# INLINABLE sort2ByIndex #-}
65
66-- | Sorts the three elements starting at the given offset in the array.
67sort3ByOffset :: (PrimMonad m, MVector v e)
68              => Comparison e -> v (PrimState m) e -> Int -> m ()
69sort3ByOffset cmp a off = sort3ByIndex cmp a off (off + 1) (off + 2)
70{-# INLINABLE sort3ByOffset #-}
71
72-- | Sorts the elements at the three given indices. The indices are assumed
73-- to be given from lowest to highest, so if 'l < m < u' then
74-- 'sort3ByIndex cmp a m l u' essentially sorts the median of three into the
75-- lowest position in the array.
76sort3ByIndex :: (PrimMonad m, MVector v e)
77             => Comparison e -> v (PrimState m) e -> Int -> Int -> Int -> m ()
78sort3ByIndex cmp a i j k = UNSAFE_CHECK(checkIndex) "sort3ByIndex" i (length a)
79                         $ UNSAFE_CHECK(checkIndex) "sort3ByIndex" j (length a)
80                         $ UNSAFE_CHECK(checkIndex) "sort3ByIndex" k (length a) $ do
81  a0 <- unsafeRead a i
82  a1 <- unsafeRead a j
83  a2 <- unsafeRead a k
84  case cmp a0 a1 of
85    GT -> case cmp a0 a2 of
86            GT -> case cmp a2 a1 of
87                    LT -> do unsafeWrite a i a2
88                             unsafeWrite a k a0
89                    _  -> do unsafeWrite a i a1
90                             unsafeWrite a j a2
91                             unsafeWrite a k a0
92            _  -> do unsafeWrite a i a1
93                     unsafeWrite a j a0
94    _  -> case cmp a1 a2 of
95            GT -> case cmp a0 a2 of
96                    GT -> do unsafeWrite a i a2
97                             unsafeWrite a j a0
98                             unsafeWrite a k a1
99                    _  -> do unsafeWrite a j a2
100                             unsafeWrite a k a1
101            _  -> return ()
102{-# INLINABLE sort3ByIndex #-}
103
104-- | Sorts the four elements beginning at the offset.
105sort4ByOffset :: (PrimMonad m, MVector v e)
106              => Comparison e -> v (PrimState m) e -> Int -> m ()
107sort4ByOffset cmp a off = sort4ByIndex cmp a off (off + 1) (off + 2) (off + 3)
108{-# INLINABLE sort4ByOffset #-}
109
110-- The horror...
111
112-- | Sorts the elements at the four given indices. Like the 2 and 3 element
113-- versions, this assumes that the indices are given in increasing order, so
114-- it can be used to sort medians into particular positions and so on.
115sort4ByIndex :: (PrimMonad m, MVector v e)
116             => Comparison e -> v (PrimState m) e -> Int -> Int -> Int -> Int -> m ()
117sort4ByIndex cmp a i j k l = UNSAFE_CHECK(checkIndex) "sort4ByIndex" i (length a)
118                           $ UNSAFE_CHECK(checkIndex) "sort4ByIndex" j (length a)
119                           $ UNSAFE_CHECK(checkIndex) "sort4ByIndex" k (length a)
120                           $ UNSAFE_CHECK(checkIndex) "sort4ByIndex" l (length a) $ do
121  a0 <- unsafeRead a i
122  a1 <- unsafeRead a j
123  a2 <- unsafeRead a k
124  a3 <- unsafeRead a l
125  case cmp a0 a1 of
126    GT -> case cmp a0 a2 of
127            GT -> case cmp a1 a2 of
128                    GT -> case cmp a1 a3 of
129                            GT -> case cmp a2 a3 of
130                                    GT -> do unsafeWrite a i a3
131                                             unsafeWrite a j a2
132                                             unsafeWrite a k a1
133                                             unsafeWrite a l a0
134                                    _  -> do unsafeWrite a i a2
135                                             unsafeWrite a j a3
136                                             unsafeWrite a k a1
137                                             unsafeWrite a l a0
138                            _  -> case cmp a0 a3 of
139                                    GT -> do unsafeWrite a i a2
140                                             unsafeWrite a j a1
141                                             unsafeWrite a k a3
142                                             unsafeWrite a l a0
143                                    _  -> do unsafeWrite a i a2
144                                             unsafeWrite a j a1
145                                             unsafeWrite a k a0
146                                             unsafeWrite a l a3
147                    _ -> case cmp a2 a3 of
148                           GT -> case cmp a1 a3 of
149                                   GT -> do unsafeWrite a i a3
150                                            unsafeWrite a j a1
151                                            unsafeWrite a k a2
152                                            unsafeWrite a l a0
153                                   _  -> do unsafeWrite a i a1
154                                            unsafeWrite a j a3
155                                            unsafeWrite a k a2
156                                            unsafeWrite a l a0
157                           _  -> case cmp a0 a3 of
158                                   GT -> do unsafeWrite a i a1
159                                            unsafeWrite a j a2
160                                            unsafeWrite a k a3
161                                            unsafeWrite a l a0
162                                   _  -> do unsafeWrite a i a1
163                                            unsafeWrite a j a2
164                                            unsafeWrite a k a0
165                                            -- unsafeWrite a l a3
166            _  -> case cmp a0 a3 of
167                    GT -> case cmp a1 a3 of
168                            GT -> do unsafeWrite a i a3
169                                     -- unsafeWrite a j a1
170                                     unsafeWrite a k a0
171                                     unsafeWrite a l a2
172                            _  -> do unsafeWrite a i a1
173                                     unsafeWrite a j a3
174                                     unsafeWrite a k a0
175                                     unsafeWrite a l a2
176                    _  -> case cmp a2 a3 of
177                            GT -> do unsafeWrite a i a1
178                                     unsafeWrite a j a0
179                                     unsafeWrite a k a3
180                                     unsafeWrite a l a2
181                            _  -> do unsafeWrite a i a1
182                                     unsafeWrite a j a0
183                                     -- unsafeWrite a k a2
184                                     -- unsafeWrite a l a3
185    _  -> case cmp a1 a2 of
186            GT -> case cmp a0 a2 of
187                    GT -> case cmp a0 a3 of
188                            GT -> case cmp a2 a3 of
189                                    GT -> do unsafeWrite a i a3
190                                             unsafeWrite a j a2
191                                             unsafeWrite a k a0
192                                             unsafeWrite a l a1
193                                    _  -> do unsafeWrite a i a2
194                                             unsafeWrite a j a3
195                                             unsafeWrite a k a0
196                                             unsafeWrite a l a1
197                            _  -> case cmp a1 a3 of
198                                    GT -> do unsafeWrite a i a2
199                                             unsafeWrite a j a0
200                                             unsafeWrite a k a3
201                                             unsafeWrite a l a1
202                                    _  -> do unsafeWrite a i a2
203                                             unsafeWrite a j a0
204                                             unsafeWrite a k a1
205                                             -- unsafeWrite a l a3
206                    _  -> case cmp a2 a3 of
207                            GT -> case cmp a0 a3 of
208                                    GT -> do unsafeWrite a i a3
209                                             unsafeWrite a j a0
210                                             -- unsafeWrite a k a2
211                                             unsafeWrite a l a1
212                                    _  -> do -- unsafeWrite a i a0
213                                             unsafeWrite a j a3
214                                             -- unsafeWrite a k a2
215                                             unsafeWrite a l a1
216                            _  -> case cmp a1 a3 of
217                                    GT -> do -- unsafeWrite a i a0
218                                             unsafeWrite a j a2
219                                             unsafeWrite a k a3
220                                             unsafeWrite a l a1
221                                    _  -> do -- unsafeWrite a i a0
222                                             unsafeWrite a j a2
223                                             unsafeWrite a k a1
224                                             -- unsafeWrite a l a3
225            _  -> case cmp a1 a3 of
226                    GT -> case cmp a0 a3 of
227                            GT -> do unsafeWrite a i a3
228                                     unsafeWrite a j a0
229                                     unsafeWrite a k a1
230                                     unsafeWrite a l a2
231                            _  -> do -- unsafeWrite a i a0
232                                     unsafeWrite a j a3
233                                     unsafeWrite a k a1
234                                     unsafeWrite a l a2
235                    _  -> case cmp a2 a3 of
236                            GT -> do -- unsafeWrite a i a0
237                                     -- unsafeWrite a j a1
238                                     unsafeWrite a k a3
239                                     unsafeWrite a l a2
240                            _  -> do -- unsafeWrite a i a0
241                                     -- unsafeWrite a j a1
242                                     -- unsafeWrite a k a2
243                                     -- unsafeWrite a l a3
244                                     return ()
245{-# INLINABLE sort4ByIndex #-}
246