1-- |
2-- Module      : Basement.Block
3-- License     : BSD-style
4-- Maintainer  : Haskell Foundation
5--
6-- A block of memory that contains elements of a type,
7-- very similar to an unboxed array but with the key difference:
8--
9-- * It doesn't have slicing capability (no cheap take or drop)
10-- * It consume less memory: 1 Offset, 1 CountOf
11-- * It's unpackable in any constructor
12-- * It uses unpinned memory by default
13--
14{-# LANGUAGE MagicHash           #-}
15{-# LANGUAGE ScopedTypeVariables #-}
16{-# LANGUAGE UnboxedTuples       #-}
17{-# LANGUAGE MultiParamTypeClasses #-}
18{-# LANGUAGE FlexibleInstances #-}
19module Basement.Block
20    ( Block(..)
21    , MutableBlock(..)
22    -- * Properties
23    , length
24    -- * Lowlevel functions
25    , unsafeThaw
26    , unsafeFreeze
27    , unsafeIndex
28    , thaw
29    , freeze
30    , copy
31    , unsafeCast
32    , cast
33    -- * safer api
34    , empty
35    , create
36    , isPinned
37    , isMutablePinned
38    , singleton
39    , replicate
40    , index
41    , map
42    , foldl'
43    , foldr
44    , foldl1'
45    , foldr1
46    , cons
47    , snoc
48    , uncons
49    , unsnoc
50    , sub
51    , splitAt
52    , revSplitAt
53    , splitOn
54    , break
55    , breakEnd
56    , span
57    , elem
58    , all
59    , any
60    , find
61    , filter
62    , reverse
63    , sortBy
64    , intersperse
65    -- * Foreign interfaces
66    , createFromPtr
67    , unsafeCopyToPtr
68    , withPtr
69    ) where
70
71import           GHC.Prim
72import           GHC.Types
73import           GHC.ST
74import qualified Data.List
75import           Basement.Compat.Base
76import           Data.Proxy
77import           Basement.Compat.Primitive
78import           Basement.NonEmpty
79import           Basement.Types.OffsetSize
80import           Basement.Monad
81import           Basement.Exception
82import           Basement.PrimType
83import qualified Basement.Block.Mutable as M
84import           Basement.Block.Mutable (Block(..), MutableBlock(..), new, unsafeThaw, unsafeFreeze)
85import           Basement.Block.Base
86import           Basement.Numerical.Additive
87import           Basement.Numerical.Subtractive
88import           Basement.Numerical.Multiplicative
89import qualified Basement.Alg.Mutable as MutAlg
90import qualified Basement.Alg.Class as Alg
91import qualified Basement.Alg.PrimArray as Alg
92
93instance (PrimMonad prim, st ~ PrimState prim, PrimType ty)
94         => Alg.RandomAccess (MutableBlock ty st) prim ty where
95    read (MutableBlock mba) = primMbaRead mba
96    write (MutableBlock mba) = primMbaWrite mba
97
98instance (PrimType ty) => Alg.Indexable (Block ty) ty where
99    index (Block ba) = primBaIndex ba
100    {-# INLINE index #-}
101
102instance Alg.Indexable (Block Word8) Word64 where
103    index (Block ba) = primBaIndex ba
104    {-# INLINE index #-}
105
106-- | Copy all the block content to the memory starting at the destination address
107unsafeCopyToPtr :: forall ty prim . PrimMonad prim
108                => Block ty -- ^ the source block to copy
109                -> Ptr ty   -- ^ The destination address where the copy is going to start
110                -> prim ()
111unsafeCopyToPtr (Block blk) (Ptr p) = primitive $ \s1 ->
112    (# copyByteArrayToAddr# blk 0# p (sizeofByteArray# blk) s1, () #)
113
114-- | Create a new array of size @n by settings each cells through the
115-- function @f.
116create :: forall ty . PrimType ty
117       => CountOf ty           -- ^ the size of the block (in element of ty)
118       -> (Offset ty -> ty) -- ^ the function that set the value at the index
119       -> Block ty          -- ^ the array created
120create n initializer
121    | n == 0    = mempty
122    | otherwise = runST $ do
123        mb <- new n
124        M.iterSet initializer mb
125        unsafeFreeze mb
126
127-- | Freeze a chunk of memory pointed, of specific size into a new unboxed array
128createFromPtr :: PrimType ty
129              => Ptr ty
130              -> CountOf ty
131              -> IO (Block ty)
132createFromPtr p sz = do
133    mb <- new sz
134    M.copyFromPtr p mb 0 sz
135    unsafeFreeze mb
136
137singleton :: PrimType ty => ty -> Block ty
138singleton ty = create 1 (const ty)
139
140replicate :: PrimType ty => CountOf ty -> ty -> Block ty
141replicate sz ty = create sz (const ty)
142
143-- | Thaw a Block into a MutableBlock
144--
145-- the Block is not modified, instead a new Mutable Block is created
146-- and its content is copied to the mutable block
147thaw :: (PrimMonad prim, PrimType ty) => Block ty -> prim (MutableBlock ty (PrimState prim))
148thaw array = do
149    ma <- M.unsafeNew Unpinned (lengthBytes array)
150    M.unsafeCopyBytesRO ma 0 array 0 (lengthBytes array)
151    pure ma
152{-# INLINE thaw #-}
153
154-- | Freeze a MutableBlock into a Block, copying all the data
155--
156-- If the data is modified in the mutable block after this call, then
157-- the immutable Block resulting is not impacted.
158freeze :: (PrimType ty, PrimMonad prim) => MutableBlock ty (PrimState prim) -> prim (Block ty)
159freeze ma = do
160    ma' <- unsafeNew Unpinned len
161    M.unsafeCopyBytes ma' 0 ma 0 len
162    --M.copyAt ma' (Offset 0) ma (Offset 0) len
163    unsafeFreeze ma'
164  where
165    len = M.mutableLengthBytes ma
166
167-- | Copy every cells of an existing Block to a new Block
168copy :: PrimType ty => Block ty -> Block ty
169copy array = runST (thaw array >>= unsafeFreeze)
170
171-- | Return the element at a specific index from an array.
172--
173-- If the index @n is out of bounds, an error is raised.
174index :: PrimType ty => Block ty -> Offset ty -> ty
175index array n
176    | isOutOfBound n len = outOfBound OOB_Index n len
177    | otherwise          = unsafeIndex array n
178  where
179    !len = length array
180{-# INLINE index #-}
181
182-- | Map all element 'a' from a block to a new block of 'b'
183map :: (PrimType a, PrimType b) => (a -> b) -> Block a -> Block b
184map f a = create lenB (\i -> f $ unsafeIndex a (offsetCast Proxy i))
185  where !lenB = sizeCast (Proxy :: Proxy (a -> b)) (length a)
186
187foldr :: PrimType ty => (ty -> a -> a) -> a -> Block ty -> a
188foldr f initialAcc vec = loop 0
189  where
190    !len = length vec
191    loop !i
192        | i .==# len = initialAcc
193        | otherwise  = unsafeIndex vec i `f` loop (i+1)
194{-# SPECIALIZE [2] foldr :: (Word8 -> a -> a) -> a -> Block Word8 -> a #-}
195
196foldl' :: PrimType ty => (a -> ty -> a) -> a -> Block ty -> a
197foldl' f initialAcc vec = loop 0 initialAcc
198  where
199    !len = length vec
200    loop !i !acc
201        | i .==# len = acc
202        | otherwise  = loop (i+1) (f acc (unsafeIndex vec i))
203{-# SPECIALIZE [2] foldl' :: (a -> Word8 -> a) -> a -> Block Word8 -> a #-}
204
205foldl1' :: PrimType ty => (ty -> ty -> ty) -> NonEmpty (Block ty) -> ty
206foldl1' f (NonEmpty arr) = loop 1 (unsafeIndex arr 0)
207  where
208    !len = length arr
209    loop !i !acc
210        | i .==# len = acc
211        | otherwise  = loop (i+1) (f acc (unsafeIndex arr i))
212{-# SPECIALIZE [3] foldl1' :: (Word8 -> Word8 -> Word8) -> NonEmpty (Block Word8) -> Word8 #-}
213
214foldr1 :: PrimType ty => (ty -> ty -> ty) -> NonEmpty (Block ty) -> ty
215foldr1 f arr = let (initialAcc, rest) = revSplitAt 1 $ getNonEmpty arr
216               in foldr f (unsafeIndex initialAcc 0) rest
217
218cons :: PrimType ty => ty -> Block ty -> Block ty
219cons e vec
220    | len == 0  = singleton e
221    | otherwise = runST $ do
222        muv <- new (len + 1)
223        M.unsafeCopyElementsRO muv 1 vec 0 len
224        M.unsafeWrite muv 0 e
225        unsafeFreeze muv
226  where
227    !len = length vec
228
229snoc :: PrimType ty => Block ty -> ty -> Block ty
230snoc vec e
231    | len == 0  = singleton e
232    | otherwise = runST $ do
233        muv <- new (len + 1)
234        M.unsafeCopyElementsRO muv 0 vec 0 len
235        M.unsafeWrite muv (0 `offsetPlusE` len) e
236        unsafeFreeze muv
237  where
238     !len = length vec
239
240sub :: PrimType ty => Block ty -> Offset ty -> Offset ty -> Block ty
241sub blk start end
242    | start >= end' = mempty
243    | otherwise     = runST $ do
244        dst <- new newLen
245        M.unsafeCopyElementsRO dst 0 blk start newLen
246        unsafeFreeze dst
247  where
248    newLen = end' - start
249    end' = min (sizeAsOffset len) end
250    !len = length blk
251
252uncons :: PrimType ty => Block ty -> Maybe (ty, Block ty)
253uncons vec
254    | nbElems == 0 = Nothing
255    | otherwise    = Just (unsafeIndex vec 0, sub vec 1 (0 `offsetPlusE` nbElems))
256  where
257    !nbElems = length vec
258
259unsnoc :: PrimType ty => Block ty -> Maybe (Block ty, ty)
260unsnoc vec = case length vec - 1 of
261    Nothing -> Nothing
262    Just offset -> Just (sub vec 0 lastElem, unsafeIndex vec lastElem)
263                     where !lastElem = 0 `offsetPlusE` offset
264
265splitAt :: PrimType ty => CountOf ty -> Block ty -> (Block ty, Block ty)
266splitAt nbElems blk
267    | nbElems <= 0 = (mempty, blk)
268    | Just nbTails <- length blk - nbElems, nbTails > 0 = runST $ do
269        left  <- new nbElems
270        right <- new nbTails
271        M.unsafeCopyElementsRO left  0 blk 0                      nbElems
272        M.unsafeCopyElementsRO right 0 blk (sizeAsOffset nbElems) nbTails
273        (,) <$> unsafeFreeze left <*> unsafeFreeze right
274    | otherwise    = (blk, mempty)
275{-# SPECIALIZE [2] splitAt :: CountOf Word8 -> Block Word8 -> (Block Word8, Block Word8) #-}
276
277revSplitAt :: PrimType ty => CountOf ty -> Block ty -> (Block ty, Block ty)
278revSplitAt n blk
279    | n <= 0                         = (mempty, blk)
280    | Just nbElems <- length blk - n = let (x, y) = splitAt nbElems blk in (y, x)
281    | otherwise                      = (blk, mempty)
282
283break :: PrimType ty => (ty -> Bool) -> Block ty -> (Block ty, Block ty)
284break predicate blk = findBreak 0
285  where
286    !len = length blk
287    findBreak !i
288        | i .==# len                    = (blk, mempty)
289        | predicate (unsafeIndex blk i) = splitAt (offsetAsSize i) blk
290        | otherwise                     = findBreak (i + 1)
291    {-# INLINE findBreak #-}
292{-# SPECIALIZE [2] break :: (Word8 -> Bool) -> Block Word8 -> (Block Word8, Block Word8) #-}
293
294breakEnd :: PrimType ty => (ty -> Bool) -> Block ty -> (Block ty, Block ty)
295breakEnd predicate blk
296    | k == sentinel = (blk, mempty)
297    | otherwise     = splitAt (offsetAsSize (k+1)) blk
298  where
299    !k = Alg.revFindIndexPredicate predicate blk 0 end
300    !end = sizeAsOffset $ length blk
301{-# SPECIALIZE [2] breakEnd :: (Word8 -> Bool) -> Block Word8 -> (Block Word8, Block Word8) #-}
302
303span :: PrimType ty => (ty -> Bool) -> Block ty -> (Block ty, Block ty)
304span p = break (not . p)
305
306elem :: PrimType ty => ty -> Block ty -> Bool
307elem v blk = loop 0
308  where
309    !len = length blk
310    loop !i
311        | i .==# len             = False
312        | unsafeIndex blk i == v = True
313        | otherwise              = loop (i+1)
314{-# SPECIALIZE [2] elem :: Word8 -> Block Word8 -> Bool #-}
315
316all :: PrimType ty => (ty -> Bool) -> Block ty -> Bool
317all p blk = loop 0
318  where
319    !len = length blk
320    loop !i
321        | i .==# len            = True
322        | p (unsafeIndex blk i) = loop (i+1)
323        | otherwise             = False
324{-# SPECIALIZE [2] all :: (Word8 -> Bool) -> Block Word8 -> Bool #-}
325
326any :: PrimType ty => (ty -> Bool) -> Block ty -> Bool
327any p blk = loop 0
328  where
329    !len = length blk
330    loop !i
331        | i .==# len            = False
332        | p (unsafeIndex blk i) = True
333        | otherwise             = loop (i+1)
334{-# SPECIALIZE [2] any :: (Word8 -> Bool) -> Block Word8 -> Bool #-}
335
336splitOn :: PrimType ty => (ty -> Bool) -> Block ty -> [Block ty]
337splitOn predicate blk
338    | len == 0  = [mempty]
339    | otherwise = go 0 0
340  where
341    !len = length blk
342    go !prevIdx !idx
343        | idx .==# len = [sub blk prevIdx idx]
344        | otherwise    =
345            let e = unsafeIndex blk idx
346                idx' = idx + 1
347             in if predicate e
348                    then sub blk prevIdx idx : go idx' idx'
349                    else go prevIdx idx'
350
351find :: PrimType ty => (ty -> Bool) -> Block ty -> Maybe ty
352find predicate vec = loop 0
353  where
354    !len = length vec
355    loop i
356        | i .==# len = Nothing
357        | otherwise  =
358            let e = unsafeIndex vec i
359             in if predicate e then Just e else loop (i+1)
360
361filter :: PrimType ty => (ty -> Bool) -> Block ty -> Block ty
362filter predicate vec = fromList $ Data.List.filter predicate $ toList vec
363
364reverse :: forall ty . PrimType ty => Block ty -> Block ty
365reverse blk
366    | len == 0  = mempty
367    | otherwise = runST $ do
368        mb <- new len
369        go mb
370        unsafeFreeze mb
371  where
372    !len = length blk
373    !endOfs = 0 `offsetPlusE` len
374
375    go :: MutableBlock ty s -> ST s ()
376    go mb = loop endOfs 0
377      where
378        loop o i
379            | i .==# len = pure ()
380            | otherwise  = unsafeWrite mb o' (unsafeIndex blk i) >> loop o' (i+1)
381          where o' = pred o
382
383sortBy :: PrimType ty => (ty -> ty -> Ordering) -> Block ty -> Block ty
384sortBy ford vec
385    | len == 0  = mempty
386    | otherwise = runST $ do
387        mblock <- thaw vec
388        MutAlg.inplaceSortBy ford 0 len mblock
389        unsafeFreeze mblock
390  where len = length vec
391{-# SPECIALIZE [2] sortBy :: (Word8 -> Word8 -> Ordering) -> Block Word8 -> Block Word8 #-}
392
393intersperse :: forall ty . PrimType ty => ty -> Block ty -> Block ty
394intersperse sep blk = case len - 1 of
395    Nothing -> blk
396    Just 0 -> blk
397    Just size -> runST $ do
398        mb <- new (len+size)
399        go mb
400        unsafeFreeze mb
401  where
402    !len = length blk
403
404    go :: MutableBlock ty s -> ST s ()
405    go mb = loop 0 0
406      where
407        loop !o !i
408            | (i + 1) .==# len = unsafeWrite mb o (unsafeIndex blk i)
409            | otherwise        = do
410                unsafeWrite mb o     (unsafeIndex blk i)
411                unsafeWrite mb (o+1) sep
412                loop (o+2) (i+1)
413
414-- | Unsafely recast an UArray containing 'a' to an UArray containing 'b'
415--
416-- The offset and size are converted from units of 'a' to units of 'b',
417-- but no check are performed to make sure this is compatible.
418--
419-- use 'cast' if unsure.
420unsafeCast :: PrimType b => Block a -> Block b
421unsafeCast (Block ba) = Block ba
422
423-- | Cast a Block of 'a' to a Block of 'b'
424--
425-- The requirement is that the size of type 'a' need to be a multiple or
426-- dividend of the size of type 'b'.
427--
428-- If this requirement is not met, the InvalidRecast exception is thrown
429cast :: forall a b . (PrimType a, PrimType b) => Block a -> Block b
430cast blk@(Block ba)
431    | aTypeSize == bTypeSize || bTypeSize == 1 = unsafeCast blk
432    | missing   == 0                           = unsafeCast blk
433    | otherwise                                =
434        throw $ InvalidRecast (RecastSourceSize alen) (RecastDestinationSize $ alen + missing)
435  where
436    (CountOf alen) = lengthBytes blk
437
438    aTypeSize = primSizeInBytes (Proxy :: Proxy a)
439    bTypeSize@(CountOf bs) = primSizeInBytes (Proxy :: Proxy b)
440
441    missing = alen `mod` bs
442