1-----------------------------------------------------------------------------
2-- |
3-- Module    : Utils.SBVTestFramework
4-- Copyright : (c) Levent Erkok
5-- License   : BSD3
6-- Maintainer: erkokl@gmail.com
7-- Stability : experimental
8--
9-- Various goodies for testing SBV
10-----------------------------------------------------------------------------
11
12{-# LANGUAGE FlexibleContexts    #-}
13{-# LANGUAGE RankNTypes          #-}
14{-# LANGUAGE ScopedTypeVariables #-}
15
16{-# OPTIONS_GHC -Wall -Werror #-}
17
18module Utils.SBVTestFramework (
19          showsAs
20        , runSAT, numberOfModels
21        , assert, assertIsThm, assertIsntThm, assertIsSat, assertIsntSat
22        , goldenString
23        , goldenVsStringShow
24        , goldenCapturedIO
25        , CIOS(..), TestEnvironment(..), getTestEnvironment
26        , qc1, qc2
27        , pickTests
28        -- module exports to simplify life
29        , module Test.Tasty
30        , module Test.Tasty.HUnit
31        , module Data.SBV
32        ) where
33
34import qualified Control.Exception as C
35
36import Control.Monad.Trans (liftIO)
37
38import qualified Data.ByteString.Lazy.Char8 as LBC
39
40import System.Directory   (removeFile)
41import System.Environment (lookupEnv)
42
43import Test.Tasty            (testGroup, TestTree, TestName)
44import Test.Tasty.HUnit      ((@?), Assertion, testCase, AssertionPredicable)
45
46import Test.Tasty.Golden     (goldenVsString, goldenVsFileDiff)
47
48import qualified Test.Tasty.QuickCheck   as QC
49import qualified Test.QuickCheck.Monadic as QC
50
51import Test.Tasty.Runners hiding (Result)
52import System.Random (randomRIO)
53
54import Data.SBV
55import Data.SBV.Control
56
57import Data.Char  (isDigit)
58import Data.Maybe (fromMaybe, catMaybes)
59
60import System.FilePath ((</>), (<.>))
61
62import Data.SBV.Internals (runSymbolic, Result, SBVRunMode(..), IStage(..), SBV(..), SVal(..), showModel, SMTModel(..), QueryContext(..))
63
64---------------------------------------------------------------------------------------
65-- Test environment; continuous integration
66data CIOS = CILinux
67          | CIOSX
68          | CIWindows
69          deriving (Show, Eq)
70
71data TestEnvironment = TestEnvLocal
72                     | TestEnvCI CIOS
73                     | TestEnvUnknown
74                     deriving Show
75
76getTestEnvironment :: IO (TestEnvironment, Int)
77getTestEnvironment = do mbTestEnv  <- lookupEnv "SBV_TEST_ENVIRONMENT"
78                        mbTestPerc <- lookupEnv "SBV_HEAVYTEST_PERCENTAGE"
79
80                        env <- case mbTestEnv of
81                                 Just "local" -> return   TestEnvLocal
82                                 Just "linux" -> return $ TestEnvCI CILinux
83                                 Just "osx"   -> return $ TestEnvCI CIOSX
84                                 Just "win"   -> return $ TestEnvCI CIWindows
85                                 Just other   -> do putStrLn $ "Ignoring unexpected test env value: " ++ show other
86                                                    return TestEnvUnknown
87                                 Nothing      -> return TestEnvUnknown
88
89                        perc <- case mbTestPerc of
90                                 Just n | all isDigit n -> return (read n)
91                                 Just n                 -> do putStrLn $ "Ignoring unexpected test percentage value: " ++ show n
92                                                              return 100
93                                 Nothing                -> return 100
94
95                        return (env, perc)
96
97-- | Generic assertion. This is less safe than usual, but will do.
98assert :: AssertionPredicable t => t -> Assertion
99assert t = t @? "assertion-failure"
100
101-- | Checks that a particular result shows as @s@
102showsAs :: Show a => a -> String -> Assertion
103showsAs r s = assert $ show r == s
104
105goldFile :: FilePath -> FilePath
106goldFile nm = "SBVTestSuite" </> "GoldFiles" </> nm <.> "gold"
107
108goldenString :: TestName -> IO String -> TestTree
109goldenString n res = goldenVsString n (goldFile n) (fmap LBC.pack res)
110
111goldenVsStringShow :: Show a => TestName -> IO a -> TestTree
112goldenVsStringShow n res = goldenVsString n (goldFile n) (fmap (LBC.pack . show) res)
113
114goldenCapturedIO :: TestName -> (FilePath -> IO ()) -> TestTree
115goldenCapturedIO n res = goldenVsFileDiff n diff gf gfTmp (rm gfTmp >> res gfTmp)
116  where gf    = goldFile n
117        gfTmp = gf ++ "_temp"
118        rm f  = removeFile f `C.catch` (\(_ :: C.SomeException) -> return ())
119
120        diff ref new = ["diff", "-u", ref, new]
121
122-- | Count the number of models. It's not kosher to
123-- call this function if you provided a max-model count
124-- that was hit, or the search was stopped because the
125-- solver said 'Unknown' at some point.
126numberOfModels :: Provable a => a -> IO Int
127numberOfModels p = do AllSatResult { allSatMaxModelCountReached  = maxHit
128                                   , allSatSolverReturnedUnknown = unk
129                                   , allSatSolverReturnedDSat    = ds
130                                   , allSatResults               = rs
131                                   } <- allSat p
132                      let l = length rs
133                      case (unk, ds, maxHit) of
134                        (True, _, _)   -> error $ "Data.SBV.numberOfModels: Search was stopped because solver said 'Unknown'. At this point, we saw: " ++ show l ++ " model(s)."
135                        (_, True, _)   -> error $ "Data.SBV.numberOfModels: Search was stopped because solver returned 'delta satisfiable'. At this point, we saw: " ++ show l ++ " model(s)."
136                        (_,   _, True) -> error $ "Data.SBV.numberOfModels: Search was stopped because the user-specified max-model count was hit at " ++ show l ++ " model(s)."
137                        _              -> return l
138
139-- | Symbolically run a SAT instance using the default config
140runSAT :: Symbolic a -> IO Result
141runSAT cmp = snd <$> runSymbolic (SMTMode QueryInternal ISetup True defaultSMTCfg) cmp
142
143-- | Turn provable to an assertion, theorem case
144assertIsThm :: Provable a => a -> Assertion
145assertIsThm t = assert (isTheorem t)
146
147-- | Turn provable to a negative assertion, theorem case
148assertIsntThm :: Provable a => a -> Assertion
149assertIsntThm t = assert (fmap not (isTheorem t))
150
151-- | Turn provable to an assertion, satisfiability case
152assertIsSat :: Provable a => a -> Assertion
153assertIsSat p = assert (isSatisfiable p)
154
155-- | Turn provable to a negative assertion, satisfiability case
156assertIsntSat :: Provable a => a -> Assertion
157assertIsntSat p = assert (fmap not (isSatisfiable p))
158
159-- | Quick-check a unary function, creating one version for constant folding, and another for solver
160qc1 :: (Eq a, SymVal a, SymVal b, Show a, QC.Arbitrary a, Eq b) => String -> (a -> b) -> (SBV a -> SBV b) -> [TestTree]
161qc1 nm opC opS = [cf, sm]
162   where cf = QC.testProperty (nm ++ ".constantFold") $ do
163                        i <- free "i"
164
165                        let grab n = fromMaybe (error $ "qc1." ++ nm ++ ": Cannot extract value for: " ++ n) . unliteral
166
167                            v = grab "i" i
168
169                            expected = literal $ opC v
170                            result   = opS i
171
172                        case (unliteral expected, unliteral result) of
173                           (Just _, Just _) -> return $ expected .== result
174                           _                -> return sFalse
175
176         sm = QC.testProperty (nm ++ ".symbolic") $ QC.monadicIO $ do
177                        ((i, expected), result) <- QC.run $ runSMT $ do v   <- liftIO $ QC.generate QC.arbitrary
178                                                                        i   <- free_
179                                                                        res <- free_
180
181                                                                        constrain $ i   .== literal v
182                                                                        constrain $ res .== opS i
183
184                                                                        let pre = (v, opC v)
185
186                                                                        query $ do cs <- checkSat
187                                                                                   case cs of
188                                                                                     Unk    -> return (pre, Left "Unexpected: Solver responded Unknown!")
189                                                                                     Unsat  -> return (pre, Left "Unexpected: Solver responded Unsatisfiable!")
190                                                                                     DSat{} -> return (pre, Left "Unexpected: Solver responded Delta-satisfiable!")
191                                                                                     Sat    -> do r <- getValue res
192                                                                                                  return (pre, Right r)
193
194                        let getCV vnm (SBV (SVal _ (Left c))) = (vnm, c)
195                            getCV vnm (SBV (SVal k _       )) = error $ "qc2.getCV: Impossible happened, non-CV value while extracting: " ++ show (vnm, k)
196
197                            vals = [ getCV "i"        (literal i)
198                                   , getCV "Expected" (literal expected)
199                                   ]
200
201                            model = case result of
202                                      Right v -> showModel defaultSMTCfg (SMTModel [] Nothing (vals ++ [getCV "Result" (literal v)]) [])
203                                      Left  e -> showModel defaultSMTCfg (SMTModel [] Nothing vals []) ++ "\n" ++ e
204
205                        QC.monitor (QC.counterexample model)
206
207                        case result of
208                           Right a -> QC.assert $ expected == a
209                           _       -> QC.assert False
210
211
212-- | Quick-check a binary function, creating one version for constant folding, and another for solver
213qc2 :: (Eq a, Eq b, SymVal a, SymVal b, SymVal c, Show a, Show b, QC.Arbitrary a, QC.Arbitrary b, Eq c) => String -> (a -> b -> c) -> (SBV a -> SBV b -> SBV c) -> [TestTree]
214qc2 nm opC opS = [cf, sm]
215   where cf = QC.testProperty (nm ++ ".constantFold") $ do
216                        i1 <- free "i1"
217                        i2 <- free "i2"
218
219                        let grab n = fromMaybe (error $ "qc2." ++ nm ++ ": Cannot extract value for: " ++ n) . unliteral
220
221                            v1 = grab "i1" i1
222                            v2 = grab "i2" i2
223
224                            expected = literal $ opC v1 v2
225                            result   = opS i1 i2
226
227                        case (unliteral expected, unliteral result) of
228                           (Just _, Just _) -> return $ expected .== result
229                           _                -> return sFalse
230
231         sm = QC.testProperty (nm ++ ".symbolic") $ QC.monadicIO $ do
232                        ((i1, i2, expected), result) <- QC.run $ runSMT $ do v1  <- liftIO $ QC.generate QC.arbitrary
233                                                                             v2  <- liftIO $ QC.generate QC.arbitrary
234                                                                             i1  <- free_
235                                                                             i2  <- free_
236                                                                             res <- free_
237
238                                                                             constrain $ i1  .== literal v1
239                                                                             constrain $ i2  .== literal v2
240                                                                             constrain $ res .== i1 `opS` i2
241
242                                                                             let pre = (v1, v2, v1 `opC` v2)
243
244                                                                             query $ do cs <- checkSat
245                                                                                        case cs of
246                                                                                          Unk    -> return (pre, Left "Unexpected: Solver responded Unknown!")
247                                                                                          Unsat  -> return (pre, Left "Unexpected: Solver responded Unsatisfiable!")
248                                                                                          DSat{} -> return (pre, Left "Unexpected: Solver responded Delta-satisfiable!")
249                                                                                          Sat    -> do r <- getValue res
250                                                                                                       return (pre, Right r)
251
252                        let getCV vnm (SBV (SVal _ (Left c))) = (vnm, c)
253                            getCV vnm (SBV (SVal k _       )) = error $ "qc2.getCV: Impossible happened, non-CV value while extracting: " ++ show (vnm, k)
254
255                            vals = [ getCV "i1"       (literal i1)
256                                   , getCV "i2"       (literal i2)
257                                   , getCV "Expected" (literal expected)
258                                   ]
259
260                            model = case result of
261                                      Right v -> showModel defaultSMTCfg (SMTModel [] Nothing (vals ++ [getCV "Result" (literal v)]) [])
262                                      Left  e -> showModel defaultSMTCfg (SMTModel [] Nothing vals []) ++ "\n" ++ e
263
264                        QC.monitor (QC.counterexample model)
265
266                        case result of
267                           Right a -> QC.assert $ expected == a
268                           _       -> QC.assert False
269
270-- | Picking a certain percent of tests.
271pickTests :: Int -> TestTree -> IO TestTree
272pickTests d origTests = fromMaybe noTestsSelected <$> walk origTests
273   where noTestsSelected = TestGroup "pickTests.NoTestsSelected" []
274
275         walk t@SingleTest{}    = do c <- randomRIO (0, 99)
276                                     if c < d
277                                        then return $ Just t
278                                        else return Nothing
279         walk (TestGroup tn ts) = do cs <- catMaybes <$> mapM walk ts
280                                     case cs of
281                                       [] -> return Nothing
282                                       _  -> return $ Just $ TestGroup tn cs
283         walk _                 = error "pickTests: Unexpected test group!"
284
285{-# ANN module ("HLint: ignore Reduce duplication" :: String) #-}
286