1{-# LANGUAGE BangPatterns, CPP, MagicHash, Rank2Types, UnboxedTuples, ScopedTypeVariables #-}
2{-# OPTIONS_GHC -fno-full-laziness -funbox-strict-fields #-}
3
4-- | Zero based arrays.
5--
6-- Note that no bounds checking are performed.
7module Data.HashMap.Array
8    ( Array
9    , MArray
10
11      -- * Creation
12    , new
13    , new_
14    , singleton
15    , singletonM
16    , pair
17
18      -- * Basic interface
19    , length
20    , lengthM
21    , read
22    , write
23    , index
24    , indexM
25    , index#
26    , update
27    , updateWith'
28    , unsafeUpdateM
29    , insert
30    , insertM
31    , delete
32    , sameArray1
33    , trim
34
35    , unsafeFreeze
36    , unsafeThaw
37    , unsafeSameArray
38    , run
39    , run2
40    , copy
41    , copyM
42
43      -- * Folds
44    , foldl'
45    , foldr
46
47    , thaw
48    , map
49    , map'
50    , traverse
51    , traverse'
52    , toList
53    , fromList
54    ) where
55
56#if !MIN_VERSION_base(4,8,0)
57import Control.Applicative (Applicative (..), (<$>))
58#endif
59import Control.Applicative (liftA2)
60import Control.DeepSeq
61import GHC.Exts(Int(..), Int#, reallyUnsafePtrEquality#, tagToEnum#, unsafeCoerce#, State#)
62import GHC.ST (ST(..))
63import Control.Monad.ST (stToIO)
64
65#if __GLASGOW_HASKELL__ >= 709
66import Prelude hiding (filter, foldr, length, map, read, traverse)
67#else
68import Prelude hiding (filter, foldr, length, map, read)
69#endif
70
71#if __GLASGOW_HASKELL__ >= 710
72import GHC.Exts (SmallArray#, newSmallArray#, readSmallArray#, writeSmallArray#,
73                 indexSmallArray#, unsafeFreezeSmallArray#, unsafeThawSmallArray#,
74                 SmallMutableArray#, sizeofSmallArray#, copySmallArray#, thawSmallArray#,
75                 sizeofSmallMutableArray#, copySmallMutableArray#, cloneSmallMutableArray#)
76
77#else
78import GHC.Exts (Array#, newArray#, readArray#, writeArray#,
79                 indexArray#, unsafeFreezeArray#, unsafeThawArray#,
80                 MutableArray#, sizeofArray#, copyArray#, thawArray#,
81                 sizeofMutableArray#, copyMutableArray#, cloneMutableArray#)
82#endif
83
84#if defined(ASSERTS)
85import qualified Prelude
86#endif
87
88import Data.HashMap.Unsafe (runST)
89import Control.Monad ((>=>))
90
91
92#if __GLASGOW_HASKELL__ >= 710
93type Array# a = SmallArray# a
94type MutableArray# a = SmallMutableArray# a
95
96newArray# :: Int# -> a -> State# d -> (# State# d, SmallMutableArray# d a #)
97newArray# = newSmallArray#
98
99unsafeFreezeArray# :: SmallMutableArray# d a
100                   -> State# d -> (# State# d, SmallArray# a #)
101unsafeFreezeArray# = unsafeFreezeSmallArray#
102
103readArray# :: SmallMutableArray# d a
104           -> Int# -> State# d -> (# State# d, a #)
105readArray# = readSmallArray#
106
107writeArray# :: SmallMutableArray# d a
108            -> Int# -> a -> State# d -> State# d
109writeArray# = writeSmallArray#
110
111indexArray# :: SmallArray# a -> Int# -> (# a #)
112indexArray# = indexSmallArray#
113
114unsafeThawArray# :: SmallArray# a
115                 -> State# d -> (# State# d, SmallMutableArray# d a #)
116unsafeThawArray# = unsafeThawSmallArray#
117
118sizeofArray# :: SmallArray# a -> Int#
119sizeofArray# = sizeofSmallArray#
120
121copyArray# :: SmallArray# a
122           -> Int#
123           -> SmallMutableArray# d a
124           -> Int#
125           -> Int#
126           -> State# d
127           -> State# d
128copyArray# = copySmallArray#
129
130cloneMutableArray# :: SmallMutableArray# s a
131                   -> Int#
132                   -> Int#
133                   -> State# s
134                   -> (# State# s, SmallMutableArray# s a #)
135cloneMutableArray# = cloneSmallMutableArray#
136
137thawArray# :: SmallArray# a
138           -> Int#
139           -> Int#
140           -> State# d
141           -> (# State# d, SmallMutableArray# d a #)
142thawArray# = thawSmallArray#
143
144sizeofMutableArray# :: SmallMutableArray# s a -> Int#
145sizeofMutableArray# = sizeofSmallMutableArray#
146
147copyMutableArray# :: SmallMutableArray# d a
148                  -> Int#
149                  -> SmallMutableArray# d a
150                  -> Int#
151                  -> Int#
152                  -> State# d
153                  -> State# d
154copyMutableArray# = copySmallMutableArray#
155#endif
156
157------------------------------------------------------------------------
158
159#if defined(ASSERTS)
160-- This fugly hack is brought by GHC's apparent reluctance to deal
161-- with MagicHash and UnboxedTuples when inferring types. Eek!
162# define CHECK_BOUNDS(_func_,_len_,_k_) \
163if (_k_) < 0 || (_k_) >= (_len_) then error ("Data.HashMap.Array." ++ (_func_) ++ ": bounds error, offset " ++ show (_k_) ++ ", length " ++ show (_len_)) else
164# define CHECK_OP(_func_,_op_,_lhs_,_rhs_) \
165if not ((_lhs_) _op_ (_rhs_)) then error ("Data.HashMap.Array." ++ (_func_) ++ ": Check failed: _lhs_ _op_ _rhs_ (" ++ show (_lhs_) ++ " vs. " ++ show (_rhs_) ++ ")") else
166# define CHECK_GT(_func_,_lhs_,_rhs_) CHECK_OP(_func_,>,_lhs_,_rhs_)
167# define CHECK_LE(_func_,_lhs_,_rhs_) CHECK_OP(_func_,<=,_lhs_,_rhs_)
168# define CHECK_EQ(_func_,_lhs_,_rhs_) CHECK_OP(_func_,==,_lhs_,_rhs_)
169#else
170# define CHECK_BOUNDS(_func_,_len_,_k_)
171# define CHECK_OP(_func_,_op_,_lhs_,_rhs_)
172# define CHECK_GT(_func_,_lhs_,_rhs_)
173# define CHECK_LE(_func_,_lhs_,_rhs_)
174# define CHECK_EQ(_func_,_lhs_,_rhs_)
175#endif
176
177data Array a = Array {
178      unArray :: !(Array# a)
179    }
180
181instance Show a => Show (Array a) where
182    show = show . toList
183
184-- Determines whether two arrays have the same memory address.
185-- This is more reliable than testing pointer equality on the
186-- Array wrappers, but it's still slightly bogus.
187unsafeSameArray :: Array a -> Array b -> Bool
188unsafeSameArray (Array xs) (Array ys) =
189  tagToEnum# (unsafeCoerce# reallyUnsafePtrEquality# xs ys)
190
191sameArray1 :: (a -> b -> Bool) -> Array a -> Array b -> Bool
192sameArray1 eq !xs0 !ys0
193  | lenxs /= lenys = False
194  | otherwise = go 0 xs0 ys0
195  where
196    go !k !xs !ys
197      | k == lenxs = True
198      | (# x #) <- index# xs k
199      , (# y #) <- index# ys k
200      = eq x y && go (k + 1) xs ys
201
202    !lenxs = length xs0
203    !lenys = length ys0
204
205length :: Array a -> Int
206length ary = I# (sizeofArray# (unArray ary))
207{-# INLINE length #-}
208
209-- | Smart constructor
210array :: Array# a -> Int -> Array a
211array ary _n = Array ary
212{-# INLINE array #-}
213
214data MArray s a = MArray {
215      unMArray :: !(MutableArray# s a)
216    }
217
218lengthM :: MArray s a -> Int
219lengthM mary = I# (sizeofMutableArray# (unMArray mary))
220{-# INLINE lengthM #-}
221
222-- | Smart constructor
223marray :: MutableArray# s a -> Int -> MArray s a
224marray mary _n = MArray mary
225{-# INLINE marray #-}
226
227------------------------------------------------------------------------
228
229instance NFData a => NFData (Array a) where
230    rnf = rnfArray
231
232rnfArray :: NFData a => Array a -> ()
233rnfArray ary0 = go ary0 n0 0
234  where
235    n0 = length ary0
236    go !ary !n !i
237        | i >= n = ()
238        | (# x #) <- index# ary i
239        = rnf x `seq` go ary n (i+1)
240-- We use index# just in case GHC can't see that the
241-- relevant rnf is strict, or in case it actually isn't.
242{-# INLINE rnfArray #-}
243
244-- | Create a new mutable array of specified size, in the specified
245-- state thread, with each element containing the specified initial
246-- value.
247new :: Int -> a -> ST s (MArray s a)
248new n@(I# n#) b =
249    CHECK_GT("new",n,(0 :: Int))
250    ST $ \s ->
251        case newArray# n# b s of
252            (# s', ary #) -> (# s', marray ary n #)
253{-# INLINE new #-}
254
255new_ :: Int -> ST s (MArray s a)
256new_ n = new n undefinedElem
257
258singleton :: a -> Array a
259singleton x = runST (singletonM x)
260{-# INLINE singleton #-}
261
262singletonM :: a -> ST s (Array a)
263singletonM x = new 1 x >>= unsafeFreeze
264{-# INLINE singletonM #-}
265
266pair :: a -> a -> Array a
267pair x y = run $ do
268    ary <- new 2 x
269    write ary 1 y
270    return ary
271{-# INLINE pair #-}
272
273read :: MArray s a -> Int -> ST s a
274read ary _i@(I# i#) = ST $ \ s ->
275    CHECK_BOUNDS("read", lengthM ary, _i)
276        readArray# (unMArray ary) i# s
277{-# INLINE read #-}
278
279write :: MArray s a -> Int -> a -> ST s ()
280write ary _i@(I# i#) b = ST $ \ s ->
281    CHECK_BOUNDS("write", lengthM ary, _i)
282        case writeArray# (unMArray ary) i# b s of
283            s' -> (# s' , () #)
284{-# INLINE write #-}
285
286index :: Array a -> Int -> a
287index ary _i@(I# i#) =
288    CHECK_BOUNDS("index", length ary, _i)
289        case indexArray# (unArray ary) i# of (# b #) -> b
290{-# INLINE index #-}
291
292index# :: Array a -> Int -> (# a #)
293index# ary _i@(I# i#) =
294    CHECK_BOUNDS("index#", length ary, _i)
295        indexArray# (unArray ary) i#
296{-# INLINE index# #-}
297
298indexM :: Array a -> Int -> ST s a
299indexM ary _i@(I# i#) =
300    CHECK_BOUNDS("indexM", length ary, _i)
301        case indexArray# (unArray ary) i# of (# b #) -> return b
302{-# INLINE indexM #-}
303
304unsafeFreeze :: MArray s a -> ST s (Array a)
305unsafeFreeze mary
306    = ST $ \s -> case unsafeFreezeArray# (unMArray mary) s of
307                   (# s', ary #) -> (# s', array ary (lengthM mary) #)
308{-# INLINE unsafeFreeze #-}
309
310unsafeThaw :: Array a -> ST s (MArray s a)
311unsafeThaw ary
312    = ST $ \s -> case unsafeThawArray# (unArray ary) s of
313                   (# s', mary #) -> (# s', marray mary (length ary) #)
314{-# INLINE unsafeThaw #-}
315
316run :: (forall s . ST s (MArray s e)) -> Array e
317run act = runST $ act >>= unsafeFreeze
318{-# INLINE run #-}
319
320run2 :: (forall s. ST s (MArray s e, a)) -> (Array e, a)
321run2 k = runST (do
322                 (marr,b) <- k
323                 arr <- unsafeFreeze marr
324                 return (arr,b))
325
326-- | Unsafely copy the elements of an array. Array bounds are not checked.
327copy :: Array e -> Int -> MArray s e -> Int -> Int -> ST s ()
328copy !src !_sidx@(I# sidx#) !dst !_didx@(I# didx#) _n@(I# n#) =
329    CHECK_LE("copy", _sidx + _n, length src)
330    CHECK_LE("copy", _didx + _n, lengthM dst)
331        ST $ \ s# ->
332        case copyArray# (unArray src) sidx# (unMArray dst) didx# n# s# of
333            s2 -> (# s2, () #)
334
335-- | Unsafely copy the elements of an array. Array bounds are not checked.
336copyM :: MArray s e -> Int -> MArray s e -> Int -> Int -> ST s ()
337copyM !src !_sidx@(I# sidx#) !dst !_didx@(I# didx#) _n@(I# n#) =
338    CHECK_BOUNDS("copyM: src", lengthM src, _sidx + _n - 1)
339    CHECK_BOUNDS("copyM: dst", lengthM dst, _didx + _n - 1)
340    ST $ \ s# ->
341    case copyMutableArray# (unMArray src) sidx# (unMArray dst) didx# n# s# of
342        s2 -> (# s2, () #)
343
344cloneM :: MArray s a -> Int -> Int -> ST s (MArray s a)
345cloneM _mary@(MArray mary#) _off@(I# off#) _len@(I# len#) =
346    CHECK_BOUNDS("cloneM_off", lengthM _mary, _off - 1)
347    CHECK_BOUNDS("cloneM_end", lengthM _mary, _off + _len - 1)
348    ST $ \ s ->
349    case cloneMutableArray# mary# off# len# s of
350      (# s', mary'# #) -> (# s', MArray mary'# #)
351
352-- | Create a new array of the @n@ first elements of @mary@.
353trim :: MArray s a -> Int -> ST s (Array a)
354trim mary n = cloneM mary 0 n >>= unsafeFreeze
355{-# INLINE trim #-}
356
357-- | /O(n)/ Insert an element at the given position in this array,
358-- increasing its size by one.
359insert :: Array e -> Int -> e -> Array e
360insert ary idx b = runST (insertM ary idx b)
361{-# INLINE insert #-}
362
363-- | /O(n)/ Insert an element at the given position in this array,
364-- increasing its size by one.
365insertM :: Array e -> Int -> e -> ST s (Array e)
366insertM ary idx b =
367    CHECK_BOUNDS("insertM", count + 1, idx)
368        do mary <- new_ (count+1)
369           copy ary 0 mary 0 idx
370           write mary idx b
371           copy ary idx mary (idx+1) (count-idx)
372           unsafeFreeze mary
373  where !count = length ary
374{-# INLINE insertM #-}
375
376-- | /O(n)/ Update the element at the given position in this array.
377update :: Array e -> Int -> e -> Array e
378update ary idx b = runST (updateM ary idx b)
379{-# INLINE update #-}
380
381-- | /O(n)/ Update the element at the given position in this array.
382updateM :: Array e -> Int -> e -> ST s (Array e)
383updateM ary idx b =
384    CHECK_BOUNDS("updateM", count, idx)
385        do mary <- thaw ary 0 count
386           write mary idx b
387           unsafeFreeze mary
388  where !count = length ary
389{-# INLINE updateM #-}
390
391-- | /O(n)/ Update the element at the given positio in this array, by
392-- applying a function to it.  Evaluates the element to WHNF before
393-- inserting it into the array.
394updateWith' :: Array e -> Int -> (e -> e) -> Array e
395updateWith' ary idx f
396  | (# x #) <- index# ary idx
397  = update ary idx $! f x
398{-# INLINE updateWith' #-}
399
400-- | /O(1)/ Update the element at the given position in this array,
401-- without copying.
402unsafeUpdateM :: Array e -> Int -> e -> ST s ()
403unsafeUpdateM ary idx b =
404    CHECK_BOUNDS("unsafeUpdateM", length ary, idx)
405        do mary <- unsafeThaw ary
406           write mary idx b
407           _ <- unsafeFreeze mary
408           return ()
409{-# INLINE unsafeUpdateM #-}
410
411foldl' :: (b -> a -> b) -> b -> Array a -> b
412foldl' f = \ z0 ary0 -> go ary0 (length ary0) 0 z0
413  where
414    go ary n i !z
415        | i >= n = z
416        | otherwise
417        = case index# ary i of
418            (# x #) -> go ary n (i+1) (f z x)
419{-# INLINE foldl' #-}
420
421foldr :: (a -> b -> b) -> b -> Array a -> b
422foldr f = \ z0 ary0 -> go ary0 (length ary0) 0 z0
423  where
424    go ary n i z
425        | i >= n = z
426        | otherwise
427        = case index# ary i of
428            (# x #) -> f x (go ary n (i+1) z)
429{-# INLINE foldr #-}
430
431undefinedElem :: a
432undefinedElem = error "Data.HashMap.Array: Undefined element"
433{-# NOINLINE undefinedElem #-}
434
435thaw :: Array e -> Int -> Int -> ST s (MArray s e)
436thaw !ary !_o@(I# o#) !n@(I# n#) =
437    CHECK_LE("thaw", _o + n, length ary)
438        ST $ \ s -> case thawArray# (unArray ary) o# n# s of
439            (# s2, mary# #) -> (# s2, marray mary# n #)
440{-# INLINE thaw #-}
441
442-- | /O(n)/ Delete an element at the given position in this array,
443-- decreasing its size by one.
444delete :: Array e -> Int -> Array e
445delete ary idx = runST (deleteM ary idx)
446{-# INLINE delete #-}
447
448-- | /O(n)/ Delete an element at the given position in this array,
449-- decreasing its size by one.
450deleteM :: Array e -> Int -> ST s (Array e)
451deleteM ary idx = do
452    CHECK_BOUNDS("deleteM", count, idx)
453        do mary <- new_ (count-1)
454           copy ary 0 mary 0 idx
455           copy ary (idx+1) mary idx (count-(idx+1))
456           unsafeFreeze mary
457  where !count = length ary
458{-# INLINE deleteM #-}
459
460map :: (a -> b) -> Array a -> Array b
461map f = \ ary ->
462    let !n = length ary
463    in run $ do
464        mary <- new_ n
465        go ary mary 0 n
466  where
467    go ary mary i n
468        | i >= n    = return mary
469        | otherwise = do
470             x <- indexM ary i
471             write mary i $ f x
472             go ary mary (i+1) n
473{-# INLINE map #-}
474
475-- | Strict version of 'map'.
476map' :: (a -> b) -> Array a -> Array b
477map' f = \ ary ->
478    let !n = length ary
479    in run $ do
480        mary <- new_ n
481        go ary mary 0 n
482  where
483    go ary mary i n
484        | i >= n    = return mary
485        | otherwise = do
486             x <- indexM ary i
487             write mary i $! f x
488             go ary mary (i+1) n
489{-# INLINE map' #-}
490
491fromList :: Int -> [a] -> Array a
492fromList n xs0 =
493    CHECK_EQ("fromList", n, Prelude.length xs0)
494        run $ do
495            mary <- new_ n
496            go xs0 mary 0
497  where
498    go [] !mary !_   = return mary
499    go (x:xs) mary i = do write mary i x
500                          go xs mary (i+1)
501
502toList :: Array a -> [a]
503toList = foldr (:) []
504
505newtype STA a = STA {_runSTA :: forall s. MutableArray# s a -> ST s (Array a)}
506
507runSTA :: Int -> STA a -> Array a
508runSTA !n (STA m) = runST $ new_ n >>= \ (MArray ar) -> m ar
509
510traverse :: Applicative f => (a -> f b) -> Array a -> f (Array b)
511traverse f = \ !ary ->
512  let
513    !len = length ary
514    go !i
515      | i == len = pure $ STA $ \mary -> unsafeFreeze (MArray mary)
516      | (# x #) <- index# ary i
517      = liftA2 (\b (STA m) -> STA $ \mary ->
518                  write (MArray mary) i b >> m mary)
519               (f x) (go (i + 1))
520  in runSTA len <$> go 0
521{-# INLINE [1] traverse #-}
522
523-- TODO: Would it be better to just use a lazy traversal
524-- and then force the elements of the result? My guess is
525-- yes.
526traverse' :: Applicative f => (a -> f b) -> Array a -> f (Array b)
527traverse' f = \ !ary ->
528  let
529    !len = length ary
530    go !i
531      | i == len = pure $ STA $ \mary -> unsafeFreeze (MArray mary)
532      | (# x #) <- index# ary i
533      = liftA2 (\ !b (STA m) -> STA $ \mary ->
534                    write (MArray mary) i b >> m mary)
535               (f x) (go (i + 1))
536  in runSTA len <$> go 0
537{-# INLINE [1] traverse' #-}
538
539-- Traversing in ST, we don't need to get fancy; we
540-- can just do it directly.
541traverseST :: (a -> ST s b) -> Array a -> ST s (Array b)
542traverseST f = \ ary0 ->
543  let
544    !len = length ary0
545    go k !mary
546      | k == len = return mary
547      | otherwise = do
548          x <- indexM ary0 k
549          y <- f x
550          write mary k y
551          go (k + 1) mary
552  in new_ len >>= (go 0 >=> unsafeFreeze)
553{-# INLINE traverseST #-}
554
555traverseIO :: (a -> IO b) -> Array a -> IO (Array b)
556traverseIO f = \ ary0 ->
557  let
558    !len = length ary0
559    go k !mary
560      | k == len = return mary
561      | otherwise = do
562          x <- stToIO $ indexM ary0 k
563          y <- f x
564          stToIO $ write mary k y
565          go (k + 1) mary
566  in stToIO (new_ len) >>= (go 0 >=> stToIO . unsafeFreeze)
567{-# INLINE traverseIO #-}
568
569
570-- Why don't we have similar RULES for traverse'? The efficient
571-- way to traverse strictly in IO or ST is to force results as
572-- they come in, which leads to different semantics. In particular,
573-- we need to ensure that
574--
575--  traverse' (\x -> print x *> pure undefined) xs
576--
577-- will actually print all the values and then return undefined.
578-- We could add a strict mapMWithIndex, operating in an arbitrary
579-- Monad, that supported such rules, but we don't have that right now.
580{-# RULES
581"traverse/ST" forall f. traverse f = traverseST f
582"traverse/IO" forall f. traverse f = traverseIO f
583 #-}
584