1{-# LANGUAGE BangPatterns #-}
2{-# LANGUAGE TypeFamilies #-}
3
4-- ---------------------------------------------------------------------------
5-- |
6-- Module      : Data.Vector.Algorithms.Merge
7-- Copyright   : (c) 2008-2011 Dan Doel
8-- Maintainer  : Dan Doel <dan.doel@gmail.com>
9-- Stability   : Experimental
10-- Portability : Portable
11--
12-- This module implements a simple top-down merge sort. The temporary buffer
13-- is preallocated to 1/2 the size of the input array, and shared through
14-- the entire sorting process to ease the amount of allocation performed in
15-- total. This is a stable sort.
16
17module Data.Vector.Algorithms.Merge
18       ( sort
19       , sortBy
20       , Comparison
21       ) where
22
23import Prelude hiding (read, length)
24
25import Control.Monad.Primitive
26
27import Data.Bits
28import Data.Vector.Generic.Mutable
29
30import Data.Vector.Algorithms.Common (Comparison, copyOffset, midPoint)
31
32import qualified Data.Vector.Algorithms.Optimal   as O
33import qualified Data.Vector.Algorithms.Insertion as I
34
35-- | Sorts an array using the default comparison.
36sort :: (PrimMonad m, MVector v e, Ord e) => v (PrimState m) e -> m ()
37sort = sortBy compare
38{-# INLINABLE sort #-}
39
40-- | Sorts an array using a custom comparison.
41sortBy :: (PrimMonad m, MVector v e) => Comparison e -> v (PrimState m) e -> m ()
42sortBy cmp vec = if len <= 4
43                    then if len <= 2
44                            then if len /= 2
45                                    then return ()
46                                    else O.sort2ByOffset cmp vec 0
47                            else if len == 3
48                                    then O.sort3ByOffset cmp vec 0
49                                    else O.sort4ByOffset cmp vec 0
50                    else if len < threshold
51                            then I.sortByBounds cmp vec 0 len
52                            else do buf <- new halfLen
53                                    mergeSortWithBuf cmp vec buf
54 where
55 len     = length vec
56 -- odd lengths have a larger half that needs to fit, so use ceiling, not floor
57 halfLen = (len + 1) `div` 2
58{-# INLINE sortBy #-}
59
60mergeSortWithBuf :: (PrimMonad m, MVector v e)
61                 => Comparison e -> v (PrimState m) e -> v (PrimState m) e -> m ()
62mergeSortWithBuf cmp src buf = loop 0 (length src)
63 where
64 loop l u
65   | len < threshold = I.sortByBounds cmp src l u
66   | otherwise       = do loop l mid
67                          loop mid u
68                          merge cmp (unsafeSlice l len src) buf (mid - l)
69  where len = u - l
70        mid = midPoint u l
71{-# INLINE mergeSortWithBuf #-}
72
73merge :: (PrimMonad m, MVector v e)
74      => Comparison e -> v (PrimState m) e -> v (PrimState m) e
75      -> Int -> m ()
76merge cmp src buf mid = do unsafeCopy tmp lower
77                           eTmp <- unsafeRead tmp 0
78                           eUpp <- unsafeRead upper 0
79                           loop tmp 0 eTmp upper 0 eUpp 0
80 where
81 lower = unsafeSlice 0   mid                src
82 upper = unsafeSlice mid (length src - mid) src
83 tmp   = unsafeSlice 0   mid                buf
84
85 wroteHigh low iLow eLow high iHigh iIns
86   | iHigh >= length high = unsafeCopy (unsafeSlice iIns (length low - iLow) src)
87                                       (unsafeSlice iLow (length low - iLow) low)
88   | otherwise            = do eHigh <- unsafeRead high iHigh
89                               loop low iLow eLow high iHigh eHigh iIns
90
91 wroteLow low iLow high iHigh eHigh iIns
92   | iLow  >= length low  = return ()
93   | otherwise            = do eLow <- unsafeRead low iLow
94                               loop low iLow eLow high iHigh eHigh iIns
95
96 loop !low !iLow !eLow !high !iHigh !eHigh !iIns = case cmp eHigh eLow of
97     LT -> do unsafeWrite src iIns eHigh
98              wroteHigh low iLow eLow high (iHigh + 1) (iIns + 1)
99     _  -> do unsafeWrite src iIns eLow
100              wroteLow low (iLow + 1) high iHigh eHigh (iIns + 1)
101{-# INLINE merge #-}
102
103threshold :: Int
104threshold = 25
105{-# INLINE threshold #-}
106