1{-# LANGUAGE DeriveFunctor #-}
2{-# LANGUAGE DeriveGeneric #-}
3{-# LANGUAGE DeriveAnyClass #-}
4module Data.IMap
5    ( IMap
6    , Run(..)
7    , empty
8    , Data.IMap.null
9    , singleton
10    , insert
11    , delete
12    , restrict
13    , lookup
14    , splitLE
15    , intersectionWith
16    , mapMaybe
17    , addToKeys
18    , unsafeUnion
19    , fromList
20    , unsafeRuns
21    , unsafeToAscList
22    ) where
23
24import Data.List (foldl')
25import Data.Monoid
26import Data.IntMap.Strict (IntMap)
27import GHC.Generics
28import Control.DeepSeq
29import Prelude hiding (lookup)
30import qualified Data.IntMap.Strict as IM
31
32-- | Semantically, 'IMap' and 'IntMap' are identical; but 'IMap' is more
33-- efficient when large sequences of contiguous keys are mapped to the same
34-- value.
35newtype IMap a = IMap { _runs :: IntMap (Run a) } deriving (Show, Functor, Read, Generic, NFData)
36
37{-# INLINE unsafeRuns #-}
38-- | This function is unsafe because 'IMap's that compare equal may split their
39-- runs into different chunks; consumers must promise that they do not treat
40-- run boundaries specially.
41unsafeRuns :: IMap a -> IntMap (Run a)
42unsafeRuns = _runs
43
44instance Eq a => Eq (IMap a) where
45    IMap m == IMap m' = go (IM.toAscList m) (IM.toAscList m') where
46        go ((k, Run n a):kvs) ((k', Run n' a'):kvs')
47            = k == k' && a == a' && case compare n n' of
48                LT -> go kvs ((k'+n, Run (n'-n) a'):kvs')
49                EQ -> go kvs kvs'
50                GT -> go ((k+n', Run (n-n') a):kvs) kvs'
51        go [] [] = True
52        go _ _ = False
53
54instance Ord a => Ord (IMap a) where
55    compare (IMap m) (IMap m') = go (IM.toAscList m) (IM.toAscList m') where
56        go [] [] = EQ
57        go [] _  = LT
58        go _  [] = GT
59        go ((k, Run n a):kvs) ((k', Run n' a'):kvs')
60            = compare k k' <> compare a a' <> case compare n n' of
61                LT -> go kvs ((k'+n, Run (n'-n) a'):kvs')
62                EQ -> go kvs kvs'
63                GT -> go ((k+n', Run (n-n') a):kvs) kvs'
64
65-- | Zippy: '(<*>)' combines values at equal keys, discarding any values whose
66-- key is in only one of its two arguments.
67instance Applicative IMap where
68    pure a = IMap . IM.fromDistinctAscList $
69        [ (minBound, Run maxBound a)
70        , (-1, Run maxBound a)
71        , (maxBound-1, Run 2 a)
72        ]
73    (<*>) = intersectionWith ($)
74
75-- | @Run n a@ represents @n@ copies of the value @a@.
76data Run a = Run
77    { len :: !Int
78    , val :: !a
79    } deriving (Eq, Ord, Read, Show, Functor, Generic, NFData)
80
81instance Foldable    Run where foldMap f r = f (val r)
82instance Traversable Run where sequenceA (Run n v) = Run n <$> v
83
84empty :: IMap a
85empty = IMap IM.empty
86
87null :: IMap a -> Bool
88null = IM.null . _runs
89
90singleton :: Int -> Run a -> IMap a
91singleton k r
92    | len r >= 1 = IMap (IM.singleton k r)
93    | otherwise = empty
94
95insert :: Int -> Run a -> IMap a -> IMap a
96insert k r m
97    | len r < 1 = m
98    | otherwise = m { _runs = IM.insert k r (_runs (delete k r m)) }
99
100{-# INLINE delete #-}
101delete :: Int -> Run ignored -> IMap a -> IMap a
102delete k r m
103    | len r < 1 = m
104    | otherwise = m { _runs = IM.union (_runs lt) (_runs gt) }
105    where
106    (lt, ge) = splitLE (k-1) m
107    (_ , gt) = splitLE (k+len r-1) ge
108
109-- | Given a range of keys (as specified by a starting key and a length for
110-- consistency with other functions in this module), restrict the map to keys
111-- in that range. @restrict k r m@ is equivalent to @intersectionWith const m
112-- (insert k r empty)@ but potentially more efficient.
113restrict :: Int -> Run ignored -> IMap a -> IMap a
114restrict k r = id
115    . snd
116    . splitLE (k-1)
117    . fst
118    . splitLE (k+len r-1)
119
120lookup :: Int -> IMap a -> Maybe a
121lookup k m = case IM.lookupLE k (_runs m) of
122    Just (k', Run n a) | k < k'+n -> Just a
123    _ -> Nothing
124
125-- | @splitLE n m@ produces a tuple @(le, gt)@ where @le@ has all the
126-- associations of @m@ where the keys are @<= n@ and @gt@ has all the
127-- associations of @m@ where the keys are @> n@.
128splitLE :: Int -> IMap a -> (IMap a, IMap a)
129splitLE k m = case IM.lookupLE k (_runs m) of
130    Nothing -> (empty, m)
131    Just (k', r@(Run n _)) -> case (k' + n - 1 <= k, k' == k) of
132        (True , False) -> (m { _runs = lt }, m { _runs = gt })
133        (True , True ) -> (m { _runs = IM.insert k r lt }, m { _runs = gt })
134        (False, _    ) -> ( m { _runs = IM.insert k'    r { len =     1 + k - k' } lt' }
135                          , m { _runs = IM.insert (k+1) r { len = n - 1 - k + k' } gt' }
136                          )
137        where
138        (lt', gt') = IM.split k' (_runs m)
139    where
140    (lt, gt) = IM.split k (_runs m)
141
142-- | Increment all keys by the given amount. This is like
143-- 'IM.mapKeysMonotonic', but restricted to partially-applied addition.
144addToKeys :: Int -> IMap a -> IMap a
145addToKeys n m = m { _runs = IM.mapKeysMonotonic (n+) (_runs m) }
146
147-- TODO: This is pretty inefficient. IntMap offers some splitting functions
148-- that should make it possible to be more efficient here (though the
149-- implementation would be significantly messier).
150intersectionWith :: (a -> b -> c) -> IMap a -> IMap b -> IMap c
151intersectionWith f (IMap runsa) (IMap runsb)
152    = IMap . IM.fromDistinctAscList $ merge (IM.toAscList runsa) (IM.toAscList runsb)
153    where
154    merge as@((ka, ra):at) bs@((kb, rb):bt)
155        | ka' < kb = merge at bs
156        | kb' < ka = merge as bt
157        | otherwise = (kc, Run (kc' - kc + 1) vc) : case compare ka' kb' of
158            LT -> merge at bs
159            EQ -> merge at bt
160            GT -> merge as bt
161        where
162        ka' = ka + len ra - 1
163        kb' = kb + len rb - 1
164        kc  = max ka  kb
165        kc' = min ka' kb'
166        vc  = f (val ra) (val rb)
167    merge _ _ = []
168
169mapMaybe :: (a -> Maybe b) -> IMap a -> IMap b
170mapMaybe f (IMap runs) = IMap (IM.mapMaybe (traverse f) runs)
171
172fromList :: [(Int, Run a)] -> IMap a
173fromList = foldl' (\m (k, r) -> insert k r m) empty
174
175-- | This function is unsafe because 'IMap's that compare equal may split their
176-- runs into different chunks; consumers must promise that they do not treat
177-- run boundaries specially.
178unsafeToAscList :: IMap a -> [(Int, Run a)]
179unsafeToAscList = IM.toAscList . _runs
180
181-- | This function is unsafe because it assumes there is no overlap between its
182-- arguments. That is, in the call @unsafeUnion a b@, the caller must guarantee
183-- that if @lookup k a = Just v@ then @lookup k b = Nothing@ and vice versa.
184unsafeUnion :: IMap a -> IMap a -> IMap a
185unsafeUnion a b = IMap { _runs = _runs a `IM.union` _runs b }
186