1-- |
2-- Module      : Basement.Sized.Block
3-- License     : BSD-style
4-- Maintainer  : Haskell Foundation
5--
6-- A Nat-sized version of Block
7{-# LANGUAGE AllowAmbiguousTypes        #-}
8{-# LANGUAGE CPP                        #-}
9{-# LANGUAGE ConstraintKinds            #-}
10{-# LANGUAGE DataKinds                  #-}
11{-# LANGUAGE GeneralizedNewtypeDeriving #-}
12{-# LANGUAGE ScopedTypeVariables        #-}
13{-# LANGUAGE TypeApplications           #-}
14{-# LANGUAGE TypeOperators              #-}
15#if __GLASGOW_HASKELL__ >= 806
16{-# LANGUAGE NoStarIsType               #-}
17#endif
18
19module Basement.Sized.Block
20    ( BlockN
21    , MutableBlockN
22    , length
23    , lengthBytes
24    , toBlockN
25    , toBlock
26    , new
27    , newPinned
28    , singleton
29    , replicate
30    , thaw
31    , freeze
32    , index
33    , indexStatic
34    , map
35    , foldl'
36    , foldr
37    , cons
38    , snoc
39    , elem
40    , sub
41    , uncons
42    , unsnoc
43    , splitAt
44    , all
45    , any
46    , find
47    , reverse
48    , sortBy
49    , intersperse
50    , withPtr
51    , withMutablePtr
52    , withMutablePtrHint
53    , cast
54    , mutableCast
55    ) where
56
57import           Data.Proxy (Proxy(..))
58import           Basement.Compat.Base
59import           Basement.Numerical.Additive (scale)
60import           Basement.Block (Block, MutableBlock(..), unsafeIndex)
61import qualified Basement.Block as B
62import qualified Basement.Block.Base as B
63import           Basement.Monad (PrimMonad, PrimState)
64import           Basement.Nat
65import           Basement.Types.OffsetSize
66import           Basement.NormalForm
67import           Basement.PrimType (PrimType, PrimSize, primSizeInBytes)
68
69-- | Sized version of 'Block'
70--
71newtype BlockN (n :: Nat) a = BlockN { unBlock :: Block a }
72  deriving (NormalForm, Eq, Show, Data, Ord)
73
74newtype MutableBlockN (n :: Nat) ty st = MutableBlockN { unMBlock :: MutableBlock ty st }
75
76toBlockN :: forall n ty . (PrimType ty, KnownNat n, Countable ty n) => Block ty -> Maybe (BlockN n ty)
77toBlockN b
78    | expected == B.length b = Just (BlockN b)
79    | otherwise = Nothing
80  where
81    expected = toCount @n
82
83length :: forall n ty
84        . (KnownNat n, Countable ty n)
85       => BlockN n ty
86       -> CountOf ty
87length _ = toCount @n
88
89lengthBytes :: forall n ty
90             . PrimType ty
91            => BlockN n ty
92            -> CountOf Word8
93lengthBytes = B.lengthBytes . unBlock
94
95toBlock :: BlockN n ty -> Block ty
96toBlock = unBlock
97
98cast :: forall n m a b
99      . ( PrimType a, PrimType b
100        , KnownNat n, KnownNat m
101        , ((PrimSize b) * m) ~ ((PrimSize a) * n)
102        )
103      => BlockN n a
104      -> BlockN m b
105cast (BlockN b) = BlockN (B.unsafeCast b)
106
107mutableCast :: forall n m a b st
108             . ( PrimType a, PrimType b
109             , KnownNat n, KnownNat m
110             , ((PrimSize b) * m) ~ ((PrimSize a) * n)
111             )
112            => MutableBlockN n a st
113            -> MutableBlockN m b st
114mutableCast (MutableBlockN b) = MutableBlockN (B.unsafeRecast b)
115
116-- | Create a new unpinned mutable block of a specific N size of 'ty' elements
117--
118-- If the size exceeds a GHC-defined threshold, then the memory will be
119-- pinned. To be certain about pinning status with small size, use 'newPinned'
120new :: forall n ty prim
121     . (PrimType ty, KnownNat n, Countable ty n, PrimMonad prim)
122    => prim (MutableBlockN n ty (PrimState prim))
123new = MutableBlockN <$> B.new (toCount @n)
124
125-- | Create a new pinned mutable block of a specific N size of 'ty' elements
126newPinned :: forall n ty prim
127           . (PrimType ty, KnownNat n, Countable ty n, PrimMonad prim)
128          => prim (MutableBlockN n ty (PrimState prim))
129newPinned = MutableBlockN <$> B.newPinned (toCount @n)
130
131singleton :: PrimType ty => ty -> BlockN 1 ty
132singleton a = BlockN (B.singleton a)
133
134replicate :: forall n ty . (KnownNat n, Countable ty n, PrimType ty) => ty -> BlockN n ty
135replicate a = BlockN (B.replicate (toCount @n) a)
136
137thaw :: (KnownNat n, PrimMonad prim, PrimType ty) => BlockN n ty -> prim (MutableBlockN n ty (PrimState prim))
138thaw b = MutableBlockN <$> B.thaw (unBlock b)
139
140freeze ::  (PrimMonad prim, PrimType ty, Countable ty n) => MutableBlockN n ty (PrimState prim) -> prim (BlockN n ty)
141freeze b = BlockN <$> B.freeze (unMBlock b)
142
143indexStatic :: forall i n ty . (KnownNat i, CmpNat i n ~ 'LT, PrimType ty, Offsetable ty i) => BlockN n ty -> ty
144indexStatic b = unsafeIndex (unBlock b) (toOffset @i)
145
146index :: forall i n ty . PrimType ty => BlockN n ty -> Offset ty -> ty
147index b ofs = B.index (unBlock b) ofs
148
149map :: (PrimType a, PrimType b) => (a -> b) -> BlockN n a -> BlockN n b
150map f b = BlockN (B.map f (unBlock b))
151
152foldl' :: PrimType ty => (a -> ty -> a) -> a -> BlockN n ty -> a
153foldl' f acc b = B.foldl' f acc (unBlock b)
154
155foldr :: PrimType ty => (ty -> a -> a) -> a -> BlockN n ty -> a
156foldr f acc b = B.foldr f acc (unBlock b)
157
158cons :: PrimType ty => ty -> BlockN n ty -> BlockN (n+1) ty
159cons e = BlockN . B.cons e . unBlock
160
161snoc :: PrimType ty => BlockN n ty -> ty -> BlockN (n+1) ty
162snoc b = BlockN . B.snoc (unBlock b)
163
164sub :: forall i j n ty
165     . ( (i <=? n) ~ 'True
166       , (j <=? n) ~ 'True
167       , (i <=? j) ~ 'True
168       , PrimType ty
169       , KnownNat i
170       , KnownNat j
171       , Offsetable ty i
172       , Offsetable ty j )
173    => BlockN n ty
174    -> BlockN (j-i) ty
175sub block = BlockN (B.sub (unBlock block) (toOffset @i) (toOffset @j))
176
177uncons :: forall n ty . (CmpNat 0 n ~ 'LT, PrimType ty, KnownNat n, Offsetable ty n)
178       => BlockN n ty
179       -> (ty, BlockN (n-1) ty)
180uncons b = (indexStatic @0 b, BlockN (B.sub (unBlock b) 1 (toOffset @n)))
181
182unsnoc :: forall n ty . (CmpNat 0 n ~ 'LT, KnownNat n, PrimType ty, Offsetable ty n)
183       => BlockN n ty
184       -> (BlockN (n-1) ty, ty)
185unsnoc b =
186    ( BlockN (B.sub (unBlock b) 0 (toOffset @n `offsetSub` 1))
187    , unsafeIndex (unBlock b) (toOffset @n `offsetSub` 1))
188
189splitAt :: forall i n ty . (CmpNat i n ~ 'LT, PrimType ty, KnownNat i, Countable ty i) => BlockN n ty -> (BlockN i ty, BlockN (n-i) ty)
190splitAt b =
191    let (left, right) = B.splitAt (toCount @i) (unBlock b)
192     in (BlockN left, BlockN right)
193
194elem :: PrimType ty => ty -> BlockN n ty -> Bool
195elem e b = B.elem e (unBlock b)
196
197all :: PrimType ty => (ty -> Bool) -> BlockN n ty -> Bool
198all p b = B.all p (unBlock b)
199
200any :: PrimType ty => (ty -> Bool) -> BlockN n ty -> Bool
201any p b = B.any p (unBlock b)
202
203find :: PrimType ty => (ty -> Bool) -> BlockN n ty -> Maybe ty
204find p b = B.find p (unBlock b)
205
206reverse :: PrimType ty => BlockN n ty -> BlockN n ty
207reverse = BlockN . B.reverse . unBlock
208
209sortBy :: PrimType ty => (ty -> ty -> Ordering) -> BlockN n ty -> BlockN n ty
210sortBy f b = BlockN (B.sortBy f (unBlock b))
211
212intersperse :: (CmpNat n 1 ~ 'GT, PrimType ty) => ty -> BlockN n ty -> BlockN (n+n-1) ty
213intersperse sep b = BlockN (B.intersperse sep (unBlock b))
214
215toCount :: forall n ty . (KnownNat n, Countable ty n) => CountOf ty
216toCount = natValCountOf (Proxy @n)
217
218toOffset :: forall n ty . (KnownNat n, Offsetable ty n) => Offset ty
219toOffset = natValOffset (Proxy @n)
220
221-- | Get a Ptr pointing to the data in the Block.
222--
223-- Since a Block is immutable, this Ptr shouldn't be
224-- to use to modify the contents
225--
226-- If the Block is pinned, then its address is returned as is,
227-- however if it's unpinned, a pinned copy of the Block is made
228-- before getting the address.
229withPtr :: (PrimMonad prim, KnownNat n)
230        => BlockN n ty
231        -> (Ptr ty -> prim a)
232        -> prim a
233withPtr b = B.withPtr (unBlock b)
234
235-- | Create a pointer on the beginning of the MutableBlock
236-- and call a function 'f'.
237--
238-- The mutable block can be mutated by the 'f' function
239-- and the change will be reflected in the mutable block
240--
241-- If the mutable block is unpinned, a trampoline buffer
242-- is created and the data is only copied when 'f' return.
243--
244-- it is all-in-all highly inefficient as this cause 2 copies
245withMutablePtr :: (PrimMonad prim, KnownNat n)
246               => MutableBlockN n ty (PrimState prim)
247               -> (Ptr ty -> prim a)
248               -> prim a
249withMutablePtr mb = B.withMutablePtr (unMBlock mb)
250
251-- | Same as 'withMutablePtr' but allow to specify 2 optimisations
252-- which is only useful when the MutableBlock is unpinned and need
253-- a pinned trampoline to be called safely.
254--
255-- If skipCopy is True, then the first copy which happen before
256-- the call to 'f', is skipped. The Ptr is now effectively
257-- pointing to uninitialized data in a new mutable Block.
258--
259-- If skipCopyBack is True, then the second copy which happen after
260-- the call to 'f', is skipped. Then effectively in the case of a
261-- trampoline being used the memory changed by 'f' will not
262-- be reflected in the original Mutable Block.
263--
264-- If using the wrong parameters, it will lead to difficult to
265-- debug issue of corrupted buffer which only present themselves
266-- with certain Mutable Block that happened to have been allocated
267-- unpinned.
268--
269-- If unsure use 'withMutablePtr', which default to *not* skip
270-- any copy.
271withMutablePtrHint :: forall n ty prim a . (PrimMonad prim, KnownNat n)
272                   => Bool -- ^ hint that the buffer doesn't need to have the same value as the mutable block when calling f
273                   -> Bool -- ^ hint that the buffer is not supposed to be modified by call of f
274                   -> MutableBlockN n ty (PrimState prim)
275                   -> (Ptr ty -> prim a)
276                   -> prim a
277withMutablePtrHint skipCopy skipCopyBack (MutableBlockN mb) f =
278    B.withMutablePtrHint skipCopy skipCopyBack mb f
279