1{-# LANGUAGE BangPatterns #-}
2{-# LANGUAGE TypeFamilies #-}
3{-# LANGUAGE FlexibleContexts #-}
4{-# lANGUAGE ScopedTypeVariables #-}
5
6-- ---------------------------------------------------------------------------
7-- |
8-- Module      : Data.Vector.Algorithms.AmericanFlag
9-- Copyright   : (c) 2011 Dan Doel
10-- Maintainer  : Dan Doel <dan.doel@gmail.com>
11-- Stability   : Experimental
12-- Portability : Non-portable (FlexibleContexts, ScopedTypeVariables)
13--
14-- This module implements American flag sort: an in-place, unstable, bucket
15-- sort. Also in contrast to radix sort, the values are inspected in a big
16-- endian order, and buckets are sorted via recursive splitting. This,
17-- however, makes it sensible for sorting strings in lexicographic order
18-- (provided indexing is fast).
19--
20-- The algorithm works as follows: at each stage, the array is looped over,
21-- counting the number of elements for each bucket. Then, starting at the
22-- beginning of the array, elements are permuted in place to reside in the
23-- proper bucket, following chains until they reach back to the current
24-- base index. Finally, each bucket is sorted recursively. This lends itself
25-- well to the aforementioned variable-length strings, and so the algorithm
26-- takes a stopping predicate, which is given a representative of the stripe,
27-- rather than running for a set number of iterations.
28
29module Data.Vector.Algorithms.AmericanFlag ( sort
30                                           , sortBy
31                                           , terminate
32                                           , Lexicographic(..)
33                                           ) where
34
35import Prelude hiding (read, length)
36
37import Control.Monad
38import Control.Monad.Primitive
39
40import Data.Proxy
41
42import Data.Word
43import Data.Int
44import Data.Bits
45
46import qualified Data.ByteString as B
47
48import Data.Vector.Generic.Mutable
49import qualified Data.Vector.Primitive.Mutable as PV
50
51import qualified Data.Vector.Unboxed.Mutable as U
52
53import Data.Vector.Algorithms.Common
54
55import qualified Data.Vector.Algorithms.Insertion as I
56
57import Foreign.Storable
58
59-- | The methods of this class specify the information necessary to sort
60-- arrays using the default ordering. The name 'Lexicographic' is meant
61-- to convey that index should return results in a similar way to indexing
62-- into a string.
63class Lexicographic e where
64  -- | Computes the length of a representative of a stripe. It should take 'n'
65  -- passes to sort values of extent 'n'. The extent may not be uniform across
66  -- all values of the type.
67  extent    :: e -> Int
68
69  -- | The size of the bucket array necessary for sorting es
70  size      :: Proxy e -> Int
71  -- | Determines which bucket a given element should inhabit for a
72  -- particular iteration.
73  index     :: Int -> e -> Int
74
75instance Lexicographic Word8 where
76  extent _ = 1
77  {-# INLINE extent #-}
78  size _ = 256
79  {-# INLINE size #-}
80  index _ n = fromIntegral n
81  {-# INLINE index #-}
82
83instance Lexicographic Word16 where
84  extent _ = 2
85  {-# INLINE extent #-}
86  size _ = 256
87  {-# INLINE size #-}
88  index 0 n = fromIntegral $ (n `shiftR`  8) .&. 255
89  index 1 n = fromIntegral $ n .&. 255
90  index _ _ = 0
91  {-# INLINE index #-}
92
93instance Lexicographic Word32 where
94  extent _ = 4
95  {-# INLINE extent #-}
96  size _ = 256
97  {-# INLINE size #-}
98  index 0 n = fromIntegral $ (n `shiftR` 24) .&. 255
99  index 1 n = fromIntegral $ (n `shiftR` 16) .&. 255
100  index 2 n = fromIntegral $ (n `shiftR`  8) .&. 255
101  index 3 n = fromIntegral $ n .&. 255
102  index _ _ = 0
103  {-# INLINE index #-}
104
105instance Lexicographic Word64 where
106  extent _ = 8
107  {-# INLINE extent #-}
108  size _ = 256
109  {-# INLINE size #-}
110  index 0 n = fromIntegral $ (n `shiftR` 56) .&. 255
111  index 1 n = fromIntegral $ (n `shiftR` 48) .&. 255
112  index 2 n = fromIntegral $ (n `shiftR` 40) .&. 255
113  index 3 n = fromIntegral $ (n `shiftR` 32) .&. 255
114  index 4 n = fromIntegral $ (n `shiftR` 24) .&. 255
115  index 5 n = fromIntegral $ (n `shiftR` 16) .&. 255
116  index 6 n = fromIntegral $ (n `shiftR`  8) .&. 255
117  index 7 n = fromIntegral $ n .&. 255
118  index _ _ = 0
119  {-# INLINE index #-}
120
121instance Lexicographic Word where
122  extent _ = sizeOf (0 :: Word)
123  {-# INLINE extent #-}
124  size _ = 256
125  {-# INLINE size #-}
126  index 0 n = fromIntegral $ (n `shiftR` 56) .&. 255
127  index 1 n = fromIntegral $ (n `shiftR` 48) .&. 255
128  index 2 n = fromIntegral $ (n `shiftR` 40) .&. 255
129  index 3 n = fromIntegral $ (n `shiftR` 32) .&. 255
130  index 4 n = fromIntegral $ (n `shiftR` 24) .&. 255
131  index 5 n = fromIntegral $ (n `shiftR` 16) .&. 255
132  index 6 n = fromIntegral $ (n `shiftR`  8) .&. 255
133  index 7 n = fromIntegral $ n .&. 255
134  index _ _ = 0
135  {-# INLINE index #-}
136
137instance Lexicographic Int8 where
138  extent _ = 1
139  {-# INLINE extent #-}
140  size _ = 256
141  {-# INLINE size #-}
142  index _ n = 255 .&. fromIntegral n `xor` 128
143  {-# INLINE index #-}
144
145instance Lexicographic Int16 where
146  extent _ = 2
147  {-# INLINE extent #-}
148  size _ = 256
149  {-# INLINE size #-}
150  index 0 n = fromIntegral $ ((n `xor` minBound) `shiftR` 8) .&. 255
151  index 1 n = fromIntegral $ n .&. 255
152  index _ _ = 0
153  {-# INLINE index #-}
154
155instance Lexicographic Int32 where
156  extent _ = 4
157  {-# INLINE extent #-}
158  size _ = 256
159  {-# INLINE size #-}
160  index 0 n = fromIntegral $ ((n `xor` minBound) `shiftR` 24) .&. 255
161  index 1 n = fromIntegral $ (n `shiftR` 16) .&. 255
162  index 2 n = fromIntegral $ (n `shiftR`  8) .&. 255
163  index 3 n = fromIntegral $ n .&. 255
164  index _ _ = 0
165  {-# INLINE index #-}
166
167instance Lexicographic Int64 where
168  extent _ = 8
169  {-# INLINE extent #-}
170  size _ = 256
171  {-# INLINE size #-}
172  index 0 n = fromIntegral $ ((n `xor` minBound) `shiftR` 56) .&. 255
173  index 1 n = fromIntegral $ (n `shiftR` 48) .&. 255
174  index 2 n = fromIntegral $ (n `shiftR` 40) .&. 255
175  index 3 n = fromIntegral $ (n `shiftR` 32) .&. 255
176  index 4 n = fromIntegral $ (n `shiftR` 24) .&. 255
177  index 5 n = fromIntegral $ (n `shiftR` 16) .&. 255
178  index 6 n = fromIntegral $ (n `shiftR`  8) .&. 255
179  index 7 n = fromIntegral $ n .&. 255
180  index _ _ = 0
181  {-# INLINE index #-}
182
183instance Lexicographic Int where
184  extent _ = sizeOf (0 :: Int)
185  {-# INLINE extent #-}
186  size _ = 256
187  {-# INLINE size #-}
188  index 0 n = ((n `xor` minBound) `shiftR` 56) .&. 255
189  index 1 n = (n `shiftR` 48) .&. 255
190  index 2 n = (n `shiftR` 40) .&. 255
191  index 3 n = (n `shiftR` 32) .&. 255
192  index 4 n = (n `shiftR` 24) .&. 255
193  index 5 n = (n `shiftR` 16) .&. 255
194  index 6 n = (n `shiftR`  8) .&. 255
195  index 7 n = n .&. 255
196  index _ _ = 0
197  {-# INLINE index #-}
198
199instance Lexicographic B.ByteString where
200  extent = B.length
201  {-# INLINE extent #-}
202  size _ = 257
203  {-# INLINE size #-}
204  index i b
205    | i >= B.length b = 0
206    | otherwise       = fromIntegral (B.index b i) + 1
207  {-# INLINE index #-}
208
209instance (Lexicographic a, Lexicographic b) => Lexicographic (a, b) where
210  extent (a,b) = extent a + extent b
211  {-# INLINE extent #-}
212  size _ = size (Proxy :: Proxy a) `max` size (Proxy :: Proxy b)
213  {-# INLINE size #-}
214  index i (a,b)
215    | i >= extent a = index i b
216    | otherwise     = index i a
217  {-# INLINE index #-}
218
219instance (Lexicographic a, Lexicographic b) => Lexicographic (Either a b) where
220  extent (Left  a) = 1 + extent a
221  extent (Right b) = 1 + extent b
222  {-# INLINE extent #-}
223  size _ = size (Proxy :: Proxy a) `max` size (Proxy :: Proxy b)
224  {-# INLINE size #-}
225  index 0 (Left  _) = 0
226  index 0 (Right _) = 1
227  index n (Left  a) = index (n-1) a
228  index n (Right b) = index (n-1) b
229  {-# INLINE index #-}
230
231-- | Given a representative of a stripe and an index number, this
232-- function determines whether to stop sorting.
233terminate :: Lexicographic e => e -> Int -> Bool
234terminate e i = i >= extent e
235{-# INLINE terminate #-}
236
237-- | Sorts an array using the default ordering. Both Lexicographic and
238-- Ord are necessary because the algorithm falls back to insertion sort
239-- for sufficiently small arrays.
240sort :: forall e m v. (PrimMonad m, MVector v e, Lexicographic e, Ord e)
241     => v (PrimState m) e -> m ()
242sort v = sortBy compare terminate (size p) index v
243 where p :: Proxy e
244       p = Proxy
245{-# INLINABLE sort #-}
246
247-- | A fully parameterized version of the sorting algorithm. Again, this
248-- function takes both radix information and a comparison, because the
249-- algorithms falls back to insertion sort for small arrays.
250sortBy :: (PrimMonad m, MVector v e)
251       => Comparison e       -- ^ a comparison for the insertion sort flalback
252       -> (e -> Int -> Bool) -- ^ determines whether a stripe is complete
253       -> Int                -- ^ the number of buckets necessary
254       -> (Int -> e -> Int)  -- ^ the big-endian radix function
255       -> v (PrimState m) e  -- ^ the array to be sorted
256       -> m ()
257sortBy cmp stop buckets radix v
258  | length v == 0 = return ()
259  | otherwise     = do count <- new buckets
260                       pile <- new buckets
261                       countLoop (radix 0) v count
262                       flagLoop cmp stop radix count pile v
263{-# INLINE sortBy #-}
264
265flagLoop :: (PrimMonad m, MVector v e)
266         => Comparison e
267         -> (e -> Int -> Bool)           -- number of passes
268         -> (Int -> e -> Int)            -- radix function
269         -> PV.MVector (PrimState m) Int -- auxiliary count array
270         -> PV.MVector (PrimState m) Int -- auxiliary pile array
271         -> v (PrimState m) e            -- source array
272         -> m ()
273flagLoop cmp stop radix count pile v = go 0 v
274 where
275
276 go pass v = do e <- unsafeRead v 0
277                unless (stop e $ pass - 1) $ go' pass v
278
279 go' pass v
280   | len < threshold = I.sortByBounds cmp v 0 len
281   | otherwise       = do accumulate count pile
282                          permute (radix pass) count pile v
283                          recurse 0
284  where
285  len = length v
286  ppass = pass + 1
287
288  recurse i
289    | i < len   = do j <- countStripe (radix ppass) (radix pass) count v i
290                     go ppass (unsafeSlice i (j - i) v)
291                     recurse j
292    | otherwise = return ()
293{-# INLINE flagLoop #-}
294
295accumulate :: (PrimMonad m)
296           => PV.MVector (PrimState m) Int
297           -> PV.MVector (PrimState m) Int
298           -> m ()
299accumulate count pile = loop 0 0
300 where
301 len = length count
302
303 loop i acc
304   | i < len = do ci <- unsafeRead count i
305                  let acc' = acc + ci
306                  unsafeWrite pile i acc
307                  unsafeWrite count i acc'
308                  loop (i+1) acc'
309   | otherwise    = return ()
310{-# INLINE accumulate #-}
311
312permute :: (PrimMonad m, MVector v e)
313        => (e -> Int)                       -- radix function
314        -> PV.MVector (PrimState m) Int     -- count array
315        -> PV.MVector (PrimState m) Int     -- pile array
316        -> v (PrimState m) e                -- source array
317        -> m ()
318permute rdx count pile v = go 0
319 where
320 len = length v
321
322 go i
323   | i < len   = do e <- unsafeRead v i
324                    let r = rdx e
325                    p <- unsafeRead pile r
326                    m <- if r > 0
327                            then unsafeRead count (r-1)
328                            else return 0
329                    case () of
330                      -- if the current element is already in the right pile,
331                      -- go to the end of the pile
332                      _ | m <= i && i < p  -> go p
333                      -- if the current element happens to be in the right
334                      -- pile, bump the pile counter and go to the next element
335                        | i == p           -> unsafeWrite pile r (p+1) >> go (i+1)
336                      -- otherwise follow the chain
337                        | otherwise        -> follow i e p >> go (i+1)
338   | otherwise = return ()
339
340 follow i e j = do en <- unsafeRead v j
341                   let r = rdx en
342                   p <- inc pile r
343                   if p == j
344                      -- if the target happens to be in the right pile, don't move it.
345                      then follow i e (j+1)
346                      else unsafeWrite v j e >> if i == p
347                                             then unsafeWrite v i en
348                                             else follow i en p
349{-# INLINE permute #-}
350
351countStripe :: (PrimMonad m, MVector v e)
352            => (e -> Int)                   -- radix function
353            -> (e -> Int)                   -- stripe function
354            -> PV.MVector (PrimState m) Int -- count array
355            -> v (PrimState m) e            -- source array
356            -> Int                          -- starting position
357            -> m Int                        -- end of stripe: [lo,hi)
358countStripe rdx str count v lo = do set count 0
359                                    e <- unsafeRead v lo
360                                    go (str e) e (lo+1)
361 where
362 len = length v
363
364 go !s e i = inc count (rdx e) >>
365            if i < len
366               then do en <- unsafeRead v i
367                       if str en == s
368                          then go s en (i+1)
369                          else return i
370                else return len
371{-# INLINE countStripe #-}
372
373threshold :: Int
374threshold = 25
375
376