1{-# LANGUAGE BangPatterns #-}
2{-# LANGUAGE CPP #-}
3{-# LANGUAGE MagicHash #-}
4{-# LANGUAGE ScopedTypeVariables #-}
5{-# LANGUAGE TypeFamilies #-}
6{-# LANGUAGE UnboxedTuples #-}
7
8{-# OPTIONS_GHC -Wall #-}
9
10module PrimLawsWIP
11  ( primLaws
12  ) where
13
14import Control.Applicative
15import Control.Monad.Primitive (PrimMonad, PrimState,primitive,primitive_)
16import Control.Monad.ST
17import Data.Proxy (Proxy)
18import Data.Primitive.ByteArray
19import Data.Primitive.Types
20import Data.Primitive.Ptr
21import Foreign.Marshal.Alloc
22import GHC.Exts
23  (State#,Int#,Addr#,Int(I#),(*#),(+#),(<#),newByteArray#,unsafeFreezeByteArray#,
24   copyMutableByteArray#,copyByteArray#,quotInt#,sizeofByteArray#)
25
26#if MIN_VERSION_base(4,7,0)
27import GHC.Exts (IsList(fromList,toList,fromListN),Item,
28  copyByteArrayToAddr#,copyAddrToByteArray#)
29#endif
30
31import GHC.Ptr (Ptr(..))
32import System.IO.Unsafe
33import Test.QuickCheck hiding ((.&.))
34import Test.QuickCheck.Property (Property)
35
36import qualified Data.List as L
37import qualified Data.Primitive as P
38
39import Test.QuickCheck.Classes.Common (Laws(..))
40import Test.QuickCheck.Classes.Compat (isTrue#)
41
42-- | Test that a 'Prim' instance obey the several laws.
43primLaws :: (Prim a, Eq a, Arbitrary a, Show a) => Proxy a -> Laws
44primLaws p = Laws "Prim"
45  [ ("ByteArray Put-Get (you get back what you put in)", primPutGetByteArray p)
46  , ("ByteArray Get-Put (putting back what you got out has no effect)", primGetPutByteArray p)
47  , ("ByteArray Put-Put (putting twice is same as putting once)", primPutPutByteArray p)
48  , ("ByteArray Set Range", primSetByteArray p)
49#if MIN_VERSION_base(4,7,0)
50  , ("ByteArray List Conversion Roundtrips", primListByteArray p)
51#endif
52  , ("Addr Put-Get (you get back what you put in)", primPutGetAddr p)
53  , ("Addr Get-Put (putting back what you got out has no effect)", primGetPutAddr p)
54  , ("Addr Set Range", primSetOffAddr p)
55  , ("Addr List Conversion Roundtrips", primListAddr p)
56  ]
57
58primListAddr :: forall a. (Prim a, Eq a, Arbitrary a, Show a) => Proxy a -> Property
59primListAddr _ = property $ \(as :: [a]) -> unsafePerformIO $ do
60  let len = L.length as
61  ptr :: Ptr a <- mallocBytes (len * P.sizeOf (undefined :: a))
62  let go :: Int -> [a] -> IO ()
63      go !ix xs = case xs of
64        [] -> return ()
65        (x : xsNext) -> do
66          writeOffPtr ptr ix x
67          go (ix + 1) xsNext
68  go 0 as
69  let rebuild :: Int -> IO [a]
70      rebuild !ix = if ix < len
71        then (:) <$> readOffPtr ptr ix <*> rebuild (ix + 1)
72        else return []
73  asNew <- rebuild 0
74  free ptr
75  return (as == asNew)
76
77primPutGetByteArray :: forall a. (Prim a, Eq a, Arbitrary a, Show a) => Proxy a -> Property
78primPutGetByteArray _ = property $ \(a :: a) len -> (len > 0) ==> do
79  ix <- choose (0,len - 1)
80  return $ runST $ do
81    arr <- newPrimArray len
82    writePrimArray arr ix a
83    a' <- readPrimArray arr ix
84    return (a == a')
85
86primGetPutByteArray :: forall a. (Prim a, Eq a, Arbitrary a, Show a) => Proxy a -> Property
87primGetPutByteArray _ = property $ \(as :: [a]) -> (not (L.null as)) ==> do
88  let arr1 = primArrayFromList as :: PrimArray a
89      len = L.length as
90  ix <- choose (0,len - 1)
91  arr2 <- return $ runST $ do
92    marr <- newPrimArray len
93    copyPrimArray marr 0 arr1 0 len
94    a <- readPrimArray marr ix
95    writePrimArray marr ix a
96    unsafeFreezePrimArray marr
97  return (arr1 == arr2)
98
99primPutPutByteArray :: forall a. (Prim a, Eq a, Arbitrary a, Show a) => Proxy a -> Property
100primPutPutByteArray _ = property $ \(a :: a) (as :: [a]) -> (not (L.null as)) ==> do
101  let arr1 = primArrayFromList as :: PrimArray a
102      len = L.length as
103  ix <- choose (0,len - 1)
104  (arr2,arr3) <- return $ runST $ do
105    marr2 <- newPrimArray len
106    copyPrimArray marr2 0 arr1 0 len
107    writePrimArray marr2 ix a
108    marr3 <- newPrimArray len
109    copyMutablePrimArray marr3 0 marr2 0 len
110    arr2 <- unsafeFreezePrimArray marr2
111    writePrimArray marr3 ix a
112    arr3 <- unsafeFreezePrimArray marr3
113    return (arr2,arr3)
114  return (arr2 == arr3)
115
116primPutGetAddr :: forall a. (Prim a, Eq a, Arbitrary a, Show a) => Proxy a -> Property
117primPutGetAddr _ = property $ \(a :: a) len -> (len > 0) ==> do
118  ix <- choose (0,len - 1)
119  return $ unsafePerformIO $ do
120    ptr :: Ptr a <- mallocBytes (len * P.sizeOf (undefined :: a))
121    writeOffPtr ptr ix a
122    a' <- readOffPtr ptr ix
123    free ptr
124    return (a == a')
125
126primGetPutAddr :: forall a. (Prim a, Eq a, Arbitrary a, Show a) => Proxy a -> Property
127primGetPutAddr _ =  property $ True
128 --property $ \(as :: [a]) -> (not (L.null as)) ==> do
129 -- let arr1 = primArrayFromList as :: PrimArray a
130 --     len = L.length as
131 -- ix <- choose (0,len - 1)
132 -- arr2 <- return $ unsafePerformIO $ do
133 --   ptr:: Ptr a <- mallocBytes (len * P.sizeOf (undefined :: a))
134 --   copyPrimArrayToPtr ptr arr1 0 len
135 --   a <- readOffPtr ptr ix
136 --   writeOffPtr ptr ix a
137 --   marr <- newPrimArray len
138 --   copyPtrToMutablePrimArray marr 0 ptr len
139 --   free ptr
140 --   unsafeFreezePrimArray marr
141 -- return (arr1 == arr2)
142
143primSetByteArray :: forall a. (Prim a, Eq a, Arbitrary a, Show a) => Proxy a -> Property
144primSetByteArray _ = property $ \(as :: [a]) (z :: a) -> do
145  let arr1 = primArrayFromList as :: PrimArray a
146      len = L.length as
147  x <- choose (0,len)
148  y <- choose (0,len)
149  let lo = min x y
150      hi = max x y
151  return $ runST $ do
152    marr2 <- newPrimArray len
153    copyPrimArray marr2 0 arr1 0 len
154    marr3 <- newPrimArray len
155    copyPrimArray marr3 0 arr1 0 len
156    setPrimArray marr2 lo (hi - lo) z
157    internalDefaultSetPrimArray marr3 lo (hi - lo) z
158    arr2 <- unsafeFreezePrimArray marr2
159    arr3 <- unsafeFreezePrimArray marr3
160    return (arr2 == arr3)
161
162-- having trouble getting this to type check AND as written its really unsafe
163primSetOffAddr :: forall a. (Prim a, Eq a, Arbitrary a, Show a) => Proxy a -> Property
164primSetOffAddr _ =   property $ True
165--primSetOffAddr :: forall a. (Prim a, Eq a, Arbitrary a, Show a) => Proxy a -> Property
166--primSetOffAddr _ = property $ \(as :: [a]) (z :: a) -> do
167--  let arr1 = primArrayFromList as :: PrimArray a
168--      len = L.length as
169--  x <- choose (0,len)
170--  y <- choose (0,len)
171--  let lo = min x y
172--      hi = max x y
173--  return $ unsafePerformIO $ do
174--    ptrA@(Ptr addrA#) :: Ptr a <- mallocBytes (len * P.sizeOf (undefined :: a))
175
176--    copyPrimArrayToPtr ptrA arr1 0 len
177--    ptrB@(Ptr addrB#) :: Ptr a <- mallocBytes (len * P.sizeOf (undefined :: a))
178
179--    copyPrimArrayToPtr ptrB arr1 0 len
180--    setPtr ptrA lo (hi - lo) z
181--    internalDefaultSetOffAddr ptrB lo (hi - lo) z
182--    marrA <- newPrimArray len
183--    copyPtrToMutablePrimArray marrA 0 ptrA len
184--    free ptrA
185--    marrB <- newPrimArray len
186--    copyPtrToMutablePrimArray marrB 0 ptrB len
187--    free ptrB
188--    arrA <- unsafeFreezePrimArray marrA
189--    arrB <- unsafeFreezePrimArray marrB
190--    return (arrA == arrB)
191
192-- byte array with phantom variable that specifies element type
193data PrimArray a = PrimArray ByteArray#
194data MutablePrimArray s a = MutablePrimArray (MutableByteArray# s)
195
196instance (Eq a, Prim a) => Eq (PrimArray a) where
197  a1 == a2 = sizeofPrimArray a1 == sizeofPrimArray a2 && loop (sizeofPrimArray a1 - 1)
198    where
199    loop !i | i < 0 = True
200            | otherwise = indexPrimArray a1 i == indexPrimArray a2 i && loop (i-1)
201
202#if MIN_VERSION_base(4,7,0)
203instance Prim a => IsList (PrimArray a) where
204  type Item (PrimArray a) = a
205  fromList = primArrayFromList
206  fromListN = primArrayFromListN
207  toList = primArrayToList
208#endif
209
210indexPrimArray :: forall a. Prim a => PrimArray a -> Int -> a
211indexPrimArray (PrimArray arr#) (I# i#) = indexByteArray# arr# i#
212
213sizeofPrimArray :: forall a. Prim a => PrimArray a -> Int
214sizeofPrimArray (PrimArray arr#) = I# (quotInt# (sizeofByteArray# arr#) (P.sizeOf# (undefined :: a)))
215
216newPrimArray :: forall m a. (PrimMonad m, Prim a) => Int -> m (MutablePrimArray (PrimState m) a)
217newPrimArray (I# n#)
218  = primitive (\s# ->
219      case newByteArray# (n# *# sizeOf# (undefined :: a)) s# of
220        (# s'#, arr# #) -> (# s'#, MutablePrimArray arr# #)
221    )
222
223readPrimArray :: (Prim a, PrimMonad m) => MutablePrimArray (PrimState m) a -> Int -> m a
224readPrimArray (MutablePrimArray arr#) (I# i#)
225  = primitive (readByteArray# arr# i#)
226
227writePrimArray ::
228     (Prim a, PrimMonad m)
229  => MutablePrimArray (PrimState m) a
230  -> Int
231  -> a
232  -> m ()
233writePrimArray (MutablePrimArray arr#) (I# i#) x
234  = primitive_ (writeByteArray# arr# i# x)
235
236unsafeFreezePrimArray
237  :: PrimMonad m => MutablePrimArray (PrimState m) a -> m (PrimArray a)
238unsafeFreezePrimArray (MutablePrimArray arr#)
239  = primitive (\s# -> case unsafeFreezeByteArray# arr# s# of
240                        (# s'#, arr'# #) -> (# s'#, PrimArray arr'# #))
241
242
243
244generateM_ :: Monad m => Int -> (Int -> m a) -> m ()
245generateM_ n f = go 0 where
246  go !ix = if ix < n
247    then f ix >> go (ix + 1)
248    else return ()
249
250
251copyPrimArrayToPtr :: forall m a. (PrimMonad m, Prim a)
252  => Ptr a       -- ^ destination pointer
253  -> PrimArray a -- ^ source array
254  -> Int         -- ^ offset into source array
255  -> Int         -- ^ number of prims to copy
256  -> m ()
257#if MIN_VERSION_base(4,7,0)
258copyPrimArrayToPtr (Ptr addr#) (PrimArray ba#) (I# soff#) (I# n#) =
259  primitive (\ s# ->
260      let s'# = copyByteArrayToAddr# ba# (soff# *# siz#) addr# (n# *# siz#) s#
261      in (# s'#, () #))
262  where siz# = sizeOf# (undefined :: a)
263#else
264copyPrimArrayToPtr ptr  ba soff n =
265  generateM_ n $ \ix -> writeOffPtr ptr  ix (indexPrimArray ba (ix + soff))
266#endif
267{-
268copyPtrToMutablePrimArray :: forall m a. (PrimMonad m, Prim a)
269  => MutablePrimArray (PrimState m) a
270  -> Int
271  -> Ptr a
272  -> Int
273  -> m ()
274#if MIN_VERSION_base(4,7,0)
275copyPtrToMutablePrimArray (MutablePrimArray ba#) (I# doff#) (Ptr addr#) (I# n#) =
276  primitive (\ s# ->
277      let s'# = copyAddrToByteArray# addr# ba# (doff# *# siz#) (n# *# siz#) s#
278      in (# s'#, () #))
279  where siz# = sizeOf# (undefined :: a)
280#else
281copyPtrToMutablePrimArray ba doff addr n =
282  generateM_ n $ \ix -> do
283    x <- readOffAddr (ptrToAddr addr) ix
284    writePrimArray ba (doff + ix) x
285#endif
286-}
287copyMutablePrimArray :: forall m s a.
288     (PrimMonad m, s ~ PrimState m , Prim a)
289  => MutablePrimArray s a -- ^ destination array
290  -> Int -- ^ offset into destination array
291  -> MutablePrimArray s  a -- ^ source array
292  -> Int -- ^ offset into source array
293  -> Int -- ^ number of bytes to copy
294  -> m ()
295copyMutablePrimArray (MutablePrimArray dst#) (I# doff#) (MutablePrimArray src#) (I# soff#) (I# n#)
296  = primitive_ (copyMutableByteArray#
297      src#
298      (soff# *# (sizeOf# (undefined :: a)))
299      dst#
300      (doff# *# (sizeOf# (undefined :: a)))
301      (n# *# (sizeOf# (undefined :: a)))
302    )
303
304copyPrimArray :: forall m a.
305     (PrimMonad m, Prim a)
306  => MutablePrimArray (PrimState m) a -- ^ destination array
307  -> Int -- ^ offset into destination array
308  -> PrimArray a -- ^ source array
309  -> Int -- ^ offset into source array
310  -> Int -- ^ number of bytes to copy
311  -> m ()
312copyPrimArray (MutablePrimArray dst#) (I# doff#) (PrimArray src#) (I# soff#) (I# n#)
313  = primitive_ (copyByteArray#
314      src#
315      (soff# *# (sizeOf# (undefined :: a)))
316      dst#
317      (doff# *# (sizeOf# (undefined :: a)))
318      (n# *# (sizeOf# (undefined :: a)))
319    )
320
321setPrimArray
322  :: (Prim a, PrimMonad m)
323  => MutablePrimArray (PrimState m) a -- ^ array to fill
324  -> Int -- ^ offset into array
325  -> Int -- ^ number of values to fill
326  -> a -- ^ value to fill with
327  -> m ()
328setPrimArray (MutablePrimArray dst#) (I# doff#) (I# sz#) x
329  = primitive_ (P.setByteArray# dst# doff# sz# x)
330
331primArrayFromList :: Prim a => [a] -> PrimArray a
332primArrayFromList xs = primArrayFromListN (L.length xs) xs
333
334primArrayFromListN :: forall a. Prim a => Int -> [a] -> PrimArray a
335primArrayFromListN len vs = runST run where
336  run :: forall s. ST s (PrimArray a)
337  run = do
338    arr <- newPrimArray len
339    let go :: [a] -> Int -> ST s ()
340        go !xs !ix = case xs of
341          [] -> return ()
342          a : as -> do
343            writePrimArray arr ix a
344            go as (ix + 1)
345    go vs 0
346    unsafeFreezePrimArray arr
347
348primArrayToList :: forall a. Prim a => PrimArray a -> [a]
349primArrayToList arr = go 0 where
350  !len = sizeofPrimArray arr
351  go :: Int -> [a]
352  go !ix = if ix < len
353    then indexPrimArray arr ix : go (ix + 1)
354    else []
355
356#if MIN_VERSION_base(4,7,0)
357primListByteArray :: forall a. (Prim a, Eq a, Arbitrary a, Show a) => Proxy a -> Property
358primListByteArray _ = property $ \(as :: [a]) ->
359  as == toList (fromList as :: PrimArray a)
360#endif
361
362
363internalDefaultSetPrimArray :: Prim a
364  => MutablePrimArray s a -> Int -> Int -> a -> ST s ()
365internalDefaultSetPrimArray (MutablePrimArray arr) (I# i) (I# len) ident =
366  primitive_ (internalDefaultSetByteArray# arr i len ident)
367
368internalDefaultSetByteArray# :: Prim a
369  => MutableByteArray# s -> Int# -> Int# -> a -> State# s -> State# s
370internalDefaultSetByteArray# arr# i# len# ident = go 0#
371  where
372  go ix# s0 = if isTrue# (ix# <# len#)
373    then case writeByteArray# arr# (i# +# ix#) ident s0 of
374      s1 -> go (ix# +# 1#) s1
375    else s0
376
377internalDefaultSetOffAddr :: Prim a => Ptr a -> Int -> Int -> a -> IO ()
378internalDefaultSetOffAddr (Ptr addr) (I# ix) (I# len) a = primitive_
379  (internalDefaultSetOffAddr# addr ix len a)
380
381internalDefaultSetOffAddr# :: Prim a => Addr# -> Int# -> Int# -> a -> State# s -> State# s
382internalDefaultSetOffAddr# addr# i# len# ident = go 0#
383  where
384  go ix# s0 = if isTrue# (ix# <# len#)
385    then case writeOffAddr# addr# (i# +# ix#) ident s0 of
386      s1 -> go (ix# +# 1#) s1
387    else s0
388