1{-# LANGUAGE BangPatterns #-}
2{-# LANGUAGE TypeFamilies #-}
3{-# LANGUAGE TypeOperators #-}
4{-# LANGUAGE ScopedTypeVariables #-}
5
6-- ---------------------------------------------------------------------------
7-- |
8-- Module      : Data.Vector.Algorithms.Radix
9-- Copyright   : (c) 2008-2011 Dan Doel
10-- Maintainer  : Dan Doel <dan.doel@gmail.com>
11-- Stability   : Experimental
12-- Portability : Non-portable (scoped type variables, bang patterns)
13--
14-- This module provides a radix sort for a subclass of unboxed arrays. The
15-- radix class gives information on
16--   * the number of passes needed for the data type
17--
18--   * the size of the auxiliary arrays
19--
20--   * how to compute the pass-k radix of a value
21--
22-- Radix sort is not a comparison sort, so it is able to achieve O(n) run
23-- time, though it also uses O(n) auxiliary space. In addition, there is a
24-- constant space overhead of 2*size*sizeOf(Int) for the sort, so it is not
25-- advisable to use this sort for large numbers of very small arrays.
26--
27-- A standard example (upon which one could base their own Radix instance)
28-- is Word32:
29--
30--   * We choose to sort on r = 8 bits at a time
31--
32--   * A Word32 has b = 32 bits total
33--
34--   Thus, b/r = 4 passes are required, 2^r = 256 elements are needed in an
35--   auxiliary array, and the radix function is:
36--
37--    > radix k e = (e `shiftR` (k*8)) .&. 255
38
39module Data.Vector.Algorithms.Radix (sort, sortBy, Radix(..)) where
40
41import Prelude hiding (read, length)
42
43import Control.Monad
44import Control.Monad.Primitive
45
46import qualified Data.Vector.Primitive.Mutable as PV
47import Data.Vector.Generic.Mutable
48
49import Data.Vector.Algorithms.Common
50
51import Data.Bits
52import Data.Int
53import Data.Word
54
55
56import Foreign.Storable
57
58class Radix e where
59  -- | The number of passes necessary to sort an array of es
60  passes :: e -> Int
61  -- | The size of an auxiliary array
62  size   :: e -> Int
63  -- | The radix function parameterized by the current pass
64  radix  :: Int -> e -> Int
65
66instance Radix Int where
67  passes _ = sizeOf (undefined :: Int)
68  {-# INLINE passes #-}
69  size _ = 256
70  {-# INLINE size #-}
71  radix 0 e = e .&. 255
72  radix i e
73    | i == passes e - 1 = radix' (e `xor` minBound)
74    | otherwise         = radix' e
75   where radix' e = (e `shiftR` (i `shiftL` 3)) .&. 255
76  {-# INLINE radix #-}
77
78instance Radix Int8 where
79  passes _ = 1
80  {-# INLINE passes #-}
81  size _ = 256
82  {-# INLINE size #-}
83  radix _ e = 255 .&. fromIntegral e `xor` 128
84  {-# INLINE radix #-}
85
86instance Radix Int16 where
87  passes _ = 2
88  {-# INLINE passes #-}
89  size _ = 256
90  {-# INLINE size #-}
91  radix 0 e = fromIntegral (e .&. 255)
92  radix 1 e = fromIntegral (((e `xor` minBound) `shiftR` 8) .&. 255)
93  {-# INLINE radix #-}
94
95instance Radix Int32 where
96  passes _ = 4
97  {-# INLINE passes #-}
98  size _ = 256
99  {-# INLINE size #-}
100  radix 0 e = fromIntegral (e .&. 255)
101  radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
102  radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
103  radix 3 e = fromIntegral (((e `xor` minBound) `shiftR` 24) .&. 255)
104  {-# INLINE radix #-}
105
106instance Radix Int64 where
107  passes _ = 8
108  {-# INLINE passes #-}
109  size _ = 256
110  {-# INLINE size #-}
111  radix 0 e = fromIntegral (e .&. 255)
112  radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
113  radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
114  radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255)
115  radix 4 e = fromIntegral ((e `shiftR` 32) .&. 255)
116  radix 5 e = fromIntegral ((e `shiftR` 40) .&. 255)
117  radix 6 e = fromIntegral ((e `shiftR` 48) .&. 255)
118  radix 7 e = fromIntegral (((e `xor` minBound) `shiftR` 56) .&. 255)
119  {-# INLINE radix #-}
120
121instance Radix Word where
122  passes _ = sizeOf (undefined :: Word)
123  {-# INLINE passes #-}
124  size _ = 256
125  {-# INLINE size #-}
126  radix 0 e = fromIntegral (e .&. 255)
127  radix i e = fromIntegral ((e `shiftR` (i `shiftL` 3)) .&. 255)
128  {-# INLINE radix #-}
129
130instance Radix Word8 where
131  passes _ = 1
132  {-# INLINE passes #-}
133  size _ = 256
134  {-# INLINE size #-}
135  radix _ = fromIntegral
136  {-# INLINE radix #-}
137
138instance Radix Word16 where
139  passes _ = 2
140  {-# INLINE passes #-}
141  size   _ = 256
142  {-# INLINE size #-}
143  radix 0 e = fromIntegral (e .&. 255)
144  radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
145  {-# INLINE radix #-}
146
147instance Radix Word32 where
148  passes _ = 4
149  {-# INLINE passes #-}
150  size   _ = 256
151  {-# INLINE size #-}
152  radix 0 e = fromIntegral (e .&. 255)
153  radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
154  radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
155  radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255)
156  {-# INLINE radix #-}
157
158instance Radix Word64 where
159  passes _ = 8
160  {-# INLINE passes #-}
161  size   _ = 256
162  {-# INLINE size #-}
163  radix 0 e = fromIntegral (e .&. 255)
164  radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
165  radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
166  radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255)
167  radix 4 e = fromIntegral ((e `shiftR` 32) .&. 255)
168  radix 5 e = fromIntegral ((e `shiftR` 40) .&. 255)
169  radix 6 e = fromIntegral ((e `shiftR` 48) .&. 255)
170  radix 7 e = fromIntegral ((e `shiftR` 56) .&. 255)
171  {-# INLINE radix #-}
172
173instance (Radix i, Radix j) => Radix (i, j) where
174  passes ~(i, j) = passes i + passes j
175  {-# INLINE passes #-}
176  size   ~(i, j) = size i `max` size j
177  {-# INLINE size #-}
178  radix k ~(i, j) | k < passes j = radix k j
179                     | otherwise    = radix (k - passes j) i
180  {-# INLINE radix #-}
181
182-- | Sorts an array based on the Radix instance.
183sort :: forall e m v. (PrimMonad m, MVector v e, Radix e)
184     => v (PrimState m) e -> m ()
185sort arr = sortBy (passes e) (size e) radix arr
186 where
187 e :: e
188 e = undefined
189{-# INLINABLE sort #-}
190
191-- | Radix sorts an array using custom radix information
192-- requires the number of passes to fully sort the array,
193-- the size of of auxiliary arrays necessary (should be
194-- one greater than the maximum value returned by the radix
195-- function), and a radix function, which takes the pass
196-- and an element, and returns the relevant radix.
197sortBy :: (PrimMonad m, MVector v e)
198       => Int               -- ^ the number of passes
199       -> Int               -- ^ the size of auxiliary arrays
200       -> (Int -> e -> Int) -- ^ the radix function
201       -> v (PrimState m) e -- ^ the array to be sorted
202       -> m ()
203sortBy passes size rdx arr = do
204  tmp    <- new (length arr)
205  count  <- new size
206  radixLoop passes rdx arr tmp count
207{-# INLINE sortBy #-}
208
209radixLoop :: (PrimMonad m, MVector v e)
210          => Int                          -- passes
211          -> (Int -> e -> Int)            -- radix function
212          -> v (PrimState m) e            -- array to sort
213          -> v (PrimState m) e            -- temporary array
214          -> PV.MVector (PrimState m) Int -- radix count array
215          -> m ()
216radixLoop passes rdx src dst count = go False 0
217 where
218 len = length src
219 go swap k
220   | k < passes = if swap
221                    then body rdx dst src count k >> go (not swap) (k+1)
222                    else body rdx src dst count k >> go (not swap) (k+1)
223   | otherwise  = when swap (unsafeCopy src dst)
224{-# INLINE radixLoop #-}
225
226body :: (PrimMonad m, MVector v e)
227     => (Int -> e -> Int)            -- radix function
228     -> v (PrimState m) e            -- source array
229     -> v (PrimState m) e            -- destination array
230     -> PV.MVector (PrimState m) Int -- radix count
231     -> Int                          -- current pass
232     -> m ()
233body rdx src dst count k = do
234  countLoop (rdx k) src count
235  accumulate count
236  moveLoop k rdx src dst count
237{-# INLINE body #-}
238
239accumulate :: (PrimMonad m)
240           => PV.MVector (PrimState m) Int -> m ()
241accumulate count = go 0 0
242 where
243 len = length count
244 go i acc
245   | i < len   = do ci <- unsafeRead count i
246                    unsafeWrite count i acc
247                    go (i+1) (acc + ci)
248   | otherwise = return ()
249{-# INLINE accumulate #-}
250
251moveLoop :: (PrimMonad m, MVector v e)
252         => Int -> (Int -> e -> Int) -> v (PrimState m) e
253         -> v (PrimState m) e -> PV.MVector (PrimState m) Int -> m ()
254moveLoop k rdx src dst prefix = go 0
255 where
256 len = length src
257 go i
258   | i < len    = do srci <- unsafeRead src i
259                     pf   <- inc prefix (rdx k srci)
260                     unsafeWrite dst pf srci
261                     go (i+1)
262   | otherwise  = return ()
263{-# INLINE moveLoop #-}
264
265