1{-# LANGUAGE BangPatterns #-} 2{-# LANGUAGE CPP #-} 3{-# LANGUAGE MagicHash #-} 4{-# LANGUAGE ScopedTypeVariables #-} 5{-# LANGUAGE TypeFamilies #-} 6{-# LANGUAGE UnboxedTuples #-} 7 8{-# OPTIONS_GHC -Wall #-} 9 10module PrimLawsWIP 11 ( primLaws 12 ) where 13 14import Control.Applicative 15import Control.Monad.Primitive (PrimMonad, PrimState,primitive,primitive_) 16import Control.Monad.ST 17import Data.Proxy (Proxy) 18import Data.Primitive.ByteArray 19import Data.Primitive.Types 20import Data.Primitive.Ptr 21import Foreign.Marshal.Alloc 22import GHC.Exts 23 (State#,Int#,Addr#,Int(I#),(*#),(+#),(<#),newByteArray#,unsafeFreezeByteArray#, 24 copyMutableByteArray#,copyByteArray#,quotInt#,sizeofByteArray#) 25 26#if MIN_VERSION_base(4,7,0) 27import GHC.Exts (IsList(fromList,toList,fromListN),Item, 28 copyByteArrayToAddr#,copyAddrToByteArray#) 29#endif 30 31import GHC.Ptr (Ptr(..)) 32import System.IO.Unsafe 33import Test.QuickCheck hiding ((.&.)) 34import Test.QuickCheck.Property (Property) 35 36import qualified Data.List as L 37import qualified Data.Primitive as P 38 39import Test.QuickCheck.Classes.Common (Laws(..)) 40import Test.QuickCheck.Classes.Compat (isTrue#) 41 42-- | Test that a 'Prim' instance obey the several laws. 43primLaws :: (Prim a, Eq a, Arbitrary a, Show a) => Proxy a -> Laws 44primLaws p = Laws "Prim" 45 [ ("ByteArray Put-Get (you get back what you put in)", primPutGetByteArray p) 46 , ("ByteArray Get-Put (putting back what you got out has no effect)", primGetPutByteArray p) 47 , ("ByteArray Put-Put (putting twice is same as putting once)", primPutPutByteArray p) 48 , ("ByteArray Set Range", primSetByteArray p) 49#if MIN_VERSION_base(4,7,0) 50 , ("ByteArray List Conversion Roundtrips", primListByteArray p) 51#endif 52 , ("Addr Put-Get (you get back what you put in)", primPutGetAddr p) 53 , ("Addr Get-Put (putting back what you got out has no effect)", primGetPutAddr p) 54 , ("Addr Set Range", primSetOffAddr p) 55 , ("Addr List Conversion Roundtrips", primListAddr p) 56 ] 57 58primListAddr :: forall a. (Prim a, Eq a, Arbitrary a, Show a) => Proxy a -> Property 59primListAddr _ = property $ \(as :: [a]) -> unsafePerformIO $ do 60 let len = L.length as 61 ptr :: Ptr a <- mallocBytes (len * P.sizeOf (undefined :: a)) 62 let go :: Int -> [a] -> IO () 63 go !ix xs = case xs of 64 [] -> return () 65 (x : xsNext) -> do 66 writeOffPtr ptr ix x 67 go (ix + 1) xsNext 68 go 0 as 69 let rebuild :: Int -> IO [a] 70 rebuild !ix = if ix < len 71 then (:) <$> readOffPtr ptr ix <*> rebuild (ix + 1) 72 else return [] 73 asNew <- rebuild 0 74 free ptr 75 return (as == asNew) 76 77primPutGetByteArray :: forall a. (Prim a, Eq a, Arbitrary a, Show a) => Proxy a -> Property 78primPutGetByteArray _ = property $ \(a :: a) len -> (len > 0) ==> do 79 ix <- choose (0,len - 1) 80 return $ runST $ do 81 arr <- newPrimArray len 82 writePrimArray arr ix a 83 a' <- readPrimArray arr ix 84 return (a == a') 85 86primGetPutByteArray :: forall a. (Prim a, Eq a, Arbitrary a, Show a) => Proxy a -> Property 87primGetPutByteArray _ = property $ \(as :: [a]) -> (not (L.null as)) ==> do 88 let arr1 = primArrayFromList as :: PrimArray a 89 len = L.length as 90 ix <- choose (0,len - 1) 91 arr2 <- return $ runST $ do 92 marr <- newPrimArray len 93 copyPrimArray marr 0 arr1 0 len 94 a <- readPrimArray marr ix 95 writePrimArray marr ix a 96 unsafeFreezePrimArray marr 97 return (arr1 == arr2) 98 99primPutPutByteArray :: forall a. (Prim a, Eq a, Arbitrary a, Show a) => Proxy a -> Property 100primPutPutByteArray _ = property $ \(a :: a) (as :: [a]) -> (not (L.null as)) ==> do 101 let arr1 = primArrayFromList as :: PrimArray a 102 len = L.length as 103 ix <- choose (0,len - 1) 104 (arr2,arr3) <- return $ runST $ do 105 marr2 <- newPrimArray len 106 copyPrimArray marr2 0 arr1 0 len 107 writePrimArray marr2 ix a 108 marr3 <- newPrimArray len 109 copyMutablePrimArray marr3 0 marr2 0 len 110 arr2 <- unsafeFreezePrimArray marr2 111 writePrimArray marr3 ix a 112 arr3 <- unsafeFreezePrimArray marr3 113 return (arr2,arr3) 114 return (arr2 == arr3) 115 116primPutGetAddr :: forall a. (Prim a, Eq a, Arbitrary a, Show a) => Proxy a -> Property 117primPutGetAddr _ = property $ \(a :: a) len -> (len > 0) ==> do 118 ix <- choose (0,len - 1) 119 return $ unsafePerformIO $ do 120 ptr :: Ptr a <- mallocBytes (len * P.sizeOf (undefined :: a)) 121 writeOffPtr ptr ix a 122 a' <- readOffPtr ptr ix 123 free ptr 124 return (a == a') 125 126primGetPutAddr :: forall a. (Prim a, Eq a, Arbitrary a, Show a) => Proxy a -> Property 127primGetPutAddr _ = property $ True 128 --property $ \(as :: [a]) -> (not (L.null as)) ==> do 129 -- let arr1 = primArrayFromList as :: PrimArray a 130 -- len = L.length as 131 -- ix <- choose (0,len - 1) 132 -- arr2 <- return $ unsafePerformIO $ do 133 -- ptr:: Ptr a <- mallocBytes (len * P.sizeOf (undefined :: a)) 134 -- copyPrimArrayToPtr ptr arr1 0 len 135 -- a <- readOffPtr ptr ix 136 -- writeOffPtr ptr ix a 137 -- marr <- newPrimArray len 138 -- copyPtrToMutablePrimArray marr 0 ptr len 139 -- free ptr 140 -- unsafeFreezePrimArray marr 141 -- return (arr1 == arr2) 142 143primSetByteArray :: forall a. (Prim a, Eq a, Arbitrary a, Show a) => Proxy a -> Property 144primSetByteArray _ = property $ \(as :: [a]) (z :: a) -> do 145 let arr1 = primArrayFromList as :: PrimArray a 146 len = L.length as 147 x <- choose (0,len) 148 y <- choose (0,len) 149 let lo = min x y 150 hi = max x y 151 return $ runST $ do 152 marr2 <- newPrimArray len 153 copyPrimArray marr2 0 arr1 0 len 154 marr3 <- newPrimArray len 155 copyPrimArray marr3 0 arr1 0 len 156 setPrimArray marr2 lo (hi - lo) z 157 internalDefaultSetPrimArray marr3 lo (hi - lo) z 158 arr2 <- unsafeFreezePrimArray marr2 159 arr3 <- unsafeFreezePrimArray marr3 160 return (arr2 == arr3) 161 162-- having trouble getting this to type check AND as written its really unsafe 163primSetOffAddr :: forall a. (Prim a, Eq a, Arbitrary a, Show a) => Proxy a -> Property 164primSetOffAddr _ = property $ True 165--primSetOffAddr :: forall a. (Prim a, Eq a, Arbitrary a, Show a) => Proxy a -> Property 166--primSetOffAddr _ = property $ \(as :: [a]) (z :: a) -> do 167-- let arr1 = primArrayFromList as :: PrimArray a 168-- len = L.length as 169-- x <- choose (0,len) 170-- y <- choose (0,len) 171-- let lo = min x y 172-- hi = max x y 173-- return $ unsafePerformIO $ do 174-- ptrA@(Ptr addrA#) :: Ptr a <- mallocBytes (len * P.sizeOf (undefined :: a)) 175 176-- copyPrimArrayToPtr ptrA arr1 0 len 177-- ptrB@(Ptr addrB#) :: Ptr a <- mallocBytes (len * P.sizeOf (undefined :: a)) 178 179-- copyPrimArrayToPtr ptrB arr1 0 len 180-- setPtr ptrA lo (hi - lo) z 181-- internalDefaultSetOffAddr ptrB lo (hi - lo) z 182-- marrA <- newPrimArray len 183-- copyPtrToMutablePrimArray marrA 0 ptrA len 184-- free ptrA 185-- marrB <- newPrimArray len 186-- copyPtrToMutablePrimArray marrB 0 ptrB len 187-- free ptrB 188-- arrA <- unsafeFreezePrimArray marrA 189-- arrB <- unsafeFreezePrimArray marrB 190-- return (arrA == arrB) 191 192-- byte array with phantom variable that specifies element type 193data PrimArray a = PrimArray ByteArray# 194data MutablePrimArray s a = MutablePrimArray (MutableByteArray# s) 195 196instance (Eq a, Prim a) => Eq (PrimArray a) where 197 a1 == a2 = sizeofPrimArray a1 == sizeofPrimArray a2 && loop (sizeofPrimArray a1 - 1) 198 where 199 loop !i | i < 0 = True 200 | otherwise = indexPrimArray a1 i == indexPrimArray a2 i && loop (i-1) 201 202#if MIN_VERSION_base(4,7,0) 203instance Prim a => IsList (PrimArray a) where 204 type Item (PrimArray a) = a 205 fromList = primArrayFromList 206 fromListN = primArrayFromListN 207 toList = primArrayToList 208#endif 209 210indexPrimArray :: forall a. Prim a => PrimArray a -> Int -> a 211indexPrimArray (PrimArray arr#) (I# i#) = indexByteArray# arr# i# 212 213sizeofPrimArray :: forall a. Prim a => PrimArray a -> Int 214sizeofPrimArray (PrimArray arr#) = I# (quotInt# (sizeofByteArray# arr#) (P.sizeOf# (undefined :: a))) 215 216newPrimArray :: forall m a. (PrimMonad m, Prim a) => Int -> m (MutablePrimArray (PrimState m) a) 217newPrimArray (I# n#) 218 = primitive (\s# -> 219 case newByteArray# (n# *# sizeOf# (undefined :: a)) s# of 220 (# s'#, arr# #) -> (# s'#, MutablePrimArray arr# #) 221 ) 222 223readPrimArray :: (Prim a, PrimMonad m) => MutablePrimArray (PrimState m) a -> Int -> m a 224readPrimArray (MutablePrimArray arr#) (I# i#) 225 = primitive (readByteArray# arr# i#) 226 227writePrimArray :: 228 (Prim a, PrimMonad m) 229 => MutablePrimArray (PrimState m) a 230 -> Int 231 -> a 232 -> m () 233writePrimArray (MutablePrimArray arr#) (I# i#) x 234 = primitive_ (writeByteArray# arr# i# x) 235 236unsafeFreezePrimArray 237 :: PrimMonad m => MutablePrimArray (PrimState m) a -> m (PrimArray a) 238unsafeFreezePrimArray (MutablePrimArray arr#) 239 = primitive (\s# -> case unsafeFreezeByteArray# arr# s# of 240 (# s'#, arr'# #) -> (# s'#, PrimArray arr'# #)) 241 242 243 244generateM_ :: Monad m => Int -> (Int -> m a) -> m () 245generateM_ n f = go 0 where 246 go !ix = if ix < n 247 then f ix >> go (ix + 1) 248 else return () 249 250 251copyPrimArrayToPtr :: forall m a. (PrimMonad m, Prim a) 252 => Ptr a -- ^ destination pointer 253 -> PrimArray a -- ^ source array 254 -> Int -- ^ offset into source array 255 -> Int -- ^ number of prims to copy 256 -> m () 257#if MIN_VERSION_base(4,7,0) 258copyPrimArrayToPtr (Ptr addr#) (PrimArray ba#) (I# soff#) (I# n#) = 259 primitive (\ s# -> 260 let s'# = copyByteArrayToAddr# ba# (soff# *# siz#) addr# (n# *# siz#) s# 261 in (# s'#, () #)) 262 where siz# = sizeOf# (undefined :: a) 263#else 264copyPrimArrayToPtr ptr ba soff n = 265 generateM_ n $ \ix -> writeOffPtr ptr ix (indexPrimArray ba (ix + soff)) 266#endif 267{- 268copyPtrToMutablePrimArray :: forall m a. (PrimMonad m, Prim a) 269 => MutablePrimArray (PrimState m) a 270 -> Int 271 -> Ptr a 272 -> Int 273 -> m () 274#if MIN_VERSION_base(4,7,0) 275copyPtrToMutablePrimArray (MutablePrimArray ba#) (I# doff#) (Ptr addr#) (I# n#) = 276 primitive (\ s# -> 277 let s'# = copyAddrToByteArray# addr# ba# (doff# *# siz#) (n# *# siz#) s# 278 in (# s'#, () #)) 279 where siz# = sizeOf# (undefined :: a) 280#else 281copyPtrToMutablePrimArray ba doff addr n = 282 generateM_ n $ \ix -> do 283 x <- readOffAddr (ptrToAddr addr) ix 284 writePrimArray ba (doff + ix) x 285#endif 286-} 287copyMutablePrimArray :: forall m s a. 288 (PrimMonad m, s ~ PrimState m , Prim a) 289 => MutablePrimArray s a -- ^ destination array 290 -> Int -- ^ offset into destination array 291 -> MutablePrimArray s a -- ^ source array 292 -> Int -- ^ offset into source array 293 -> Int -- ^ number of bytes to copy 294 -> m () 295copyMutablePrimArray (MutablePrimArray dst#) (I# doff#) (MutablePrimArray src#) (I# soff#) (I# n#) 296 = primitive_ (copyMutableByteArray# 297 src# 298 (soff# *# (sizeOf# (undefined :: a))) 299 dst# 300 (doff# *# (sizeOf# (undefined :: a))) 301 (n# *# (sizeOf# (undefined :: a))) 302 ) 303 304copyPrimArray :: forall m a. 305 (PrimMonad m, Prim a) 306 => MutablePrimArray (PrimState m) a -- ^ destination array 307 -> Int -- ^ offset into destination array 308 -> PrimArray a -- ^ source array 309 -> Int -- ^ offset into source array 310 -> Int -- ^ number of bytes to copy 311 -> m () 312copyPrimArray (MutablePrimArray dst#) (I# doff#) (PrimArray src#) (I# soff#) (I# n#) 313 = primitive_ (copyByteArray# 314 src# 315 (soff# *# (sizeOf# (undefined :: a))) 316 dst# 317 (doff# *# (sizeOf# (undefined :: a))) 318 (n# *# (sizeOf# (undefined :: a))) 319 ) 320 321setPrimArray 322 :: (Prim a, PrimMonad m) 323 => MutablePrimArray (PrimState m) a -- ^ array to fill 324 -> Int -- ^ offset into array 325 -> Int -- ^ number of values to fill 326 -> a -- ^ value to fill with 327 -> m () 328setPrimArray (MutablePrimArray dst#) (I# doff#) (I# sz#) x 329 = primitive_ (P.setByteArray# dst# doff# sz# x) 330 331primArrayFromList :: Prim a => [a] -> PrimArray a 332primArrayFromList xs = primArrayFromListN (L.length xs) xs 333 334primArrayFromListN :: forall a. Prim a => Int -> [a] -> PrimArray a 335primArrayFromListN len vs = runST run where 336 run :: forall s. ST s (PrimArray a) 337 run = do 338 arr <- newPrimArray len 339 let go :: [a] -> Int -> ST s () 340 go !xs !ix = case xs of 341 [] -> return () 342 a : as -> do 343 writePrimArray arr ix a 344 go as (ix + 1) 345 go vs 0 346 unsafeFreezePrimArray arr 347 348primArrayToList :: forall a. Prim a => PrimArray a -> [a] 349primArrayToList arr = go 0 where 350 !len = sizeofPrimArray arr 351 go :: Int -> [a] 352 go !ix = if ix < len 353 then indexPrimArray arr ix : go (ix + 1) 354 else [] 355 356#if MIN_VERSION_base(4,7,0) 357primListByteArray :: forall a. (Prim a, Eq a, Arbitrary a, Show a) => Proxy a -> Property 358primListByteArray _ = property $ \(as :: [a]) -> 359 as == toList (fromList as :: PrimArray a) 360#endif 361 362 363internalDefaultSetPrimArray :: Prim a 364 => MutablePrimArray s a -> Int -> Int -> a -> ST s () 365internalDefaultSetPrimArray (MutablePrimArray arr) (I# i) (I# len) ident = 366 primitive_ (internalDefaultSetByteArray# arr i len ident) 367 368internalDefaultSetByteArray# :: Prim a 369 => MutableByteArray# s -> Int# -> Int# -> a -> State# s -> State# s 370internalDefaultSetByteArray# arr# i# len# ident = go 0# 371 where 372 go ix# s0 = if isTrue# (ix# <# len#) 373 then case writeByteArray# arr# (i# +# ix#) ident s0 of 374 s1 -> go (ix# +# 1#) s1 375 else s0 376 377internalDefaultSetOffAddr :: Prim a => Ptr a -> Int -> Int -> a -> IO () 378internalDefaultSetOffAddr (Ptr addr) (I# ix) (I# len) a = primitive_ 379 (internalDefaultSetOffAddr# addr ix len a) 380 381internalDefaultSetOffAddr# :: Prim a => Addr# -> Int# -> Int# -> a -> State# s -> State# s 382internalDefaultSetOffAddr# addr# i# len# ident = go 0# 383 where 384 go ix# s0 = if isTrue# (ix# <# len#) 385 then case writeOffAddr# addr# (i# +# ix#) ident s0 of 386 s1 -> go (ix# +# 1#) s1 387 else s0 388