1{-# LANGUAGE BangPatterns               #-}
2{-# LANGUAGE DeriveFoldable             #-}
3{-# LANGUAGE DeriveFunctor              #-}
4{-# LANGUAGE DeriveTraversable          #-}
5{-# LANGUAGE GeneralizedNewtypeDeriving #-}
6{-# LANGUAGE ScopedTypeVariables        #-}
7module Data.HashPSQ.Internal
8    ( -- * Type
9      Bucket (..)
10    , mkBucket
11    , HashPSQ (..)
12
13      -- * Query
14    , null
15    , size
16    , member
17    , lookup
18    , findMin
19
20      -- * Construction
21    , empty
22    , singleton
23
24      -- * Insertion
25    , insert
26
27      -- * Delete/update
28    , delete
29    , deleteMin
30    , alter
31    , alterMin
32
33      -- * Lists
34    , fromList
35    , toList
36    , keys
37
38      -- * Views
39    , insertView
40    , deleteView
41    , minView
42    , atMostView
43
44      -- * Traversal
45    , map
46    , unsafeMapMonotonic
47    , fold'
48
49      -- * Unsafe operations
50    , unsafeLookupIncreasePriority
51    , unsafeInsertIncreasePriority
52    , unsafeInsertIncreasePriorityView
53
54      -- * Validity check
55    , valid
56    ) where
57
58import           Control.DeepSeq      (NFData (..))
59import           Data.Foldable        (Foldable)
60import           Data.Hashable
61import qualified Data.List            as List
62import           Data.Maybe           (isJust)
63import           Data.Traversable
64import           Prelude              hiding (foldr, lookup, map, null)
65
66import qualified Data.IntPSQ.Internal as IntPSQ
67import qualified Data.OrdPSQ          as OrdPSQ
68
69------------------------------------------------------------------------------
70-- Types
71------------------------------------------------------------------------------
72
73data Bucket k p v = B !k !v !(OrdPSQ.OrdPSQ k p v)
74    deriving (Foldable, Functor, Show, Traversable)
75
76-- | Smart constructor which takes care of placing the minimum element directly
77-- in the 'Bucket'.
78{-# INLINABLE mkBucket #-}
79mkBucket
80    :: (Ord k, Ord p)
81    => k -> p -> v -> OrdPSQ.OrdPSQ k p v -> (p, Bucket k p v)
82mkBucket k p x opsq =
83    -- TODO (jaspervdj): We could do an 'unsafeInsertNew' here for all call
84    -- sites.
85    case toBucket (OrdPSQ.insert k p x opsq) of
86        Just bucket -> bucket
87        Nothing     -> error $ "mkBucket: internal error"
88
89toBucket :: (Ord k, Ord p) => OrdPSQ.OrdPSQ k p v -> Maybe (p, Bucket k p v)
90toBucket opsq = case OrdPSQ.minView opsq of
91    Just (k, p, x, opsq') -> Just (p, B k x opsq')
92    Nothing               -> Nothing
93
94instance (NFData k, NFData p, NFData v) => NFData (Bucket k p v) where
95    rnf (B k v x) = rnf k `seq` rnf v `seq` rnf x
96
97-- | A priority search queue with keys of type @k@ and priorities of type @p@
98-- and values of type @v@. It is strict in keys, priorities and values.
99newtype HashPSQ k p v = HashPSQ (IntPSQ.IntPSQ p (Bucket k p v))
100    deriving (Foldable, Functor, NFData, Show, Traversable)
101
102instance (Eq k, Eq p, Eq v, Hashable k, Ord k, Ord p) =>
103            Eq (HashPSQ k p v) where
104    x == y = case (minView x, minView y) of
105        (Nothing              , Nothing                ) -> True
106        (Just (xk, xp, xv, x'), (Just (yk, yp, yv, y'))) ->
107            xk == yk && xp == yp && xv == yv && x' == y'
108        (Just _               , Nothing                ) -> False
109        (Nothing              , Just _                 ) -> False
110
111
112------------------------------------------------------------------------------
113-- Query
114------------------------------------------------------------------------------
115
116-- | /O(1)/ True if the queue is empty.
117{-# INLINABLE null #-}
118null :: HashPSQ k p v -> Bool
119null (HashPSQ ipsq) = IntPSQ.null ipsq
120
121-- | /O(n)/ The number of elements stored in the PSQ.
122{-# INLINABLE size #-}
123size :: HashPSQ k p v -> Int
124size (HashPSQ ipsq) = IntPSQ.fold'
125    (\_ _ (B _ _ opsq) acc -> 1 + OrdPSQ.size opsq + acc)
126    0
127    ipsq
128
129-- | /O(min(n,W))/ Check if a key is present in the the queue.
130{-# INLINABLE member #-}
131member :: (Hashable k, Ord k, Ord p) => k -> HashPSQ k p v -> Bool
132member k = isJust . lookup k
133
134-- | /O(min(n,W))/ The priority and value of a given key, or 'Nothing' if the
135-- key is not bound.
136{-# INLINABLE lookup #-}
137lookup :: (Ord k, Hashable k, Ord p) => k -> HashPSQ k p v -> Maybe (p, v)
138lookup k (HashPSQ ipsq) = do
139    (p0, B k0 v0 os) <- IntPSQ.lookup (hash k) ipsq
140    if k0 == k
141        then return (p0, v0)
142        else OrdPSQ.lookup k os
143
144-- | /O(1)/ The element with the lowest priority.
145findMin :: (Hashable k, Ord k, Ord p) => HashPSQ k p v -> Maybe (k, p, v)
146findMin (HashPSQ ipsq) = case IntPSQ.findMin ipsq of
147    Nothing              -> Nothing
148    Just (_, p, B k x _) -> Just (k, p, x)
149
150
151--------------------------------------------------------------------------------
152-- Construction
153--------------------------------------------------------------------------------
154
155-- | /O(1)/ The empty queue.
156empty :: HashPSQ k p v
157empty = HashPSQ IntPSQ.empty
158
159-- | /O(1)/ Build a queue with one element.
160singleton :: (Hashable k, Ord k, Ord p) => k -> p -> v -> HashPSQ k p v
161singleton k p v = insert k p v empty
162
163
164--------------------------------------------------------------------------------
165-- Insertion
166--------------------------------------------------------------------------------
167
168-- | /O(min(n,W))/ Insert a new key, priority and value into the queue. If the key
169-- is already present in the queue, the associated priority and value are
170-- replaced with the supplied priority and value.
171{-# INLINABLE insert #-}
172insert
173    :: (Ord k, Hashable k, Ord p)
174    => k -> p -> v -> HashPSQ k p v -> HashPSQ k p v
175insert k p v (HashPSQ ipsq) =
176    case IntPSQ.alter (\x -> ((), ins x)) (hash k) ipsq of
177        ((), ipsq') -> HashPSQ ipsq'
178  where
179    ins Nothing                         = Just (p,  B k  v  (OrdPSQ.empty))
180    ins (Just (p', B k' v' os))
181        | k' == k                       =
182            -- Tricky: p might have less priority than an item in 'os'.
183            Just (mkBucket k p v os)
184        | p' < p || (p == p' && k' < k) =
185            Just (p', B k' v' (OrdPSQ.insert k  p  v  os))
186        | OrdPSQ.member k os            =
187            -- This is a bit tricky: k might already be present in 'os' and we
188            -- don't want to end up with duplicate keys.
189            Just (p,  B k  v  (OrdPSQ.insert k' p' v' (OrdPSQ.delete k os)))
190        | otherwise                     =
191            Just (p , B k  v  (OrdPSQ.insert k' p' v' os))
192
193
194--------------------------------------------------------------------------------
195-- Delete/update
196--------------------------------------------------------------------------------
197
198-- | /O(min(n,W))/ Delete a key and its priority and value from the queue. When
199-- the key is not a member of the queue, the original queue is returned.
200{-# INLINE delete #-}
201delete
202    :: (Hashable k, Ord k, Ord p) => k -> HashPSQ k p v -> HashPSQ k p v
203delete k t = case deleteView k t of
204    Nothing         -> t
205    Just (_, _, t') -> t'
206
207-- | /O(min(n,W))/ Delete the binding with the least priority, and return the
208-- rest of the queue stripped of that binding. In case the queue is empty, the
209-- empty queue is returned again.
210{-# INLINE deleteMin #-}
211deleteMin
212    :: (Hashable k, Ord k, Ord p) => HashPSQ k p v -> HashPSQ k p v
213deleteMin t = case minView t of
214    Nothing            -> t
215    Just (_, _, _, t') -> t'
216
217-- | /O(min(n,W))/ The expression @alter f k queue@ alters the value @x@ at @k@,
218-- or absence thereof. 'alter' can be used to insert, delete, or update a value
219-- in a queue. It also allows you to calculate an additional value @b@.
220{-# INLINABLE alter #-}
221alter :: (Hashable k, Ord k, Ord p)
222      => (Maybe (p, v) -> (b, Maybe (p, v)))
223      -> k -> HashPSQ k p v -> (b, HashPSQ k p v)
224alter f k (HashPSQ ipsq) = case IntPSQ.deleteView h ipsq of
225    Nothing -> case f Nothing of
226        (b, Nothing)     -> (b, HashPSQ ipsq)
227        (b, Just (p, x)) ->
228            (b, HashPSQ $ IntPSQ.unsafeInsertNew h p (B k x OrdPSQ.empty) ipsq)
229    Just (bp, B bk bx opsq, ipsq')
230        | k == bk   -> case f (Just (bp, bx)) of
231            (b, Nothing) -> case toBucket opsq of
232                Nothing             -> (b, HashPSQ ipsq')
233                Just (bp', bucket') ->
234                    (b, HashPSQ $ IntPSQ.unsafeInsertNew h bp' bucket' ipsq')
235            (b, Just (p, x)) -> case mkBucket k p x opsq of
236                (bp', bucket') ->
237                    (b, HashPSQ $ IntPSQ.unsafeInsertNew h bp' bucket' ipsq')
238        | otherwise -> case OrdPSQ.alter f k opsq of
239            (b, opsq') -> case mkBucket bk bp bx opsq' of
240                (bp', bucket') ->
241                    (b, HashPSQ $ IntPSQ.unsafeInsertNew h bp' bucket' ipsq')
242  where
243    h = hash k
244
245-- | /O(min(n,W))/ A variant of 'alter' which works on the element with the
246-- minimum priority. Unlike 'alter', this variant also allows you to change the
247-- key of the element.
248{-# INLINABLE alterMin #-}
249alterMin
250    :: (Hashable k, Ord k, Ord p)
251     => (Maybe (k, p, v) -> (b, Maybe (k, p, v)))
252     -> HashPSQ k p v
253     -> (b, HashPSQ k p v)
254alterMin f t0 =
255    let (t, mbX) = case minView t0 of
256                    Nothing             -> (t0, Nothing)
257                    Just (k, p, x, t0') -> (t0', Just (k, p, x))
258    in case f mbX of
259        (b, mbX') ->
260            (b, maybe t (\(k, p, x) -> insert k p x t) mbX')
261
262
263--------------------------------------------------------------------------------
264-- Lists
265--------------------------------------------------------------------------------
266
267-- | /O(n*min(n,W))/ Build a queue from a list of (key, priority, value) tuples.
268-- If the list contains more than one priority and value for the same key, the
269-- last priority and value for the key is retained.
270{-# INLINABLE fromList #-}
271fromList :: (Hashable k, Ord k, Ord p) => [(k, p, v)] -> HashPSQ k p v
272fromList = List.foldl' (\psq (k, p, x) -> insert k p x psq) empty
273
274-- | /O(n)/ Convert a queue to a list of (key, priority, value) tuples. The
275-- order of the list is not specified.
276{-# INLINABLE toList #-}
277toList :: (Hashable k, Ord k, Ord p) => HashPSQ k p v -> [(k, p, v)]
278toList (HashPSQ ipsq) =
279    [ (k', p', x')
280    | (_, p, (B k x opsq)) <- IntPSQ.toList ipsq
281    , (k', p', x')         <- (k, p, x) : OrdPSQ.toList opsq
282    ]
283
284-- | /O(n)/ Obtain the list of present keys in the queue.
285{-# INLINABLE keys #-}
286keys :: (Hashable k, Ord k, Ord p) => HashPSQ k p v -> [k]
287keys t = [k | (k, _, _) <- toList t]
288
289
290--------------------------------------------------------------------------------
291-- Views
292--------------------------------------------------------------------------------
293
294-- | /O(min(n,W))/ Insert a new key, priority and value into the queue. If the key
295-- is already present in the queue, then the evicted priority and value can be
296-- found the first element of the returned tuple.
297{-# INLINABLE insertView #-}
298insertView
299    :: (Hashable k, Ord k, Ord p)
300    => k -> p -> v -> HashPSQ k p v -> (Maybe (p, v), HashPSQ k p v)
301insertView k p x t =
302    -- TODO (jaspervdj): Can be optimized easily
303    case deleteView k t of
304        Nothing          -> (Nothing,       insert k p x t)
305        Just (p', x', _) -> (Just (p', x'), insert k p x t)
306
307-- | /O(min(n,W))/ Delete a key and its priority and value from the queue. If
308-- the key was present, the associated priority and value are returned in
309-- addition to the updated queue.
310{-# INLINABLE deleteView #-}
311deleteView
312    :: forall k p v. (Hashable k, Ord k, Ord p)
313    => k -> HashPSQ k p v -> Maybe (p, v, HashPSQ k p v)
314deleteView k (HashPSQ ipsq) = case IntPSQ.alter f (hash k) ipsq of
315    (Nothing,     _    ) -> Nothing
316    (Just (p, x), ipsq') -> Just (p, x, HashPSQ ipsq')
317  where
318    f :: Maybe (p, Bucket k p v) -> (Maybe (p, v), Maybe (p, Bucket k p v))
319    f Nothing       = (Nothing, Nothing)
320    f (Just (p, B bk bx opsq))
321        | k == bk   = case OrdPSQ.minView opsq of
322            Nothing                  -> (Just (p, bx), Nothing)
323            Just (k', p', x', opsq') -> (Just (p, bx), Just (p', B k' x' opsq'))
324        | otherwise = case OrdPSQ.deleteView k opsq of
325            Nothing              -> (Nothing,       Nothing)
326            Just (p', x', opsq') -> (Just (p', x'), Just (p, B bk bx opsq'))
327
328-- | /O(min(n,W))/ Retrieve the binding with the least priority, and the
329-- rest of the queue stripped of that binding.
330{-# INLINABLE minView #-}
331minView
332    :: (Hashable k, Ord k, Ord p)
333    => HashPSQ k p v -> Maybe (k, p, v, HashPSQ k p v)
334minView (HashPSQ ipsq ) =
335    case IntPSQ.alterMin f ipsq of
336        (Nothing       , _    ) -> Nothing
337        (Just (k, p, x), ipsq') -> Just (k, p, x, HashPSQ ipsq')
338  where
339    f Nothing                 = (Nothing, Nothing)
340    f (Just (h, p, B k x os)) = case OrdPSQ.minView os of
341        Nothing                ->
342            (Just (k, p, x), Nothing)
343        Just (k', p', x', os') ->
344            (Just (k, p, x), Just (h, p', B k' x' os'))
345
346-- | Return a list of elements ordered by key whose priorities are at most @pt@,
347-- and the rest of the queue stripped of these elements.  The returned list of
348-- elements can be in any order: no guarantees there.
349{-# INLINABLE atMostView #-}
350atMostView
351    :: (Hashable k, Ord k, Ord p)
352    => p -> HashPSQ k p v -> ([(k, p, v)], HashPSQ k p v)
353atMostView pt (HashPSQ t0) =
354    (returns, HashPSQ t2)
355  where
356    -- First we use 'IntPSQ.atMostView' to get a collection of buckets that have
357    -- /AT LEAST/ one element with a low priority.  Buckets will usually only
358    -- contain a single element.
359    (buckets, t1) = IntPSQ.atMostView pt t0
360
361    -- We now need to run through the buckets.  This will give us a list of
362    -- elements to return and a bunch of buckets to re-insert.
363    (returns, reinserts) = go [] [] buckets
364      where
365        -- We use two accumulators, for returns and re-inserts.
366        go rets reins []                        = (rets, reins)
367        go rets reins ((_, p, B k v opsq) : bs) =
368            -- Note that 'elems' should be very small, ideally a null list.
369            let (elems, opsq') = OrdPSQ.atMostView pt opsq
370                rets'          = (k, p, v) : elems ++ rets
371                reins'         = case toBucket opsq' of
372                    Nothing      -> reins
373                    Just (p', b) -> ((p', b) : reins)
374            in  go rets' reins' bs
375
376    -- Now we can do the re-insertion pass.
377    t2 = List.foldl'
378        (\t (p, b@(B k _ _)) -> IntPSQ.unsafeInsertNew (hash k) p b t)
379        t1
380        reinserts
381
382
383--------------------------------------------------------------------------------
384-- Traversals
385--------------------------------------------------------------------------------
386
387-- | /O(n)/ Modify every value in the queue.
388{-# INLINABLE map #-}
389map :: (k -> p -> v -> w) -> HashPSQ k p v -> HashPSQ k p w
390map f (HashPSQ ipsq) = HashPSQ (IntPSQ.map (\_ p v -> mapBucket p v) ipsq)
391  where
392    mapBucket p (B k v opsq) = B k (f k p v) (OrdPSQ.map f opsq)
393
394-- | /O(n)/ Maps a function over the values and priorities of the queue.
395-- The function @f@ must be monotonic with respect to the priorities. I.e. if
396-- @x < y@, then @fst (f k x v) < fst (f k y v)@.
397-- /The precondition is not checked./ If @f@ is not monotonic, then the result
398-- will be invalid.
399{-# INLINABLE unsafeMapMonotonic #-}
400unsafeMapMonotonic
401    :: (k -> p -> v -> (q, w))
402    -> HashPSQ k p v
403    -> HashPSQ k q w
404unsafeMapMonotonic f (HashPSQ ipsq) =
405  HashPSQ (IntPSQ.unsafeMapMonotonic (\_ p v -> mapBucket p v) ipsq)
406  where
407    mapBucket p (B k v opsq) =
408        let (p', v') = f k p v
409        in  (p', B k v' (OrdPSQ.unsafeMapMonotonic f opsq))
410
411-- | /O(n)/ Strict fold over every key, priority and value in the queue. The order
412-- in which the fold is performed is not specified.
413{-# INLINABLE fold' #-}
414fold' :: (k -> p -> v -> a -> a) -> a -> HashPSQ k p v -> a
415fold' f acc0 (HashPSQ ipsq) = IntPSQ.fold' goBucket acc0 ipsq
416  where
417    goBucket _ p (B k v opsq) acc =
418        let !acc1 = f k p v acc
419            !acc2 = OrdPSQ.fold' f acc1 opsq
420        in acc2
421
422
423--------------------------------------------------------------------------------
424-- Unsafe operations
425--------------------------------------------------------------------------------
426
427{-# INLINABLE unsafeLookupIncreasePriority #-}
428unsafeLookupIncreasePriority
429    :: (Hashable k, Ord k, Ord p)
430    => k -> p -> HashPSQ k p v -> (Maybe (p, v), HashPSQ k p v)
431unsafeLookupIncreasePriority k p (HashPSQ ipsq) =
432    (mbPV, HashPSQ ipsq')
433  where
434    (!mbPV, !ipsq') = IntPSQ.unsafeLookupIncreasePriority
435        (\bp b@(B bk bx opsq) ->
436            if k == bk
437                then let (bp', b') = mkBucket k p bx opsq
438                     in (Just (bp, bx), bp', b')
439                -- TODO (jaspervdj): Still a lookup-insert here: 3 traversals?
440                else case OrdPSQ.lookup k opsq of
441                        Nothing      -> (Nothing,     bp, b)
442                        Just (p', x) ->
443                            let b' = B bk bx (OrdPSQ.insert k p x opsq)
444                            in (Just (p', x), bp, b'))
445        (hash k)
446        ipsq
447
448{-# INLINABLE unsafeInsertIncreasePriority #-}
449unsafeInsertIncreasePriority
450    :: (Hashable k, Ord k, Ord p)
451    => k -> p -> v -> HashPSQ k p v -> HashPSQ k p v
452unsafeInsertIncreasePriority k p x (HashPSQ ipsq) = HashPSQ $
453    IntPSQ.unsafeInsertWithIncreasePriority
454        (\_ _ bp (B bk bx opsq) ->
455            if k == bk
456                then mkBucket k p x opsq
457                else (bp, B bk bx (OrdPSQ.insert k p x opsq)))
458        (hash k)
459        p
460        (B k x OrdPSQ.empty)
461        ipsq
462
463{-# INLINABLE unsafeInsertIncreasePriorityView #-}
464unsafeInsertIncreasePriorityView
465    :: (Hashable k, Ord k, Ord p)
466    => k -> p -> v -> HashPSQ k p v -> (Maybe (p, v), HashPSQ k p v)
467unsafeInsertIncreasePriorityView k p x (HashPSQ ipsq) =
468    (mbEvicted, HashPSQ ipsq')
469  where
470    (mbBucket, ipsq') = IntPSQ.unsafeInsertWithIncreasePriorityView
471        (\_ _ bp (B bk bx opsq) ->
472            if k == bk
473                then mkBucket k p x opsq
474                else (bp, B bk bx (OrdPSQ.insert k p x opsq)))
475        (hash k)
476        p
477        (B k x OrdPSQ.empty)
478        ipsq
479
480    mbEvicted = case mbBucket of
481        Nothing         -> Nothing
482        Just (bp, B bk bv opsq)
483            | k == bk   -> Just (bp, bv)
484            | otherwise -> OrdPSQ.lookup k opsq
485
486
487--------------------------------------------------------------------------------
488-- Validity check
489--------------------------------------------------------------------------------
490
491-- | /O(n^2)/ Internal function to check if the 'HashPSQ' is valid, i.e. if all
492-- invariants hold. This should always be the case.
493valid :: (Hashable k, Ord k, Ord p) => HashPSQ k p v -> Bool
494valid t@(HashPSQ ipsq) =
495    not (hasDuplicateKeys t) &&
496    and [validBucket k p bucket | (k, p, bucket) <- IntPSQ.toList ipsq]
497
498hasDuplicateKeys :: (Hashable k, Ord k, Ord p) => HashPSQ k p v -> Bool
499hasDuplicateKeys = any (> 1) . List.map length . List.group . List.sort . keys
500
501validBucket :: (Hashable k, Ord k, Ord p) => Int -> p -> Bucket k p v -> Bool
502validBucket h p (B k _ opsq) =
503    OrdPSQ.valid opsq &&
504    -- Check that the first element of the bucket has lower priority than all
505    -- the other elements.
506    and [(p, k) < (p', k') && hash k' == h | (k', p', _) <- OrdPSQ.toList opsq]
507