1# -*- coding: utf-8 -*- 2# 3# CryptoMiniSat 4# 5# Permission is hereby granted, free of charge, to any person obtaining a copy 6# of this software and associated documentation files (the "Software"), to deal 7# in the Software without restriction, including without limitation the rights 8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9# copies of the Software, and to permit persons to whom the Software is 10# furnished to do so, subject to the following conditions: 11# 12# The above copyright notice and this permission notice shall be included in 13# all copies or substantial portions of the Software. 14# 15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21# THE SOFTWARE. 22 23from __future__ import unicode_literals 24from __future__ import print_function 25from array import array as _array 26import sys 27import unittest 28import time 29 30 31import pycryptosat 32from pycryptosat import Solver 33 34 35def array(typecode, initializer=()): 36 return _array(str(typecode), initializer) 37 38 39def check_clause(clause, solution): 40 for lit in clause: 41 var = abs(lit) 42 if lit < 0: 43 inverted = True 44 else: 45 inverted = False 46 47 if solution[var] != inverted: 48 return True 49 50 51def check_solution(clauses, solution): 52 for clause in clauses: 53 if check_clause(clause, solution) is False: 54 return False 55 56 return True 57 58# -------------------------- test clauses -------------------------------- 59 60# p cnf 5 3 61# 1 -5 4 0 62# -1 5 3 4 0 63# -3 -4 0 64clauses1 = [[1, -5, 4], [-1, 5, 3, 4], [-3, -4]] 65 66# p cnf 2 2 67# -1 0 68# 1 0 69clauses2 = [[-1], [1]] 70 71# p cnf 2 3 72# -1 2 0 73# -1 -2 0 74# 1 -2 0 75clauses3 = [[-1, 2], [-1, -2], [1, -2]] 76 77# -------------------------- actual unit tests --------------------------- 78 79 80class TestXor(unittest.TestCase): 81 82 def setUp(self): 83 self.solver = Solver(threads=2) 84 85 def test_wrong_args(self): 86 self.assertRaises(TypeError, self.solver.add_xor_clause, [1, 2]) 87 self.assertRaises(ValueError, self.solver.add_xor_clause, [1, 0], True) 88 self.assertRaises( 89 ValueError, self.solver.add_xor_clause, [-1, 2], True) 90 91 def test_binary(self): 92 self.solver.add_xor_clause([1, 2], False) 93 res, solution = self.solver.solve([1]) 94 self.assertEqual(res, True) 95 self.assertEqual(solution, (None, True, True)) 96 97 def test_unit(self): 98 self.solver.add_xor_clause([1], False) 99 res, solution = self.solver.solve() 100 self.assertEqual(res, True) 101 self.assertEqual(solution, (None, False)) 102 103 def test_unit2(self): 104 self.solver.add_xor_clause([1], True) 105 res, solution = self.solver.solve() 106 self.assertEqual(res, True) 107 self.assertEqual(solution, (None, True)) 108 109 def test_3_long(self): 110 self.solver.add_xor_clause([1, 2, 3], False) 111 res, solution = self.solver.solve([1, 2]) 112 self.assertEqual(res, True) 113 # self.assertEqual(solution, (None, True, True, False)) 114 115 def test_3_long2(self): 116 self.solver.add_xor_clause([1, 2, 3], True) 117 res, solution = self.solver.solve([1, -2]) 118 self.assertEqual(res, True) 119 self.assertEqual(solution, (None, True, False, False)) 120 121 def test_long(self): 122 for l in range(10, 30): 123 self.setUp() 124 toadd = [] 125 toassume = [] 126 solution_expected = [None] 127 for i in range(1, l): 128 toadd.append(i) 129 solution_expected.append(False) 130 if i != l - 1: 131 toassume.append(i * -1) 132 133 self.solver.add_xor_clause(toadd, False) 134 res, solution = self.solver.solve(toassume) 135 self.assertEqual(res, True) 136 self.assertEqual(solution, tuple(solution_expected)) 137 138 139class InitTester(unittest.TestCase): 140 141 def test_wrong_args_to_solver(self): 142 self.assertRaises(ValueError, Solver, threads=-1) 143 self.assertRaises(ValueError, Solver, threads=0) 144 self.assertRaises(ValueError, Solver, verbose=-1) 145 self.assertRaises(ValueError, Solver, time_limit=-1) 146 self.assertRaises(ValueError, Solver, confl_limit=-1) 147 self.assertRaises(TypeError, Solver, threads="fail") 148 self.assertRaises(TypeError, Solver, verbose="fail") 149 self.assertRaises(TypeError, Solver, time_limit="fail") 150 self.assertRaises(TypeError, Solver, confl_limit="fail") 151 152 153class TestDump(unittest.TestCase): 154 155 def setUp(self): 156 self.solver = Solver() 157 158 def test_max_glue_missing(self): 159 self.assertRaises(TypeError, 160 self.solver.start_getting_small_clauses, 4) 161 162 def test_one_dump(self): 163 with open("tests/test.cnf", "r") as x: 164 for line in x: 165 line = line.strip() 166 if "p" in line or "c" in line: 167 continue 168 169 out = [int(x) for x in line.split()[:-1]] 170 self.solver.add_clause(out) 171 172 res, _ = self.solver.solve() 173 self.assertEqual(res, True) 174 175 self.solver.start_getting_small_clauses(4, max_glue=10) 176 x = self.solver.get_next_small_clause() 177 self.assertNotEquals(x, None) 178 self.solver.end_getting_small_clauses() 179 180 181class TestSolve(unittest.TestCase): 182 183 def setUp(self): 184 self.solver = Solver(threads=2) 185 186 def test_wrong_args(self): 187 self.assertRaises(TypeError, self.solver.add_clause, 'A') 188 self.assertRaises(TypeError, self.solver.add_clause, 1) 189 self.assertRaises(TypeError, self.solver.add_clause, 1.0) 190 self.assertRaises(TypeError, self.solver.add_clause, object()) 191 self.assertRaises(TypeError, self.solver.add_clause, ['a']) 192 self.assertRaises( 193 TypeError, self.solver.add_clause, [[1, 2], [3, None]]) 194 self.assertRaises(ValueError, self.solver.add_clause, [1, 0]) 195 196 def test_no_clauses(self): 197 for _ in range(7): 198 self.assertEqual(self.solver.solve([]), (True, (None,))) 199 200 def test_cnf1(self): 201 for cl in clauses1: 202 self.solver.add_clause(cl) 203 res, solution = self.solver.solve() 204 self.assertEqual(res, True) 205 self.assertTrue(check_solution(clauses1, solution)) 206 207 def test_add_clauses(self): 208 self.solver.add_clauses([[1], [-1]]) 209 res, solution = self.solver.solve() 210 self.assertEqual(res, False) 211 212 def test_add_clauses_wrong_zero(self): 213 self.assertRaises(TypeError, self.solver.add_clause, [[1, 0], [-1]]) 214 215 def test_add_clauses_array_SAT(self): 216 cls = array('i', [1, 2, 0, 1, 2, 0]) 217 self.solver.add_clauses(cls) 218 res, solution = self.solver.solve() 219 self.assertEqual(res, True) 220 221 def test_add_clauses_array_UNSAT(self): 222 cls = array('i', [-1, 0, 1, 0]) 223 self.solver.add_clauses(cls) 224 res, solution = self.solver.solve() 225 self.assertEqual(res, False) 226 227 def test_add_clauses_array_unterminated(self): 228 cls = array('i', [1, 2, 0, 1, 2]) 229 self.assertRaises(ValueError, self.solver.add_clause, cls) 230 231 def test_bad_iter(self): 232 class Liar: 233 234 def __iter__(self): 235 return None 236 self.assertRaises(TypeError, self.solver.add_clause, Liar()) 237 238 def test_get_conflict(self): 239 self.solver.add_clauses([[-1], [2], [3], [-4]]) 240 assume = [-2, 3, 4] 241 242 res, model = self.solver.solve(assumptions=assume) 243 self.assertEqual(res, False) 244 245 confl = self.solver.get_conflict() 246 self.assertEqual(isinstance(confl, list), True) 247 self.assertNotIn(3, confl) 248 249 if 2 in confl: 250 self.assertIn(2, confl) 251 elif -4 in confl: 252 self.assertIn(-4, confl) 253 else: 254 self.assertEqual(False, True, msg="Either -2 or 4 should be conflicting!") 255 256 assume = [2, 4] 257 res, model = self.solver.solve(assumptions=assume) 258 self.assertEqual(res, False) 259 260 confl = self.solver.get_conflict() 261 self.assertEqual(isinstance(confl, list), True) 262 self.assertNotIn(2, confl) 263 self.assertIn(-4, confl) 264 265 def test_cnf2(self): 266 for cl in clauses2: 267 self.solver.add_clause(cl) 268 self.assertEqual(self.solver.solve(), (False, None)) 269 270 def test_cnf3(self): 271 for cl in clauses3: 272 self.solver.add_clause(cl) 273 res, solution = self.solver.solve() 274 self.assertEqual(res, True) 275 self.assertTrue(check_solution(clauses3, solution)) 276 277 def test_cnf1_confl_limit(self): 278 for _ in range(1, 20): 279 self.setUp() 280 for cl in clauses1: 281 self.solver.add_clause(cl) 282 283 res, solution = self.solver.solve() 284 self.assertTrue(res is None or check_solution(clauses1, solution)) 285 286 def test_by_re_curse(self): 287 self.solver.add_clause([-1, -2, 3]) 288 res, _ = self.solver.solve() 289 self.assertEqual(res, True) 290 291 self.solver.add_clause([-5, 1]) 292 self.solver.add_clause([4, -3]) 293 self.solver.add_clause([2, 3, 5]) 294 res, _ = self.solver.solve() 295 self.assertEqual(res, True) 296 297 298class TestSolveTimeLimit(unittest.TestCase): 299 300 def get_clauses(self): 301 cls = [] 302 with open("tests/f400-r425-x000.cnf", "r") as f: 303 for line in f: 304 line = line.strip() 305 if len(line) == 0: 306 continue 307 if line[0] == "p": 308 continue 309 if line[0] == "c": 310 continue 311 line = line.split() 312 line = [int(l.strip()) for l in line] 313 assert line[-1] == 0 314 cls.append(line[:-1]) 315 316 return cls 317 318 319 def test_time(self): 320 SAT_TIME_LIMIT = 1 321 clauses = self.get_clauses() #returns a few hundred short clauses 322 t0 = time.time() 323 solver = Solver(threads=4, time_limit=SAT_TIME_LIMIT) 324 solver.add_clauses(clauses) 325 sat, sol = solver.solve() 326 took_time = time.time() - t0 327 328 # NOTE: the above CNF solves in about 1 hour. 329 # So anything below 10min is good. Setting 2s would work... no most 330 # systems, but not on overloaded CI servers 331 self.assertLess(took_time, 4) 332 333# ------------------------------------------------------------------------ 334 335 336def run(): 337 print("sys.prefix: %s" % sys.prefix) 338 print("sys.version: %s" % sys.version) 339 try: 340 print("pycryptosat version: %r" % pycryptosat.__version__) 341 except AttributeError: 342 pass 343 suite = unittest.TestSuite() 344 suite.addTest(unittest.makeSuite(TestXor)) 345 suite.addTest(unittest.makeSuite(InitTester)) 346 suite.addTest(unittest.makeSuite(TestSolve)) 347 suite.addTest(unittest.makeSuite(TestDump)) 348 suite.addTest(unittest.makeSuite(TestSolveTimeLimit)) 349 350 runner = unittest.TextTestRunner(verbosity=2) 351 result = runner.run(suite) 352 353 n_errors = len(result.errors) 354 n_failures = len(result.failures) 355 356 if n_errors or n_failures: 357 print('\n\nSummary: %d errors and %d failures reported\n'%\ 358 (n_errors, n_failures)) 359 360 print() 361 362 sys.exit(n_errors+n_failures) 363 364 365if __name__ == '__main__': 366 run() 367