1{-# LANGUAGE TypeFamilies #-}
2{-# LANGUAGE TypeOperators #-}
3
4-- ---------------------------------------------------------------------------
5-- |
6-- Module      : Data.Vector.Algorithms.Heap
7-- Copyright   : (c) 2008-2015 Dan Doel
8-- Maintainer  : Dan Doel <dan.doel@gmail.com>
9-- Stability   : Experimental
10-- Portability : Non-portable (type operators)
11--
12-- This module implements operations for working with a quaternary heap stored
13-- in an unboxed array. Most heapsorts are defined in terms of a binary heap,
14-- in which each internal node has at most two children. By contrast, a
15-- quaternary heap has internal nodes with up to four children. This reduces
16-- the number of comparisons in a heapsort slightly, and improves locality
17-- (again, slightly) by flattening out the heap.
18
19module Data.Vector.Algorithms.Heap
20       ( -- * Sorting
21         sort
22       , sortBy
23       , sortByBounds
24         -- * Selection
25       , select
26       , selectBy
27       , selectByBounds
28         -- * Partial sorts
29       , partialSort
30       , partialSortBy
31       , partialSortByBounds
32         -- * Heap operations
33       , heapify
34       , pop
35       , popTo
36       , sortHeap
37       , heapInsert
38       , Comparison
39       ) where
40
41import Prelude hiding (read, length)
42
43import Control.Monad
44import Control.Monad.Primitive
45
46import Data.Bits
47
48import Data.Vector.Generic.Mutable
49
50import Data.Vector.Algorithms.Common (Comparison)
51
52import qualified Data.Vector.Algorithms.Optimal as O
53
54-- | Sorts an entire array using the default ordering.
55sort :: (PrimMonad m, MVector v e, Ord e) => v (PrimState m) e -> m ()
56sort = sortBy compare
57{-# INLINABLE sort #-}
58
59-- | Sorts an entire array using a custom ordering.
60sortBy :: (PrimMonad m, MVector v e) => Comparison e -> v (PrimState m) e -> m ()
61sortBy cmp a = sortByBounds cmp a 0 (length a)
62{-# INLINE sortBy #-}
63
64-- | Sorts a portion of an array [l,u) using a custom ordering
65sortByBounds
66  :: (PrimMonad m, MVector v e)
67  => Comparison e
68  -> v (PrimState m) e
69  -> Int -- ^ lower index, l
70  -> Int -- ^ upper index, u
71  -> m ()
72sortByBounds cmp a l u
73  | len < 2   = return ()
74  | len == 2  = O.sort2ByOffset cmp a l
75  | len == 3  = O.sort3ByOffset cmp a l
76  | len == 4  = O.sort4ByOffset cmp a l
77  | otherwise = heapify cmp a l u >> sortHeap cmp a l (l+4) u >> O.sort4ByOffset cmp a l
78 where len = u - l
79{-# INLINE sortByBounds #-}
80
81-- | Moves the lowest k elements to the front of the array.
82-- The elements will be in no particular order.
83select
84  :: (PrimMonad m, MVector v e, Ord e)
85  => v (PrimState m) e
86  -> Int -- ^ number of elements to select, k
87  -> m ()
88select = selectBy compare
89{-# INLINE select #-}
90
91-- | Moves the lowest (as defined by the comparison) k elements
92-- to the front of the array. The elements will be in no particular
93-- order.
94selectBy
95  :: (PrimMonad m, MVector v e)
96  => Comparison e
97  -> v (PrimState m) e
98  -> Int -- ^ number of elements to select, k
99  -> m ()
100selectBy cmp a k = selectByBounds cmp a k 0 (length a)
101{-# INLINE selectBy #-}
102
103-- | Moves the 'lowest' k elements in the portion [l,u) of the
104-- array into the positions [l,k+l). The elements will be in
105-- no particular order.
106selectByBounds
107  :: (PrimMonad m, MVector v e)
108  => Comparison e
109  -> v (PrimState m) e
110  -> Int -- ^ number of elements to select, k
111  -> Int -- ^ lower index, l
112  -> Int -- ^ upper index, u
113  -> m ()
114selectByBounds cmp a k l u
115  | l + k <= u = heapify cmp a l (l + k) >> go l (l + k) (u - 1)
116  | otherwise  = return ()
117 where
118 go l m u
119   | u < m      = return ()
120   | otherwise  = do el <- unsafeRead a l
121                     eu <- unsafeRead a u
122                     case cmp eu el of
123                       LT -> popTo cmp a l m u
124                       _  -> return ()
125                     go l m (u - 1)
126{-# INLINE selectByBounds #-}
127
128-- | Moves the lowest k elements to the front of the array, sorted.
129--
130-- The remaining values of the array will be in no particular order.
131partialSort
132  :: (PrimMonad m, MVector v e, Ord e)
133  => v (PrimState m) e
134  -> Int -- ^ number of elements to sort, k
135  -> m ()
136partialSort = partialSortBy compare
137{-# INLINE partialSort #-}
138
139-- | Moves the lowest k elements (as defined by the comparison) to
140-- the front of the array, sorted.
141--
142-- The remaining values of the array will be in no particular order.
143partialSortBy
144  :: (PrimMonad m, MVector v e)
145  => Comparison e
146  -> v (PrimState m) e
147  -> Int -- ^ number of elements to sort, k
148  -> m ()
149partialSortBy cmp a k = partialSortByBounds cmp a k 0 (length a)
150{-# INLINE partialSortBy #-}
151
152-- | Moves the lowest k elements in the portion [l,u) of the array
153-- into positions [l,k+l), sorted.
154--
155-- The remaining values in [l,u) will be in no particular order. Values outside
156-- the range [l,u) will be unaffected.
157partialSortByBounds
158  :: (PrimMonad m, MVector v e)
159  => Comparison e
160  -> v (PrimState m) e
161  -> Int -- ^ number of elements to sort, k
162  -> Int -- ^ lower index, l
163  -> Int -- ^ upper index, u
164  -> m ()
165partialSortByBounds cmp a k l u
166  -- this potentially does more work than absolutely required,
167  -- but using a heap to find the least 2 of 4 elements
168  -- seems unlikely to be better than just sorting all of them
169  -- with an optimal sort, and the latter is obviously index
170  -- correct.
171  | len <  2   = return ()
172  | len == 2   = O.sort2ByOffset cmp a l
173  | len == 3   = O.sort3ByOffset cmp a l
174  | len == 4   = O.sort4ByOffset cmp a l
175  | u <= l + k = sortByBounds cmp a l u
176  | otherwise  = do selectByBounds cmp a k l u
177                    sortHeap cmp a l (l + 4) (l + k)
178                    O.sort4ByOffset cmp a l
179 where
180 len = u - l
181{-# INLINE partialSortByBounds #-}
182
183-- | Constructs a heap in a portion of an array [l, u), using the values therein.
184--
185-- Note: 'heapify' is more efficient than constructing a heap by repeated
186-- insertion. Repeated insertion has complexity O(n*log n) while 'heapify' is able
187-- to construct a heap in O(n), where n is the number of elements in the heap.
188heapify
189  :: (PrimMonad m, MVector v e)
190  => Comparison e
191  -> v (PrimState m) e
192  -> Int -- ^ lower index, l
193  -> Int -- ^ upper index, u
194  -> m ()
195heapify cmp a l u = loop $ (len - 1) `shiftR` 2
196  where
197 len = u - l
198 loop k
199   | k < 0     = return ()
200   | otherwise = unsafeRead a (l+k) >>= \e ->
201                   siftByOffset cmp a e l k len >> loop (k - 1)
202{-# INLINE heapify #-}
203
204-- | Given a heap stored in a portion of an array [l,u), swaps the
205-- top of the heap with the element at u and rebuilds the heap.
206pop
207  :: (PrimMonad m, MVector v e)
208  => Comparison e
209  -> v (PrimState m) e
210  -> Int -- ^ lower heap index, l
211  -> Int -- ^ upper heap index, u
212  -> m ()
213pop cmp a l u = popTo cmp a l u u
214{-# INLINE pop #-}
215
216-- | Given a heap stored in a portion of an array [l,u) swaps the top
217-- of the heap with the element at position t, and rebuilds the heap.
218popTo
219  :: (PrimMonad m, MVector v e)
220  => Comparison e
221  -> v (PrimState m) e
222  -> Int -- ^ lower heap index, l
223  -> Int -- ^ upper heap index, u
224  -> Int -- ^ index to pop to, t
225  -> m ()
226popTo cmp a l u t = do al <- unsafeRead a l
227                       at <- unsafeRead a t
228                       unsafeWrite a t al
229                       siftByOffset cmp a at l 0 (u - l)
230{-# INLINE popTo #-}
231
232-- | Given a heap stored in a portion of an array [l,u), sorts the
233-- highest values into [m,u). The elements in [l,m) are not in any
234-- particular order.
235sortHeap
236  :: (PrimMonad m, MVector v e)
237  => Comparison e
238  -> v (PrimState m) e
239  -> Int -- ^ lower heap index, l
240  -> Int -- ^ lower bound of final sorted portion, m
241  -> Int -- ^ upper heap index, u
242  -> m ()
243sortHeap cmp a l m u = loop (u-1) >> unsafeSwap a l m
244 where
245 loop k
246   | m < k     = pop cmp a l k >> loop (k-1)
247   | otherwise = return ()
248{-# INLINE sortHeap #-}
249
250-- | Given a heap stored in a portion of an array [l,u) and an element e,
251-- inserts the element into the heap, resulting in a heap in [l,u].
252--
253-- Note: it is best to only use this operation when incremental construction of
254-- a heap is required. 'heapify' is capable of building a heap in O(n) time,
255-- while repeated insertion takes O(n*log n) time.
256heapInsert
257  :: (PrimMonad m, MVector v e)
258  => Comparison e
259  -> v (PrimState m) e
260  -> Int -- ^ lower heap index, l
261  -> Int -- ^ upper heap index, u
262  -> e -- ^ element to be inserted, e
263  -> m ()
264heapInsert cmp v l u e = sift (u - l)
265 where
266 sift k
267   | k <= 0    = unsafeWrite v l e
268   | otherwise = let pi = shiftR (k-1) 2
269                  in unsafeRead v (l + pi) >>= \p -> case cmp p e of
270                       LT -> unsafeWrite v (l + k) p >> sift pi
271                       _  -> unsafeWrite v (l + k) e
272{-# INLINE heapInsert #-}
273
274-- Rebuilds a heap with a hole in it from start downwards. Afterward,
275-- the heap property should apply for [start + off, len + off). val
276-- is the new value to be put in the hole.
277siftByOffset :: (PrimMonad m, MVector v e)
278             => Comparison e -> v (PrimState m) e -> e -> Int -> Int -> Int -> m ()
279siftByOffset cmp a val off start len = sift val start len
280 where
281 sift val root len
282   | child < len = do (child', ac) <- maximumChild cmp a off child len
283                      case cmp val ac of
284                        LT -> unsafeWrite a (root + off) ac >> sift val child' len
285                        _  -> unsafeWrite a (root + off) val
286   | otherwise = unsafeWrite a (root + off) val
287  where child = root `shiftL` 2 + 1
288{-# INLINE siftByOffset #-}
289
290-- Finds the maximum child of a heap node, given the indx of the first child.
291maximumChild :: (PrimMonad m, MVector v e)
292             => Comparison e -> v (PrimState m) e -> Int -> Int -> Int -> m (Int,  e)
293maximumChild cmp a off child1 len
294  | child4 < len = do ac1 <- unsafeRead a (child1 + off)
295                      ac2 <- unsafeRead a (child2 + off)
296                      ac3 <- unsafeRead a (child3 + off)
297                      ac4 <- unsafeRead a (child4 + off)
298                      return $ case cmp ac1 ac2 of
299                                 LT -> case cmp ac2 ac3 of
300                                         LT -> case cmp ac3 ac4 of
301                                                 LT -> (child4, ac4)
302                                                 _  -> (child3, ac3)
303                                         _  -> case cmp ac2 ac4 of
304                                                 LT -> (child4, ac4)
305                                                 _  -> (child2, ac2)
306                                 _  -> case cmp ac1 ac3 of
307                                         LT -> case cmp ac3 ac4 of
308                                                 LT -> (child4, ac4)
309                                                 _  -> (child3, ac3)
310                                         _  -> case cmp ac1 ac4 of
311                                                 LT -> (child4, ac4)
312                                                 _  -> (child1, ac1)
313  | child3 < len = do ac1 <- unsafeRead a (child1 + off)
314                      ac2 <- unsafeRead a (child2 + off)
315                      ac3 <- unsafeRead a (child3 + off)
316                      return $ case cmp ac1 ac2 of
317                                 LT -> case cmp ac2 ac3 of
318                                         LT -> (child3, ac3)
319                                         _  -> (child2, ac2)
320                                 _  -> case cmp ac1 ac3 of
321                                         LT -> (child3, ac3)
322                                         _  -> (child1, ac1)
323  | child2 < len = do ac1 <- unsafeRead a (child1 + off)
324                      ac2 <- unsafeRead a (child2 + off)
325                      return $ case cmp ac1 ac2 of
326                                 LT -> (child2, ac2)
327                                 _  -> (child1, ac1)
328  | otherwise    = do ac1 <- unsafeRead a (child1 + off) ; return (child1, ac1)
329 where
330 child2 = child1 + 1
331 child3 = child1 + 2
332 child4 = child1 + 3
333{-# INLINE maximumChild #-}
334