1{-# LANGUAGE ForeignFunctionInterface, BangPatterns #-}
2module Throughput.Memory (memBench) where
3
4import Foreign
5import Foreign.C
6
7import Control.Exception
8import System.CPUTime
9import Numeric
10
11memBench :: Int -> IO ()
12memBench mb = do
13  let bytes = mb * 2^20
14  allocaBytes bytes $ \ptr -> do
15    let bench label test = do
16          seconds <- time $ test (castPtr ptr) (fromIntegral bytes)
17          let throughput = fromIntegral mb / seconds
18          putStrLn $ show mb ++ "MB of " ++ label
19                  ++ " in " ++ showFFloat (Just 3) seconds "s, at: "
20                  ++ showFFloat (Just 1) throughput "MB/s"
21    bench "setup        " c_wordwrite
22    putStrLn ""
23    putStrLn "C memory throughput benchmarks:"
24    bench "bytes written                     " c_bytewrite
25    bench "bytes read                        " c_byteread
26    bench "words written                     " c_wordwrite
27    bench "words read                        " c_wordread
28    putStrLn ""
29    putStrLn "Haskell memory throughput benchmarks:"
30    bench "bytes written                     " hs_bytewrite
31    bench "bytes written (loop unrolled once)" hs_bytewrite2
32    bench "bytes read                        " hs_byteread
33    bench "words written                     " hs_wordwrite
34    bench "words read                        " hs_wordread
35
36hs_bytewrite  :: Ptr CUChar -> Int -> IO ()
37hs_bytewrite !ptr bytes = loop 0 0
38  where iterations = bytes
39        loop :: Int -> CUChar -> IO ()
40        loop !i !n | i == iterations = return ()
41                   | otherwise = do pokeByteOff ptr i n
42                                    loop (i+1) (n+1)
43
44hs_bytewrite2  :: Ptr CUChar -> Int -> IO ()
45hs_bytewrite2 !start bytes = loop start 0
46  where end = start `plusPtr` bytes
47        loop :: Ptr CUChar -> CUChar -> IO ()
48        loop !ptr !n | ptr `plusPtr` 2 < end = do
49                         poke ptr               n
50                         poke (ptr `plusPtr` 1) (n+1)
51                         loop (ptr `plusPtr` 2) (n+2)
52                     | ptr `plusPtr` 1 < end =
53                         poke ptr               n
54                     | otherwise             = return ()
55
56hs_byteread  :: Ptr CUChar -> Int -> IO CUChar
57hs_byteread !ptr bytes = loop 0 0
58  where iterations = bytes
59        loop :: Int -> CUChar -> IO CUChar
60        loop !i !n | i == iterations = return n
61                   | otherwise = do x <- peekByteOff ptr i
62                                    loop (i+1) (n+x)
63
64hs_wordwrite :: Ptr CULong -> Int -> IO ()
65hs_wordwrite !ptr bytes = loop 0 0
66  where iterations = bytes `div` sizeOf (undefined :: CULong)
67        loop :: Int -> CULong -> IO ()
68        loop !i !n | i == iterations = return ()
69                   | otherwise = do pokeByteOff ptr i n
70                                    loop (i+1) (n+1)
71
72hs_wordread  :: Ptr CULong -> Int -> IO CULong
73hs_wordread !ptr bytes = loop 0 0
74  where iterations = bytes `div` sizeOf (undefined :: CULong)
75        loop :: Int -> CULong -> IO CULong
76        loop !i !n | i == iterations = return n
77                   | otherwise = do x <- peekByteOff ptr i
78                                    loop (i+1) (n+x)
79
80
81foreign import ccall unsafe "CBenchmark.h byteread"
82  c_byteread :: Ptr CUChar -> CInt -> IO ()
83
84foreign import ccall unsafe "CBenchmark.h bytewrite"
85  c_bytewrite :: Ptr CUChar -> CInt -> IO ()
86
87foreign import ccall unsafe "CBenchmark.h wordread"
88  c_wordread :: Ptr CUInt -> CInt -> IO ()
89
90foreign import ccall unsafe "CBenchmark.h wordwrite"
91  c_wordwrite :: Ptr CUInt -> CInt -> IO ()
92
93time :: IO a -> IO Double
94time action = do
95    start <- getCPUTime
96    action
97    end   <- getCPUTime
98    return $! (fromIntegral (end - start)) / (10^12)
99