1{-# LANGUAGE BangPatterns #-} 2{-# LANGUAGE TypeFamilies #-} 3{-# LANGUAGE TypeOperators #-} 4{-# LANGUAGE ScopedTypeVariables #-} 5 6-- --------------------------------------------------------------------------- 7-- | 8-- Module : Data.Vector.Algorithms.Radix 9-- Copyright : (c) 2008-2011 Dan Doel 10-- Maintainer : Dan Doel <dan.doel@gmail.com> 11-- Stability : Experimental 12-- Portability : Non-portable (scoped type variables, bang patterns) 13-- 14-- This module provides a radix sort for a subclass of unboxed arrays. The 15-- radix class gives information on 16-- * the number of passes needed for the data type 17-- 18-- * the size of the auxiliary arrays 19-- 20-- * how to compute the pass-k radix of a value 21-- 22-- Radix sort is not a comparison sort, so it is able to achieve O(n) run 23-- time, though it also uses O(n) auxiliary space. In addition, there is a 24-- constant space overhead of 2*size*sizeOf(Int) for the sort, so it is not 25-- advisable to use this sort for large numbers of very small arrays. 26-- 27-- A standard example (upon which one could base their own Radix instance) 28-- is Word32: 29-- 30-- * We choose to sort on r = 8 bits at a time 31-- 32-- * A Word32 has b = 32 bits total 33-- 34-- Thus, b/r = 4 passes are required, 2^r = 256 elements are needed in an 35-- auxiliary array, and the radix function is: 36-- 37-- > radix k e = (e `shiftR` (k*8)) .&. 255 38 39module Data.Vector.Algorithms.Radix (sort, sortBy, Radix(..)) where 40 41import Prelude hiding (read, length) 42 43import Control.Monad 44import Control.Monad.Primitive 45 46import qualified Data.Vector.Primitive.Mutable as PV 47import Data.Vector.Generic.Mutable 48 49import Data.Vector.Algorithms.Common 50 51import Data.Bits 52import Data.Int 53import Data.Word 54 55 56import Foreign.Storable 57 58class Radix e where 59 -- | The number of passes necessary to sort an array of es 60 passes :: e -> Int 61 -- | The size of an auxiliary array 62 size :: e -> Int 63 -- | The radix function parameterized by the current pass 64 radix :: Int -> e -> Int 65 66instance Radix Int where 67 passes _ = sizeOf (undefined :: Int) 68 {-# INLINE passes #-} 69 size _ = 256 70 {-# INLINE size #-} 71 radix 0 e = e .&. 255 72 radix i e 73 | i == passes e - 1 = radix' (e `xor` minBound) 74 | otherwise = radix' e 75 where radix' e = (e `shiftR` (i `shiftL` 3)) .&. 255 76 {-# INLINE radix #-} 77 78instance Radix Int8 where 79 passes _ = 1 80 {-# INLINE passes #-} 81 size _ = 256 82 {-# INLINE size #-} 83 radix _ e = 255 .&. fromIntegral e `xor` 128 84 {-# INLINE radix #-} 85 86instance Radix Int16 where 87 passes _ = 2 88 {-# INLINE passes #-} 89 size _ = 256 90 {-# INLINE size #-} 91 radix 0 e = fromIntegral (e .&. 255) 92 radix 1 e = fromIntegral (((e `xor` minBound) `shiftR` 8) .&. 255) 93 {-# INLINE radix #-} 94 95instance Radix Int32 where 96 passes _ = 4 97 {-# INLINE passes #-} 98 size _ = 256 99 {-# INLINE size #-} 100 radix 0 e = fromIntegral (e .&. 255) 101 radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255) 102 radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255) 103 radix 3 e = fromIntegral (((e `xor` minBound) `shiftR` 24) .&. 255) 104 {-# INLINE radix #-} 105 106instance Radix Int64 where 107 passes _ = 8 108 {-# INLINE passes #-} 109 size _ = 256 110 {-# INLINE size #-} 111 radix 0 e = fromIntegral (e .&. 255) 112 radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255) 113 radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255) 114 radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255) 115 radix 4 e = fromIntegral ((e `shiftR` 32) .&. 255) 116 radix 5 e = fromIntegral ((e `shiftR` 40) .&. 255) 117 radix 6 e = fromIntegral ((e `shiftR` 48) .&. 255) 118 radix 7 e = fromIntegral (((e `xor` minBound) `shiftR` 56) .&. 255) 119 {-# INLINE radix #-} 120 121instance Radix Word where 122 passes _ = sizeOf (undefined :: Word) 123 {-# INLINE passes #-} 124 size _ = 256 125 {-# INLINE size #-} 126 radix 0 e = fromIntegral (e .&. 255) 127 radix i e = fromIntegral ((e `shiftR` (i `shiftL` 3)) .&. 255) 128 {-# INLINE radix #-} 129 130instance Radix Word8 where 131 passes _ = 1 132 {-# INLINE passes #-} 133 size _ = 256 134 {-# INLINE size #-} 135 radix _ = fromIntegral 136 {-# INLINE radix #-} 137 138instance Radix Word16 where 139 passes _ = 2 140 {-# INLINE passes #-} 141 size _ = 256 142 {-# INLINE size #-} 143 radix 0 e = fromIntegral (e .&. 255) 144 radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255) 145 {-# INLINE radix #-} 146 147instance Radix Word32 where 148 passes _ = 4 149 {-# INLINE passes #-} 150 size _ = 256 151 {-# INLINE size #-} 152 radix 0 e = fromIntegral (e .&. 255) 153 radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255) 154 radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255) 155 radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255) 156 {-# INLINE radix #-} 157 158instance Radix Word64 where 159 passes _ = 8 160 {-# INLINE passes #-} 161 size _ = 256 162 {-# INLINE size #-} 163 radix 0 e = fromIntegral (e .&. 255) 164 radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255) 165 radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255) 166 radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255) 167 radix 4 e = fromIntegral ((e `shiftR` 32) .&. 255) 168 radix 5 e = fromIntegral ((e `shiftR` 40) .&. 255) 169 radix 6 e = fromIntegral ((e `shiftR` 48) .&. 255) 170 radix 7 e = fromIntegral ((e `shiftR` 56) .&. 255) 171 {-# INLINE radix #-} 172 173instance (Radix i, Radix j) => Radix (i, j) where 174 passes ~(i, j) = passes i + passes j 175 {-# INLINE passes #-} 176 size ~(i, j) = size i `max` size j 177 {-# INLINE size #-} 178 radix k ~(i, j) | k < passes j = radix k j 179 | otherwise = radix (k - passes j) i 180 {-# INLINE radix #-} 181 182-- | Sorts an array based on the Radix instance. 183sort :: forall e m v. (PrimMonad m, MVector v e, Radix e) 184 => v (PrimState m) e -> m () 185sort arr = sortBy (passes e) (size e) radix arr 186 where 187 e :: e 188 e = undefined 189{-# INLINABLE sort #-} 190 191-- | Radix sorts an array using custom radix information 192-- requires the number of passes to fully sort the array, 193-- the size of of auxiliary arrays necessary (should be 194-- one greater than the maximum value returned by the radix 195-- function), and a radix function, which takes the pass 196-- and an element, and returns the relevant radix. 197sortBy :: (PrimMonad m, MVector v e) 198 => Int -- ^ the number of passes 199 -> Int -- ^ the size of auxiliary arrays 200 -> (Int -> e -> Int) -- ^ the radix function 201 -> v (PrimState m) e -- ^ the array to be sorted 202 -> m () 203sortBy passes size rdx arr = do 204 tmp <- new (length arr) 205 count <- new size 206 radixLoop passes rdx arr tmp count 207{-# INLINE sortBy #-} 208 209radixLoop :: (PrimMonad m, MVector v e) 210 => Int -- passes 211 -> (Int -> e -> Int) -- radix function 212 -> v (PrimState m) e -- array to sort 213 -> v (PrimState m) e -- temporary array 214 -> PV.MVector (PrimState m) Int -- radix count array 215 -> m () 216radixLoop passes rdx src dst count = go False 0 217 where 218 len = length src 219 go swap k 220 | k < passes = if swap 221 then body rdx dst src count k >> go (not swap) (k+1) 222 else body rdx src dst count k >> go (not swap) (k+1) 223 | otherwise = when swap (unsafeCopy src dst) 224{-# INLINE radixLoop #-} 225 226body :: (PrimMonad m, MVector v e) 227 => (Int -> e -> Int) -- radix function 228 -> v (PrimState m) e -- source array 229 -> v (PrimState m) e -- destination array 230 -> PV.MVector (PrimState m) Int -- radix count 231 -> Int -- current pass 232 -> m () 233body rdx src dst count k = do 234 countLoop (rdx k) src count 235 accumulate count 236 moveLoop k rdx src dst count 237{-# INLINE body #-} 238 239accumulate :: (PrimMonad m) 240 => PV.MVector (PrimState m) Int -> m () 241accumulate count = go 0 0 242 where 243 len = length count 244 go i acc 245 | i < len = do ci <- unsafeRead count i 246 unsafeWrite count i acc 247 go (i+1) (acc + ci) 248 | otherwise = return () 249{-# INLINE accumulate #-} 250 251moveLoop :: (PrimMonad m, MVector v e) 252 => Int -> (Int -> e -> Int) -> v (PrimState m) e 253 -> v (PrimState m) e -> PV.MVector (PrimState m) Int -> m () 254moveLoop k rdx src dst prefix = go 0 255 where 256 len = length src 257 go i 258 | i < len = do srci <- unsafeRead src i 259 pf <- inc prefix (rdx k srci) 260 unsafeWrite dst pf srci 261 go (i+1) 262 | otherwise = return () 263{-# INLINE moveLoop #-} 264 265