1{-# LANGUAGE BangPatterns #-} 2{-# LANGUAGE TypeFamilies #-} 3{-# LANGUAGE FlexibleContexts #-} 4{-# lANGUAGE ScopedTypeVariables #-} 5 6-- --------------------------------------------------------------------------- 7-- | 8-- Module : Data.Vector.Algorithms.AmericanFlag 9-- Copyright : (c) 2011 Dan Doel 10-- Maintainer : Dan Doel <dan.doel@gmail.com> 11-- Stability : Experimental 12-- Portability : Non-portable (FlexibleContexts, ScopedTypeVariables) 13-- 14-- This module implements American flag sort: an in-place, unstable, bucket 15-- sort. Also in contrast to radix sort, the values are inspected in a big 16-- endian order, and buckets are sorted via recursive splitting. This, 17-- however, makes it sensible for sorting strings in lexicographic order 18-- (provided indexing is fast). 19-- 20-- The algorithm works as follows: at each stage, the array is looped over, 21-- counting the number of elements for each bucket. Then, starting at the 22-- beginning of the array, elements are permuted in place to reside in the 23-- proper bucket, following chains until they reach back to the current 24-- base index. Finally, each bucket is sorted recursively. This lends itself 25-- well to the aforementioned variable-length strings, and so the algorithm 26-- takes a stopping predicate, which is given a representative of the stripe, 27-- rather than running for a set number of iterations. 28 29module Data.Vector.Algorithms.AmericanFlag ( sort 30 , sortBy 31 , terminate 32 , Lexicographic(..) 33 ) where 34 35import Prelude hiding (read, length) 36 37import Control.Monad 38import Control.Monad.Primitive 39 40import Data.Proxy 41 42import Data.Word 43import Data.Int 44import Data.Bits 45 46import qualified Data.ByteString as B 47 48import Data.Vector.Generic.Mutable 49import qualified Data.Vector.Primitive.Mutable as PV 50 51import qualified Data.Vector.Unboxed.Mutable as U 52 53import Data.Vector.Algorithms.Common 54 55import qualified Data.Vector.Algorithms.Insertion as I 56 57import Foreign.Storable 58 59-- | The methods of this class specify the information necessary to sort 60-- arrays using the default ordering. The name 'Lexicographic' is meant 61-- to convey that index should return results in a similar way to indexing 62-- into a string. 63class Lexicographic e where 64 -- | Computes the length of a representative of a stripe. It should take 'n' 65 -- passes to sort values of extent 'n'. The extent may not be uniform across 66 -- all values of the type. 67 extent :: e -> Int 68 69 -- | The size of the bucket array necessary for sorting es 70 size :: Proxy e -> Int 71 -- | Determines which bucket a given element should inhabit for a 72 -- particular iteration. 73 index :: Int -> e -> Int 74 75instance Lexicographic Word8 where 76 extent _ = 1 77 {-# INLINE extent #-} 78 size _ = 256 79 {-# INLINE size #-} 80 index _ n = fromIntegral n 81 {-# INLINE index #-} 82 83instance Lexicographic Word16 where 84 extent _ = 2 85 {-# INLINE extent #-} 86 size _ = 256 87 {-# INLINE size #-} 88 index 0 n = fromIntegral $ (n `shiftR` 8) .&. 255 89 index 1 n = fromIntegral $ n .&. 255 90 index _ _ = 0 91 {-# INLINE index #-} 92 93instance Lexicographic Word32 where 94 extent _ = 4 95 {-# INLINE extent #-} 96 size _ = 256 97 {-# INLINE size #-} 98 index 0 n = fromIntegral $ (n `shiftR` 24) .&. 255 99 index 1 n = fromIntegral $ (n `shiftR` 16) .&. 255 100 index 2 n = fromIntegral $ (n `shiftR` 8) .&. 255 101 index 3 n = fromIntegral $ n .&. 255 102 index _ _ = 0 103 {-# INLINE index #-} 104 105instance Lexicographic Word64 where 106 extent _ = 8 107 {-# INLINE extent #-} 108 size _ = 256 109 {-# INLINE size #-} 110 index 0 n = fromIntegral $ (n `shiftR` 56) .&. 255 111 index 1 n = fromIntegral $ (n `shiftR` 48) .&. 255 112 index 2 n = fromIntegral $ (n `shiftR` 40) .&. 255 113 index 3 n = fromIntegral $ (n `shiftR` 32) .&. 255 114 index 4 n = fromIntegral $ (n `shiftR` 24) .&. 255 115 index 5 n = fromIntegral $ (n `shiftR` 16) .&. 255 116 index 6 n = fromIntegral $ (n `shiftR` 8) .&. 255 117 index 7 n = fromIntegral $ n .&. 255 118 index _ _ = 0 119 {-# INLINE index #-} 120 121instance Lexicographic Word where 122 extent _ = sizeOf (0 :: Word) 123 {-# INLINE extent #-} 124 size _ = 256 125 {-# INLINE size #-} 126 index 0 n = fromIntegral $ (n `shiftR` 56) .&. 255 127 index 1 n = fromIntegral $ (n `shiftR` 48) .&. 255 128 index 2 n = fromIntegral $ (n `shiftR` 40) .&. 255 129 index 3 n = fromIntegral $ (n `shiftR` 32) .&. 255 130 index 4 n = fromIntegral $ (n `shiftR` 24) .&. 255 131 index 5 n = fromIntegral $ (n `shiftR` 16) .&. 255 132 index 6 n = fromIntegral $ (n `shiftR` 8) .&. 255 133 index 7 n = fromIntegral $ n .&. 255 134 index _ _ = 0 135 {-# INLINE index #-} 136 137instance Lexicographic Int8 where 138 extent _ = 1 139 {-# INLINE extent #-} 140 size _ = 256 141 {-# INLINE size #-} 142 index _ n = 255 .&. fromIntegral n `xor` 128 143 {-# INLINE index #-} 144 145instance Lexicographic Int16 where 146 extent _ = 2 147 {-# INLINE extent #-} 148 size _ = 256 149 {-# INLINE size #-} 150 index 0 n = fromIntegral $ ((n `xor` minBound) `shiftR` 8) .&. 255 151 index 1 n = fromIntegral $ n .&. 255 152 index _ _ = 0 153 {-# INLINE index #-} 154 155instance Lexicographic Int32 where 156 extent _ = 4 157 {-# INLINE extent #-} 158 size _ = 256 159 {-# INLINE size #-} 160 index 0 n = fromIntegral $ ((n `xor` minBound) `shiftR` 24) .&. 255 161 index 1 n = fromIntegral $ (n `shiftR` 16) .&. 255 162 index 2 n = fromIntegral $ (n `shiftR` 8) .&. 255 163 index 3 n = fromIntegral $ n .&. 255 164 index _ _ = 0 165 {-# INLINE index #-} 166 167instance Lexicographic Int64 where 168 extent _ = 8 169 {-# INLINE extent #-} 170 size _ = 256 171 {-# INLINE size #-} 172 index 0 n = fromIntegral $ ((n `xor` minBound) `shiftR` 56) .&. 255 173 index 1 n = fromIntegral $ (n `shiftR` 48) .&. 255 174 index 2 n = fromIntegral $ (n `shiftR` 40) .&. 255 175 index 3 n = fromIntegral $ (n `shiftR` 32) .&. 255 176 index 4 n = fromIntegral $ (n `shiftR` 24) .&. 255 177 index 5 n = fromIntegral $ (n `shiftR` 16) .&. 255 178 index 6 n = fromIntegral $ (n `shiftR` 8) .&. 255 179 index 7 n = fromIntegral $ n .&. 255 180 index _ _ = 0 181 {-# INLINE index #-} 182 183instance Lexicographic Int where 184 extent _ = sizeOf (0 :: Int) 185 {-# INLINE extent #-} 186 size _ = 256 187 {-# INLINE size #-} 188 index 0 n = ((n `xor` minBound) `shiftR` 56) .&. 255 189 index 1 n = (n `shiftR` 48) .&. 255 190 index 2 n = (n `shiftR` 40) .&. 255 191 index 3 n = (n `shiftR` 32) .&. 255 192 index 4 n = (n `shiftR` 24) .&. 255 193 index 5 n = (n `shiftR` 16) .&. 255 194 index 6 n = (n `shiftR` 8) .&. 255 195 index 7 n = n .&. 255 196 index _ _ = 0 197 {-# INLINE index #-} 198 199instance Lexicographic B.ByteString where 200 extent = B.length 201 {-# INLINE extent #-} 202 size _ = 257 203 {-# INLINE size #-} 204 index i b 205 | i >= B.length b = 0 206 | otherwise = fromIntegral (B.index b i) + 1 207 {-# INLINE index #-} 208 209instance (Lexicographic a, Lexicographic b) => Lexicographic (a, b) where 210 extent (a,b) = extent a + extent b 211 {-# INLINE extent #-} 212 size _ = size (Proxy :: Proxy a) `max` size (Proxy :: Proxy b) 213 {-# INLINE size #-} 214 index i (a,b) 215 | i >= extent a = index i b 216 | otherwise = index i a 217 {-# INLINE index #-} 218 219instance (Lexicographic a, Lexicographic b) => Lexicographic (Either a b) where 220 extent (Left a) = 1 + extent a 221 extent (Right b) = 1 + extent b 222 {-# INLINE extent #-} 223 size _ = size (Proxy :: Proxy a) `max` size (Proxy :: Proxy b) 224 {-# INLINE size #-} 225 index 0 (Left _) = 0 226 index 0 (Right _) = 1 227 index n (Left a) = index (n-1) a 228 index n (Right b) = index (n-1) b 229 {-# INLINE index #-} 230 231-- | Given a representative of a stripe and an index number, this 232-- function determines whether to stop sorting. 233terminate :: Lexicographic e => e -> Int -> Bool 234terminate e i = i >= extent e 235{-# INLINE terminate #-} 236 237-- | Sorts an array using the default ordering. Both Lexicographic and 238-- Ord are necessary because the algorithm falls back to insertion sort 239-- for sufficiently small arrays. 240sort :: forall e m v. (PrimMonad m, MVector v e, Lexicographic e, Ord e) 241 => v (PrimState m) e -> m () 242sort v = sortBy compare terminate (size p) index v 243 where p :: Proxy e 244 p = Proxy 245{-# INLINABLE sort #-} 246 247-- | A fully parameterized version of the sorting algorithm. Again, this 248-- function takes both radix information and a comparison, because the 249-- algorithms falls back to insertion sort for small arrays. 250sortBy :: (PrimMonad m, MVector v e) 251 => Comparison e -- ^ a comparison for the insertion sort flalback 252 -> (e -> Int -> Bool) -- ^ determines whether a stripe is complete 253 -> Int -- ^ the number of buckets necessary 254 -> (Int -> e -> Int) -- ^ the big-endian radix function 255 -> v (PrimState m) e -- ^ the array to be sorted 256 -> m () 257sortBy cmp stop buckets radix v 258 | length v == 0 = return () 259 | otherwise = do count <- new buckets 260 pile <- new buckets 261 countLoop (radix 0) v count 262 flagLoop cmp stop radix count pile v 263{-# INLINE sortBy #-} 264 265flagLoop :: (PrimMonad m, MVector v e) 266 => Comparison e 267 -> (e -> Int -> Bool) -- number of passes 268 -> (Int -> e -> Int) -- radix function 269 -> PV.MVector (PrimState m) Int -- auxiliary count array 270 -> PV.MVector (PrimState m) Int -- auxiliary pile array 271 -> v (PrimState m) e -- source array 272 -> m () 273flagLoop cmp stop radix count pile v = go 0 v 274 where 275 276 go pass v = do e <- unsafeRead v 0 277 unless (stop e $ pass - 1) $ go' pass v 278 279 go' pass v 280 | len < threshold = I.sortByBounds cmp v 0 len 281 | otherwise = do accumulate count pile 282 permute (radix pass) count pile v 283 recurse 0 284 where 285 len = length v 286 ppass = pass + 1 287 288 recurse i 289 | i < len = do j <- countStripe (radix ppass) (radix pass) count v i 290 go ppass (unsafeSlice i (j - i) v) 291 recurse j 292 | otherwise = return () 293{-# INLINE flagLoop #-} 294 295accumulate :: (PrimMonad m) 296 => PV.MVector (PrimState m) Int 297 -> PV.MVector (PrimState m) Int 298 -> m () 299accumulate count pile = loop 0 0 300 where 301 len = length count 302 303 loop i acc 304 | i < len = do ci <- unsafeRead count i 305 let acc' = acc + ci 306 unsafeWrite pile i acc 307 unsafeWrite count i acc' 308 loop (i+1) acc' 309 | otherwise = return () 310{-# INLINE accumulate #-} 311 312permute :: (PrimMonad m, MVector v e) 313 => (e -> Int) -- radix function 314 -> PV.MVector (PrimState m) Int -- count array 315 -> PV.MVector (PrimState m) Int -- pile array 316 -> v (PrimState m) e -- source array 317 -> m () 318permute rdx count pile v = go 0 319 where 320 len = length v 321 322 go i 323 | i < len = do e <- unsafeRead v i 324 let r = rdx e 325 p <- unsafeRead pile r 326 m <- if r > 0 327 then unsafeRead count (r-1) 328 else return 0 329 case () of 330 -- if the current element is already in the right pile, 331 -- go to the end of the pile 332 _ | m <= i && i < p -> go p 333 -- if the current element happens to be in the right 334 -- pile, bump the pile counter and go to the next element 335 | i == p -> unsafeWrite pile r (p+1) >> go (i+1) 336 -- otherwise follow the chain 337 | otherwise -> follow i e p >> go (i+1) 338 | otherwise = return () 339 340 follow i e j = do en <- unsafeRead v j 341 let r = rdx en 342 p <- inc pile r 343 if p == j 344 -- if the target happens to be in the right pile, don't move it. 345 then follow i e (j+1) 346 else unsafeWrite v j e >> if i == p 347 then unsafeWrite v i en 348 else follow i en p 349{-# INLINE permute #-} 350 351countStripe :: (PrimMonad m, MVector v e) 352 => (e -> Int) -- radix function 353 -> (e -> Int) -- stripe function 354 -> PV.MVector (PrimState m) Int -- count array 355 -> v (PrimState m) e -- source array 356 -> Int -- starting position 357 -> m Int -- end of stripe: [lo,hi) 358countStripe rdx str count v lo = do set count 0 359 e <- unsafeRead v lo 360 go (str e) e (lo+1) 361 where 362 len = length v 363 364 go !s e i = inc count (rdx e) >> 365 if i < len 366 then do en <- unsafeRead v i 367 if str en == s 368 then go s en (i+1) 369 else return i 370 else return len 371{-# INLINE countStripe #-} 372 373threshold :: Int 374threshold = 25 375 376