1-- |
2-- Module      : Foundation.Timing
3-- License     : BSD-style
4-- Maintainer  : Foundation maintainers
5--
6-- An implementation of a timing framework
7--
8{-# LANGUAGE CPP #-}
9module Foundation.Timing
10    ( Timing(..)
11    , Measure(..)
12    , stopWatch
13    , measure
14    ) where
15
16import           Basement.Imports hiding (from)
17import           Basement.From (from)
18#if __GLASGOW_HASKELL__ < 802
19import           Basement.Cast (cast)
20#endif
21import           Basement.Monad
22-- import           Basement.UArray hiding (unsafeFreeze)
23import           Basement.UArray.Mutable (MUArray)
24import           Foundation.Collection
25import           Foundation.Time.Types
26import           Foundation.Numerical
27import           Foundation.Time.Bindings
28import           Control.Exception (evaluate)
29import           System.Mem (performGC)
30import           Data.Function (on)
31import qualified GHC.Stats as GHC
32
33
34data Timing = Timing
35    { timeDiff           :: !NanoSeconds
36    , timeBytesAllocated :: !(Maybe Word64)
37    }
38
39data Measure = Measure
40    { measurements :: UArray NanoSeconds
41    , iters        :: Word
42    }
43
44#if __GLASGOW_HASKELL__ >= 802
45type GCStats = GHC.RTSStats
46
47getGCStats :: IO (Maybe GCStats)
48getGCStats = do
49    r <- GHC.getRTSStatsEnabled
50    if r then pure Nothing else Just <$> GHC.getRTSStats
51
52diffGC :: Maybe GHC.RTSStats -> Maybe GHC.RTSStats -> Maybe Word64
53diffGC gc2 gc1 = ((-) `on` GHC.allocated_bytes) <$> gc2 <*> gc1
54#else
55type GCStats = GHC.GCStats
56
57getGCStats :: IO (Maybe GCStats)
58getGCStats = do
59    r <- GHC.getGCStatsEnabled
60    if r then pure Nothing else Just <$> GHC.getGCStats
61
62diffGC :: Maybe GHC.GCStats -> Maybe GHC.GCStats -> Maybe Word64
63diffGC gc2 gc1 = cast <$> (((-) `on` GHC.bytesAllocated) <$> gc2 <*> gc1)
64#endif
65
66-- | Simple one-time measurement of time & other metrics spent in a function
67stopWatch :: (a -> b) -> a -> IO Timing
68stopWatch f !a = do
69    performGC
70    gc1 <- getGCStats
71    (_, ns) <- measuringNanoSeconds (evaluate $ f a)
72    gc2 <- getGCStats
73    return $ Timing ns (diffGC gc2 gc1)
74
75-- | In depth timing & other metrics analysis of a function
76measure :: Word -> (a -> b) -> a -> IO Measure
77measure nbIters f a = do
78    d <- mutNew (from nbIters) :: IO (MUArray NanoSeconds (PrimState IO))
79    loop d 0
80    Measure <$> unsafeFreeze d
81            <*> pure nbIters
82  where
83    loop d !i
84        | i == nbIters = return ()
85        | otherwise    = do
86            (_, r) <- measuringNanoSeconds (evaluate $ f a)
87            mutUnsafeWrite d (from i) r
88            loop d (i+1)
89