1# 2# This file is part of pySMT. 3# 4# Copyright 2014 Andrea Micheli and Marco Gario 5# 6# Licensed under the Apache License, Version 2.0 (the "License"); 7# you may not use this file except in compliance with the License. 8# You may obtain a copy of the License at 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, software 13# distributed under the License is distributed on an "AS IS" BASIS, 14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15# See the License for the specific language governing permissions and 16# limitations under the License. 17# 18import os 19from functools import wraps 20 21try: 22 import unittest2 as unittest 23except ImportError: 24 import unittest 25 26from pysmt.environment import get_env, reset_env 27skipIf = unittest.skipIf 28 29 30class TestCase(unittest.TestCase): 31 """Wrapper on the unittest TestCase class. 32 33 This class provides setUp and tearDown methods for pySMT in which 34 a fresh environment is provided for each test. 35 """ 36 37 def setUp(self): 38 self.env = reset_env() 39 40 def tearDown(self): 41 pass 42 43 if "assertRaisesRegex" not in dir(unittest.TestCase): 44 assertRaisesRegex = unittest.TestCase.assertRaisesRegexp 45 46 47 def assertValid(self, formula, msg=None, solver_name=None, logic=None): 48 """Assert that formula is VALID.""" 49 self.assertTrue(self.env.factory.is_valid(formula=formula, 50 solver_name=solver_name, 51 logic=logic), 52 msg=msg) 53 54 def assertSat(self, formula, msg=None, solver_name=None, logic=None): 55 """Assert that formula is SAT.""" 56 self.assertTrue(self.env.factory.is_sat(formula=formula, 57 solver_name=solver_name, 58 logic=logic), 59 msg=msg) 60 61 def assertUnsat(self, formula, msg=None, solver_name=None, logic=None): 62 """Assert that formula is UNSAT.""" 63 self.assertTrue(self.env.factory.is_unsat(formula=formula, 64 solver_name=solver_name, 65 logic=logic), 66 msg=msg) 67 68 69class skipIfSolverNotAvailable(object): 70 """Skip a test if the given solver is not available.""" 71 72 def __init__(self, solver): 73 self.solver = solver 74 75 def __call__(self, test_fun): 76 msg = "%s not available" % self.solver 77 cond = self.solver not in get_env().factory.all_solvers() 78 @unittest.skipIf(cond, msg) 79 @wraps(test_fun) 80 def wrapper(*args, **kwargs): 81 return test_fun(*args, **kwargs) 82 return wrapper 83 84class skipIfQENotAvailable(object): 85 """Skip a test if the given solver does not support quantifier elimination.""" 86 87 def __init__(self, qe): 88 self.qe = qe 89 90 def __call__(self, test_fun): 91 msg = "Quantifier Eliminator %s not available" % self.qe 92 cond = self.qe not in get_env().factory.all_quantifier_eliminators() 93 @unittest.skipIf(cond, msg) 94 @wraps(test_fun) 95 def wrapper(*args, **kwargs): 96 return test_fun(*args, **kwargs) 97 return wrapper 98 99 100class skipIfNoSolverForLogic(object): 101 """Skip a test if there is no solver for the given logic.""" 102 103 def __init__(self, logic): 104 self.logic = logic 105 106 def __call__(self, test_fun): 107 msg = "Solver for %s not available" % self.logic 108 cond = not get_env().factory.has_solvers(logic=self.logic) 109 @unittest.skipIf(cond, msg) 110 @wraps(test_fun) 111 def wrapper(*args, **kwargs): 112 return test_fun(*args, **kwargs) 113 return wrapper 114 115 116class skipIfNoUnsatCoreSolverForLogic(object): 117 """Skip a test if there is no solver for the given logic.""" 118 119 def __init__(self, logic): 120 self.logic = logic 121 122 def __call__(self, test_fun): 123 msg = "Unsat Core Solver for %s not available" % self.logic 124 cond = len(get_env().factory.all_unsat_core_solvers(logic=self.logic)) == 0 125 @unittest.skipIf(cond, msg) 126 @wraps(test_fun) 127 def wrapper(*args, **kwargs): 128 return test_fun(*args, **kwargs) 129 return wrapper 130 131 132class skipIfNoQEForLogic(object): 133 """Skip a test if there is no quantifier eliminator for the given logic.""" 134 135 def __init__(self, logic): 136 self.logic = logic 137 138 def __call__(self, test_fun): 139 msg = "Quantifier Eliminator for %s not available" % self.logic 140 cond = len(get_env().factory.all_quantifier_eliminators(logic=self.logic)) == 0 141 @unittest.skipIf(cond, msg) 142 @wraps(test_fun) 143 def wrapper(*args, **kwargs): 144 return test_fun(*args, **kwargs) 145 return wrapper 146 147 148# Export a main function 149main = unittest.main 150 151# Export SkipTest 152SkipTest = unittest.SkipTest 153