1#
2# This file is part of pySMT.
3#
4#   Copyright 2014 Andrea Micheli and Marco Gario
5#
6#   Licensed under the Apache License, Version 2.0 (the "License");
7#   you may not use this file except in compliance with the License.
8#   You may obtain a copy of the License at
9#
10#       http://www.apache.org/licenses/LICENSE-2.0
11#
12#   Unless required by applicable law or agreed to in writing, software
13#   distributed under the License is distributed on an "AS IS" BASIS,
14#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15#   See the License for the specific language governing permissions and
16#   limitations under the License.
17#
18import os
19from functools import wraps
20
21try:
22    import unittest2 as unittest
23except ImportError:
24    import unittest
25
26from pysmt.environment import get_env, reset_env
27skipIf = unittest.skipIf
28
29
30class TestCase(unittest.TestCase):
31    """Wrapper on the unittest TestCase class.
32
33    This class provides setUp and tearDown methods for pySMT in which
34    a fresh environment is provided for each test.
35    """
36
37    def setUp(self):
38        self.env = reset_env()
39
40    def tearDown(self):
41        pass
42
43    if "assertRaisesRegex" not in dir(unittest.TestCase):
44        assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
45
46
47    def assertValid(self, formula, msg=None, solver_name=None, logic=None):
48        """Assert that formula is VALID."""
49        self.assertTrue(self.env.factory.is_valid(formula=formula,
50                                                  solver_name=solver_name,
51                                                  logic=logic),
52                        msg=msg)
53
54    def assertSat(self, formula, msg=None, solver_name=None, logic=None):
55        """Assert that formula is SAT."""
56        self.assertTrue(self.env.factory.is_sat(formula=formula,
57                                                solver_name=solver_name,
58                                                logic=logic),
59                        msg=msg)
60
61    def assertUnsat(self, formula, msg=None, solver_name=None, logic=None):
62        """Assert that formula is UNSAT."""
63        self.assertTrue(self.env.factory.is_unsat(formula=formula,
64                                                  solver_name=solver_name,
65                                                  logic=logic),
66                        msg=msg)
67
68
69class skipIfSolverNotAvailable(object):
70    """Skip a test if the given solver is not available."""
71
72    def __init__(self, solver):
73        self.solver = solver
74
75    def __call__(self, test_fun):
76        msg = "%s not available" % self.solver
77        cond = self.solver not in get_env().factory.all_solvers()
78        @unittest.skipIf(cond, msg)
79        @wraps(test_fun)
80        def wrapper(*args, **kwargs):
81            return test_fun(*args, **kwargs)
82        return wrapper
83
84class skipIfQENotAvailable(object):
85    """Skip a test if the given solver does not support quantifier elimination."""
86
87    def __init__(self, qe):
88        self.qe = qe
89
90    def __call__(self, test_fun):
91        msg = "Quantifier Eliminator %s not available" % self.qe
92        cond = self.qe not in get_env().factory.all_quantifier_eliminators()
93        @unittest.skipIf(cond, msg)
94        @wraps(test_fun)
95        def wrapper(*args, **kwargs):
96            return test_fun(*args, **kwargs)
97        return wrapper
98
99
100class skipIfNoSolverForLogic(object):
101    """Skip a test if there is no solver for the given logic."""
102
103    def __init__(self, logic):
104        self.logic = logic
105
106    def __call__(self, test_fun):
107        msg = "Solver for %s not available" % self.logic
108        cond = not get_env().factory.has_solvers(logic=self.logic)
109        @unittest.skipIf(cond, msg)
110        @wraps(test_fun)
111        def wrapper(*args, **kwargs):
112            return test_fun(*args, **kwargs)
113        return wrapper
114
115
116class skipIfNoUnsatCoreSolverForLogic(object):
117    """Skip a test if there is no solver for the given logic."""
118
119    def __init__(self, logic):
120        self.logic = logic
121
122    def __call__(self, test_fun):
123        msg = "Unsat Core Solver for %s not available" % self.logic
124        cond = len(get_env().factory.all_unsat_core_solvers(logic=self.logic)) == 0
125        @unittest.skipIf(cond, msg)
126        @wraps(test_fun)
127        def wrapper(*args, **kwargs):
128            return test_fun(*args, **kwargs)
129        return wrapper
130
131
132class skipIfNoQEForLogic(object):
133    """Skip a test if there is no quantifier eliminator for the given logic."""
134
135    def __init__(self, logic):
136        self.logic = logic
137
138    def __call__(self, test_fun):
139        msg = "Quantifier Eliminator for %s not available" % self.logic
140        cond = len(get_env().factory.all_quantifier_eliminators(logic=self.logic)) == 0
141        @unittest.skipIf(cond, msg)
142        @wraps(test_fun)
143        def wrapper(*args, **kwargs):
144            return test_fun(*args, **kwargs)
145        return wrapper
146
147
148# Export a main function
149main = unittest.main
150
151# Export SkipTest
152SkipTest = unittest.SkipTest
153