1module Fetch (tests) where
2
3-- tests for our fetch-and-* family of functions.
4import Control.Monad
5import System.Random
6import Test.Framework.Providers.HUnit (testCase)
7import Test.Framework (Test)
8import Test.HUnit (assertEqual,assertBool)
9import Data.Primitive
10import Data.List
11import Data.Bits
12import Data.Atomics
13import Control.Monad.Primitive
14import Control.Concurrent
15
16tests :: [Test]
17tests = [
18      testCase "Fetch-and-* operations return previous value" case_return_previous
19    , testCase "Fetch-and-* operations behave like their corresponding bitwise operators" case_like_bitwise
20    , testCase "fetchAndIntArray and fetchOrIntArray are atomic"  $ fetchAndOrTest  10000000
21    , testCase "fetchNandIntArray atomic"                         $ fetchNandTest   1000000
22    , testCase "fetchAddIntArray and fetchSubIntArray are atomic" $ fetchAddSubTest 10000000
23    , testCase "fetchXorIntArray is atomic"                       $ fetchXorTest    10000000
24    ]
25
26nand :: Bits a => a -> a -> a
27nand x y = complement (x .&. y)
28
29fetchOps :: [( String
30            ,  MutableByteArray RealWorld -> Int -> Int -> IO Int
31            ,  Int -> Int -> Int )]
32fetchOps = [
33   ("Add",  fetchAddIntArray,  (+)),
34   ("Sub",  fetchSubIntArray,  (-)),
35   ("And",  fetchAndIntArray,  (.&.)),
36   ("Nand", fetchNandIntArray, nand),
37   ("Or",   fetchOrIntArray,   (.|.)),
38   ("Xor",  fetchXorIntArray,  xor)
39   ]
40
41
42-- Test all operations at once, somewhat randomly, ensuring they behave like
43-- their corresponding bitwise operator; we compose a few operations before
44-- inspecting the intermediate result, and spread them randomly around a small
45-- array.
46-- TODO use quickcheck if we want
47case_like_bitwise :: IO ()
48case_like_bitwise = do
49    let opGroupSize = 5
50    let grp n = go n []
51          where go _ stck [] = [stck]
52                go 0 stck xs = stck : go n [] xs
53                go i stck (x:xs) = go (i-1) (x:stck) xs
54    -- Inf list of different short sequences of bitwise operations:
55    let opGroups = grp opGroupSize $ cycle $ concat $ permutations fetchOps
56
57    let size = 4
58    randIxs <- randomRs (0, size-1) <$> newStdGen
59    randArgs <- grp opGroupSize . randoms <$> newStdGen
60
61    a <- newByteArray (sizeOf (undefined::Int) * size)
62    forM_ [0.. size-1] $ \ix-> writeByteArray a ix (0::Int)
63
64    forM_ (take 1000000 $ zip randIxs $ zipWith zip opGroups randArgs) $
65        \ (ix, opsArgs)-> do
66            assertEqual "test not b0rken" (length opsArgs) opGroupSize
67
68            let doOpGroups pureLHS [] = return pureLHS
69                doOpGroups pureLHS (((_,atomicOp,op), v) : rest) = do
70                    atomicOp a ix v >> doOpGroups (pureLHS `op` v) rest
71
72            vInitial <- readByteArray a ix
73            vFinalPure <- doOpGroups vInitial opsArgs
74            vFinal <- readByteArray a ix
75
76            let nmsArgs = map (\ ((nm,_,_),v) -> (nm,v)) opsArgs
77            assertEqual ("sequence on initial value "++(show vInitial)
78                          ++" of ops with RHS args: "++(show nmsArgs)
79                          ++" gives same result in both pure and atomic op"
80                        ) vFinal vFinalPure
81
82
83
84-- check all operations return the value before the operation was applied;
85-- basic smoke test, with each op tested individually.
86case_return_previous :: IO ()
87case_return_previous = do
88    let l = length fetchOps
89    a <- newByteArray (sizeOf (undefined::Int) * l)
90    let randomInts = take l . randoms <$> newStdGen :: IO [Int]
91    initial <- randomInts
92    forM_ (zip [0..] initial) $ \(ix, v)-> writeByteArray a ix v
93
94    args <- randomInts
95    forM_ (zip4 [0..] initial args fetchOps) $ \(ix, pre, v, (nm,atomicOp,op))-> do
96        pre' <- atomicOp a ix v
97        assertEqual (fetchStr nm "returned previous value") pre pre'
98        let post = pre `op` v
99        post' <- readByteArray a ix
100        assertEqual (fetchStrArgVal nm v pre "operation was seen correctly on read") post post'
101
102fetchStr :: String -> String -> String
103fetchStr nm = (("fetch"++nm++"IntArray: ")++)
104fetchStrArgVal :: (Show a, Show a1) => String -> a -> a1 -> String -> String
105fetchStrArgVal nm v initial = (("fetch"++nm++"IntArray, with arg "++(show v)++" on value "++(show initial)++": ")++)
106
107-- ----------------------------------------------------------------------------
108-- Tests of atomicity:
109
110
111-- Concurrently run a sequence of AND and OR simultaneously on separate parts
112-- of the bit range of an Int.
113fetchAndOrTest :: Int -> IO ()
114fetchAndOrTest iters = do
115    out0 <- newEmptyMVar
116    out1 <- newEmptyMVar
117    mba <- newByteArray (sizeOf (undefined :: Int))
118    let andLowersBit , orRaisesBit :: Int -> Int
119        andLowersBit = clearBit (complement 0)
120        orRaisesBit = setBit 0
121    writeByteArray mba 0 (0 :: Int)
122    -- thread 1 toggles bit 0, thread 2 toggles bit 1; then we verify results
123    -- in the main thread.
124    let go v b = do
125            -- Avoid stack overflow on GHC 7.6:
126            let replicateMrev l 0 = putMVar v l
127                replicateMrev l iter = do
128                       low <- fetchOrIntArray mba 0 (orRaisesBit b)
129                       high <- fetchAndIntArray mba 0 (andLowersBit b)
130                       replicateMrev ((low,high):l) (iter-1)
131             in replicateMrev [] iters
132    void $ forkIO $ go out0 0
133    void $ forkIO $ go out1 1
134    res0 <- takeMVar out0
135    res1 <- takeMVar out1
136    let check b = all ( \(low,high)-> (not $ testBit low b) && testBit high b)
137
138    assertBool "fetchAndOrTest not broken" $ length (res0++res1) == iters*2
139    assertBool "fetchAndOrTest thread1" $ check 0 res0
140    assertBool "fetchAndOrTest thread2" $ check 1 res1
141
142-- Nand of 1 is a bit complement. Concurrently run two threads running an even
143-- number of complements in this way and verify the final value is unchanged.
144-- TODO think of a more clever test
145fetchNandTest :: Int -> IO ()
146fetchNandTest iters = do
147    let nandComplements = complement 0
148        dblComplement mba = replicateM_ (2 * iters) $
149            fetchNandIntArray mba 0 nandComplements
150    randomInts <- take 10 . randoms <$> newStdGen :: IO [Int]
151    forM_ randomInts $ \ initial -> do
152        final <- race initial dblComplement dblComplement
153        assertEqual "fetchNandTest" initial final
154
155
156-- ----------------------------------------------------------------------------
157-- Code below copied with minor modifications from GHC
158-- testsuite/tests/concurrent/should_run/AtomicPrimops.hs @ f293931
159-- ----------------------------------------------------------------------------
160
161
162-- | Test fetchAddIntArray# by having two threads concurrenctly
163-- increment a counter and then checking the sum at the end.
164fetchAddSubTest :: Int -> IO ()
165fetchAddSubTest iters = do
166    tot <- race 0
167        (\ mba -> work fetchAddIntArray mba iters 2)
168        (\ mba -> work fetchSubIntArray mba iters 1)
169    assertEqual "fetchAddSubTest" iters tot
170  where
171    work :: (MutableByteArray RealWorld -> Int -> Int -> IO Int) -> MutableByteArray RealWorld -> Int -> Int
172         -> IO ()
173    work _ _    0 _ = return ()
174    work op mba n val = op mba 0 val >> work op mba (n-1) val
175
176-- | Test fetchXorIntArray# by having two threads concurrenctly XORing
177-- and then checking the result at the end. Works since XOR is
178-- commutative.
179--
180-- Covers the code paths for AND, NAND, and OR as well.
181fetchXorTest :: Int -> IO ()
182fetchXorTest iters = do
183    res <- race n0
184        (\ mba -> work mba iters t1pat)
185        (\ mba -> work mba iters t2pat)
186    assertEqual "fetchXorTest" expected res
187  where
188    work :: MutableByteArray RealWorld -> Int -> Int -> IO ()
189    work _   0 _ = return ()
190    work mba n val = fetchXorIntArray mba 0 val >> work mba (n-1) val
191
192    -- Initial value is a large prime and the two patterns are 1010...
193    -- and 0101...
194    (n0, t1pat, t2pat)
195        -- TODO: If we want to silence warnings from here, use CPP conditional
196        --       on arch x86_64
197        | sizeOf (undefined :: Int) == 8 =
198            (0x00000000ffffffff, 0x5555555555555555, 0x9999999999999999)
199        | otherwise = (0x0000ffff, 0x55555555, 0x99999999)
200    expected
201        | sizeOf (undefined :: Int) == 8 = 4294967295
202        | otherwise = 65535
203
204-- | Create two threads that mutate the byte array passed to them
205-- concurrently. The array is one word large.
206race :: Int                    -- ^ Initial value of array element
207     -> (MutableByteArray RealWorld -> IO ())  -- ^ Thread 1 action
208     -> (MutableByteArray RealWorld -> IO ())  -- ^ Thread 2 action
209     -> IO Int                 -- ^ Final value of array element
210race n0 thread1 thread2 = do
211    done1 <- newEmptyMVar
212    done2 <- newEmptyMVar
213    mba <- newByteArray (sizeOf (undefined :: Int))
214    writeByteArray mba 0 n0
215    void $ forkIO $ thread1 mba >> putMVar done1 ()
216    void $ forkIO $ thread2 mba >> putMVar done2 ()
217    mapM_ takeMVar [done1, done2]
218    readByteArray mba 0
219