1{-# LANGUAGE TypeFamilies #-} 2{-# LANGUAGE TypeOperators #-} 3 4-- --------------------------------------------------------------------------- 5-- | 6-- Module : Data.Vector.Algorithms.Heap 7-- Copyright : (c) 2008-2015 Dan Doel 8-- Maintainer : Dan Doel <dan.doel@gmail.com> 9-- Stability : Experimental 10-- Portability : Non-portable (type operators) 11-- 12-- This module implements operations for working with a quaternary heap stored 13-- in an unboxed array. Most heapsorts are defined in terms of a binary heap, 14-- in which each internal node has at most two children. By contrast, a 15-- quaternary heap has internal nodes with up to four children. This reduces 16-- the number of comparisons in a heapsort slightly, and improves locality 17-- (again, slightly) by flattening out the heap. 18 19module Data.Vector.Algorithms.Heap 20 ( -- * Sorting 21 sort 22 , sortBy 23 , sortByBounds 24 -- * Selection 25 , select 26 , selectBy 27 , selectByBounds 28 -- * Partial sorts 29 , partialSort 30 , partialSortBy 31 , partialSortByBounds 32 -- * Heap operations 33 , heapify 34 , pop 35 , popTo 36 , sortHeap 37 , heapInsert 38 , Comparison 39 ) where 40 41import Prelude hiding (read, length) 42 43import Control.Monad 44import Control.Monad.Primitive 45 46import Data.Bits 47 48import Data.Vector.Generic.Mutable 49 50import Data.Vector.Algorithms.Common (Comparison) 51 52import qualified Data.Vector.Algorithms.Optimal as O 53 54-- | Sorts an entire array using the default ordering. 55sort :: (PrimMonad m, MVector v e, Ord e) => v (PrimState m) e -> m () 56sort = sortBy compare 57{-# INLINABLE sort #-} 58 59-- | Sorts an entire array using a custom ordering. 60sortBy :: (PrimMonad m, MVector v e) => Comparison e -> v (PrimState m) e -> m () 61sortBy cmp a = sortByBounds cmp a 0 (length a) 62{-# INLINE sortBy #-} 63 64-- | Sorts a portion of an array [l,u) using a custom ordering 65sortByBounds 66 :: (PrimMonad m, MVector v e) 67 => Comparison e 68 -> v (PrimState m) e 69 -> Int -- ^ lower index, l 70 -> Int -- ^ upper index, u 71 -> m () 72sortByBounds cmp a l u 73 | len < 2 = return () 74 | len == 2 = O.sort2ByOffset cmp a l 75 | len == 3 = O.sort3ByOffset cmp a l 76 | len == 4 = O.sort4ByOffset cmp a l 77 | otherwise = heapify cmp a l u >> sortHeap cmp a l (l+4) u >> O.sort4ByOffset cmp a l 78 where len = u - l 79{-# INLINE sortByBounds #-} 80 81-- | Moves the lowest k elements to the front of the array. 82-- The elements will be in no particular order. 83select 84 :: (PrimMonad m, MVector v e, Ord e) 85 => v (PrimState m) e 86 -> Int -- ^ number of elements to select, k 87 -> m () 88select = selectBy compare 89{-# INLINE select #-} 90 91-- | Moves the lowest (as defined by the comparison) k elements 92-- to the front of the array. The elements will be in no particular 93-- order. 94selectBy 95 :: (PrimMonad m, MVector v e) 96 => Comparison e 97 -> v (PrimState m) e 98 -> Int -- ^ number of elements to select, k 99 -> m () 100selectBy cmp a k = selectByBounds cmp a k 0 (length a) 101{-# INLINE selectBy #-} 102 103-- | Moves the 'lowest' k elements in the portion [l,u) of the 104-- array into the positions [l,k+l). The elements will be in 105-- no particular order. 106selectByBounds 107 :: (PrimMonad m, MVector v e) 108 => Comparison e 109 -> v (PrimState m) e 110 -> Int -- ^ number of elements to select, k 111 -> Int -- ^ lower index, l 112 -> Int -- ^ upper index, u 113 -> m () 114selectByBounds cmp a k l u 115 | l + k <= u = heapify cmp a l (l + k) >> go l (l + k) (u - 1) 116 | otherwise = return () 117 where 118 go l m u 119 | u < m = return () 120 | otherwise = do el <- unsafeRead a l 121 eu <- unsafeRead a u 122 case cmp eu el of 123 LT -> popTo cmp a l m u 124 _ -> return () 125 go l m (u - 1) 126{-# INLINE selectByBounds #-} 127 128-- | Moves the lowest k elements to the front of the array, sorted. 129-- 130-- The remaining values of the array will be in no particular order. 131partialSort 132 :: (PrimMonad m, MVector v e, Ord e) 133 => v (PrimState m) e 134 -> Int -- ^ number of elements to sort, k 135 -> m () 136partialSort = partialSortBy compare 137{-# INLINE partialSort #-} 138 139-- | Moves the lowest k elements (as defined by the comparison) to 140-- the front of the array, sorted. 141-- 142-- The remaining values of the array will be in no particular order. 143partialSortBy 144 :: (PrimMonad m, MVector v e) 145 => Comparison e 146 -> v (PrimState m) e 147 -> Int -- ^ number of elements to sort, k 148 -> m () 149partialSortBy cmp a k = partialSortByBounds cmp a k 0 (length a) 150{-# INLINE partialSortBy #-} 151 152-- | Moves the lowest k elements in the portion [l,u) of the array 153-- into positions [l,k+l), sorted. 154-- 155-- The remaining values in [l,u) will be in no particular order. Values outside 156-- the range [l,u) will be unaffected. 157partialSortByBounds 158 :: (PrimMonad m, MVector v e) 159 => Comparison e 160 -> v (PrimState m) e 161 -> Int -- ^ number of elements to sort, k 162 -> Int -- ^ lower index, l 163 -> Int -- ^ upper index, u 164 -> m () 165partialSortByBounds cmp a k l u 166 -- this potentially does more work than absolutely required, 167 -- but using a heap to find the least 2 of 4 elements 168 -- seems unlikely to be better than just sorting all of them 169 -- with an optimal sort, and the latter is obviously index 170 -- correct. 171 | len < 2 = return () 172 | len == 2 = O.sort2ByOffset cmp a l 173 | len == 3 = O.sort3ByOffset cmp a l 174 | len == 4 = O.sort4ByOffset cmp a l 175 | u <= l + k = sortByBounds cmp a l u 176 | otherwise = do selectByBounds cmp a k l u 177 sortHeap cmp a l (l + 4) (l + k) 178 O.sort4ByOffset cmp a l 179 where 180 len = u - l 181{-# INLINE partialSortByBounds #-} 182 183-- | Constructs a heap in a portion of an array [l, u), using the values therein. 184-- 185-- Note: 'heapify' is more efficient than constructing a heap by repeated 186-- insertion. Repeated insertion has complexity O(n*log n) while 'heapify' is able 187-- to construct a heap in O(n), where n is the number of elements in the heap. 188heapify 189 :: (PrimMonad m, MVector v e) 190 => Comparison e 191 -> v (PrimState m) e 192 -> Int -- ^ lower index, l 193 -> Int -- ^ upper index, u 194 -> m () 195heapify cmp a l u = loop $ (len - 1) `shiftR` 2 196 where 197 len = u - l 198 loop k 199 | k < 0 = return () 200 | otherwise = unsafeRead a (l+k) >>= \e -> 201 siftByOffset cmp a e l k len >> loop (k - 1) 202{-# INLINE heapify #-} 203 204-- | Given a heap stored in a portion of an array [l,u), swaps the 205-- top of the heap with the element at u and rebuilds the heap. 206pop 207 :: (PrimMonad m, MVector v e) 208 => Comparison e 209 -> v (PrimState m) e 210 -> Int -- ^ lower heap index, l 211 -> Int -- ^ upper heap index, u 212 -> m () 213pop cmp a l u = popTo cmp a l u u 214{-# INLINE pop #-} 215 216-- | Given a heap stored in a portion of an array [l,u) swaps the top 217-- of the heap with the element at position t, and rebuilds the heap. 218popTo 219 :: (PrimMonad m, MVector v e) 220 => Comparison e 221 -> v (PrimState m) e 222 -> Int -- ^ lower heap index, l 223 -> Int -- ^ upper heap index, u 224 -> Int -- ^ index to pop to, t 225 -> m () 226popTo cmp a l u t = do al <- unsafeRead a l 227 at <- unsafeRead a t 228 unsafeWrite a t al 229 siftByOffset cmp a at l 0 (u - l) 230{-# INLINE popTo #-} 231 232-- | Given a heap stored in a portion of an array [l,u), sorts the 233-- highest values into [m,u). The elements in [l,m) are not in any 234-- particular order. 235sortHeap 236 :: (PrimMonad m, MVector v e) 237 => Comparison e 238 -> v (PrimState m) e 239 -> Int -- ^ lower heap index, l 240 -> Int -- ^ lower bound of final sorted portion, m 241 -> Int -- ^ upper heap index, u 242 -> m () 243sortHeap cmp a l m u = loop (u-1) >> unsafeSwap a l m 244 where 245 loop k 246 | m < k = pop cmp a l k >> loop (k-1) 247 | otherwise = return () 248{-# INLINE sortHeap #-} 249 250-- | Given a heap stored in a portion of an array [l,u) and an element e, 251-- inserts the element into the heap, resulting in a heap in [l,u]. 252-- 253-- Note: it is best to only use this operation when incremental construction of 254-- a heap is required. 'heapify' is capable of building a heap in O(n) time, 255-- while repeated insertion takes O(n*log n) time. 256heapInsert 257 :: (PrimMonad m, MVector v e) 258 => Comparison e 259 -> v (PrimState m) e 260 -> Int -- ^ lower heap index, l 261 -> Int -- ^ upper heap index, u 262 -> e -- ^ element to be inserted, e 263 -> m () 264heapInsert cmp v l u e = sift (u - l) 265 where 266 sift k 267 | k <= 0 = unsafeWrite v l e 268 | otherwise = let pi = shiftR (k-1) 2 269 in unsafeRead v (l + pi) >>= \p -> case cmp p e of 270 LT -> unsafeWrite v (l + k) p >> sift pi 271 _ -> unsafeWrite v (l + k) e 272{-# INLINE heapInsert #-} 273 274-- Rebuilds a heap with a hole in it from start downwards. Afterward, 275-- the heap property should apply for [start + off, len + off). val 276-- is the new value to be put in the hole. 277siftByOffset :: (PrimMonad m, MVector v e) 278 => Comparison e -> v (PrimState m) e -> e -> Int -> Int -> Int -> m () 279siftByOffset cmp a val off start len = sift val start len 280 where 281 sift val root len 282 | child < len = do (child', ac) <- maximumChild cmp a off child len 283 case cmp val ac of 284 LT -> unsafeWrite a (root + off) ac >> sift val child' len 285 _ -> unsafeWrite a (root + off) val 286 | otherwise = unsafeWrite a (root + off) val 287 where child = root `shiftL` 2 + 1 288{-# INLINE siftByOffset #-} 289 290-- Finds the maximum child of a heap node, given the indx of the first child. 291maximumChild :: (PrimMonad m, MVector v e) 292 => Comparison e -> v (PrimState m) e -> Int -> Int -> Int -> m (Int, e) 293maximumChild cmp a off child1 len 294 | child4 < len = do ac1 <- unsafeRead a (child1 + off) 295 ac2 <- unsafeRead a (child2 + off) 296 ac3 <- unsafeRead a (child3 + off) 297 ac4 <- unsafeRead a (child4 + off) 298 return $ case cmp ac1 ac2 of 299 LT -> case cmp ac2 ac3 of 300 LT -> case cmp ac3 ac4 of 301 LT -> (child4, ac4) 302 _ -> (child3, ac3) 303 _ -> case cmp ac2 ac4 of 304 LT -> (child4, ac4) 305 _ -> (child2, ac2) 306 _ -> case cmp ac1 ac3 of 307 LT -> case cmp ac3 ac4 of 308 LT -> (child4, ac4) 309 _ -> (child3, ac3) 310 _ -> case cmp ac1 ac4 of 311 LT -> (child4, ac4) 312 _ -> (child1, ac1) 313 | child3 < len = do ac1 <- unsafeRead a (child1 + off) 314 ac2 <- unsafeRead a (child2 + off) 315 ac3 <- unsafeRead a (child3 + off) 316 return $ case cmp ac1 ac2 of 317 LT -> case cmp ac2 ac3 of 318 LT -> (child3, ac3) 319 _ -> (child2, ac2) 320 _ -> case cmp ac1 ac3 of 321 LT -> (child3, ac3) 322 _ -> (child1, ac1) 323 | child2 < len = do ac1 <- unsafeRead a (child1 + off) 324 ac2 <- unsafeRead a (child2 + off) 325 return $ case cmp ac1 ac2 of 326 LT -> (child2, ac2) 327 _ -> (child1, ac1) 328 | otherwise = do ac1 <- unsafeRead a (child1 + off) ; return (child1, ac1) 329 where 330 child2 = child1 + 1 331 child3 = child1 + 2 332 child4 = child1 + 3 333{-# INLINE maximumChild #-} 334