1{-# LANGUAGE ForeignFunctionInterface #-}
2{-# LANGUAGE RecordWildCards #-}
3{-# LANGUAGE BangPatterns #-}
4{-# LANGUAGE FlexibleContexts, CPP #-}
5
6module BinaryHeapSTM (
7    Entry
8  , newEntry
9  , renewEntry
10  , item
11  , PriorityQueue(..)
12  , new
13  , enqueue
14  , dequeue
15  , delete
16  ) where
17
18#if __GLASGOW_HASKELL__ < 709
19import Data.Word (Word)
20#endif
21import Control.Concurrent.STM
22import Control.Monad (when, void)
23import Data.Array (Array, listArray, (!))
24import Data.Array.MArray (newArray_, readArray, writeArray)
25
26----------------------------------------------------------------
27
28type Weight = Int
29type Deficit = Word
30
31-- | Abstract data type of entries for priority queues.
32data Entry a = Entry {
33    weight  :: {-# UNPACK #-} !Weight
34  , item    :: {-# UNPACK #-} !(TVar a) -- ^ Extracting an item from an entry.
35  , deficit :: {-# UNPACK #-} !(TVar Deficit)
36  , index   :: {-# UNPACK #-} !(TVar Index)
37  }
38
39newEntry :: a -> Weight -> STM (Entry a)
40newEntry x w = Entry w <$> newTVar x <*> newTVar magicDeficit <*> newTVar (-1)
41
42-- | Changing the item of an entry.
43renewEntry :: Entry a -> a -> STM ()
44renewEntry Entry{..} x = writeTVar item x
45
46----------------------------------------------------------------
47
48type Index = Int
49type MA a = TArray Index (Entry a)
50
51-- FIXME: The base (Word64) would be overflowed.
52--        In that case, the heap must be re-constructed.
53data PriorityQueue a = PriorityQueue (TVar Deficit)
54                                     (TVar Index)
55                                     (MA a)
56
57----------------------------------------------------------------
58
59magicDeficit :: Deficit
60magicDeficit = 0
61
62deficitSteps :: Int
63deficitSteps = 65536
64
65deficitStepsW :: Word
66deficitStepsW = fromIntegral deficitSteps
67
68deficitList :: [Deficit]
69deficitList = map calc idxs
70  where
71    idxs = [1..256] :: [Double]
72    calc w = round (fromIntegral deficitSteps / w)
73
74deficitTable :: Array Index Deficit
75deficitTable = listArray (1,256) deficitList
76
77weightToDeficit :: Weight -> Deficit
78weightToDeficit w = deficitTable ! w
79
80----------------------------------------------------------------
81
82new :: Int -> STM (PriorityQueue a)
83new n = PriorityQueue <$> newTVar 0
84                      <*> newTVar 1
85                      <*> newArray_ (1,n)
86
87-- | Enqueuing an entry. PriorityQueue is updated.
88enqueue :: Entry a -> PriorityQueue a -> STM ()
89enqueue ent@Entry{..} (PriorityQueue bref idx arr) = do
90    i <- readTVar idx
91    base <- readTVar bref
92    d <- readTVar deficit
93    let !b = if d == magicDeficit then base else d
94        !d' = b + weightToDeficit weight
95    writeTVar deficit d'
96    write arr i ent
97    shiftUp arr i
98    let !i' = i + 1
99    writeTVar idx i'
100    return ()
101
102-- | Dequeuing an entry. PriorityQueue is updated.
103dequeue :: PriorityQueue a -> STM (Entry a)
104dequeue (PriorityQueue bref idx arr) = do
105    ent <- shrink arr 1 idx
106    i <- readTVar idx
107    shiftDown arr 1 i
108    d <- readTVar $ deficit ent
109    writeTVar bref $ if i == 1 then 0 else d
110    return ent
111
112shrink :: MA a -> Index -> TVar Index -> STM (Entry a)
113shrink arr r idx = do
114    entr <- readArray arr r
115    -- fixme: checking if i == 0
116    i <- subtract 1 <$> readTVar idx
117    xi <- readArray arr i
118    write arr r xi
119    writeTVar idx i
120    return entr
121
122shiftUp :: MA a -> Int -> STM ()
123shiftUp _   1 = return ()
124shiftUp arr c = do
125    swapped <- swap arr p c
126    when swapped $ shiftUp arr p
127  where
128    p = c `div` 2
129
130shiftDown :: MA a -> Int -> Int -> STM ()
131shiftDown arr p n
132  | c1 > n    = return ()
133  | c1 == n   = void $ swap arr p c1
134  | otherwise = do
135      let !c2 = c1 + 1
136      xc1 <- readArray arr c1
137      xc2 <- readArray arr c2
138      d1 <- readTVar $ deficit xc1
139      d2 <- readTVar $ deficit xc2
140      let !c = if d1 /= d2 && d2 - d1 <= deficitStepsW then c1 else c2
141      swapped <- swap arr p c
142      when swapped $ shiftDown arr c n
143  where
144    c1 = 2 * p
145
146{-# INLINE swap #-}
147swap :: MA a -> Index -> Index -> STM Bool
148swap arr p c = do
149    xp <- readArray arr p
150    xc <- readArray arr c
151    dp <- readTVar $ deficit xp
152    dc <- readTVar $ deficit xc
153    if dc < dp then do
154        write arr c xp
155        write arr p xc
156        return True
157      else
158        return False
159
160{-# INLINE write #-}
161write :: MA a -> Index -> Entry a -> STM ()
162write arr i ent = do
163    writeArray arr i ent
164    writeTVar (index ent) i
165
166delete :: Entry a -> PriorityQueue a -> STM ()
167delete ent pq@(PriorityQueue _ idx arr) = do
168    i <- readTVar $ index ent
169    if i == 1 then
170        void $ dequeue pq
171      else do
172        entr <- shrink arr i idx
173        r <- readTVar $ index entr
174        shiftDown arr r (i - 1)
175        shiftUp arr r
176