1# -*- coding: utf-8 -*-
2#
3# CryptoMiniSat
4#
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
11#
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21# THE SOFTWARE.
22
23from __future__ import unicode_literals
24from __future__ import print_function
25from array import array as _array
26import sys
27import unittest
28import time
29
30
31import pycryptosat
32from pycryptosat import Solver
33
34
35def array(typecode, initializer=()):
36    return _array(str(typecode), initializer)
37
38
39def check_clause(clause, solution):
40    for lit in clause:
41        var = abs(lit)
42        if lit < 0:
43            inverted = True
44        else:
45            inverted = False
46
47        if solution[var] != inverted:
48            return True
49
50
51def check_solution(clauses, solution):
52    for clause in clauses:
53        if check_clause(clause, solution) is False:
54            return False
55
56    return True
57
58# -------------------------- test clauses --------------------------------
59
60# p cnf 5 3
61# 1 -5 4 0
62# -1 5 3 4 0
63# -3 -4 0
64clauses1 = [[1, -5, 4], [-1, 5, 3, 4], [-3, -4]]
65
66# p cnf 2 2
67# -1 0
68# 1 0
69clauses2 = [[-1], [1]]
70
71# p cnf 2 3
72# -1 2 0
73# -1 -2 0
74# 1 -2 0
75clauses3 = [[-1, 2], [-1, -2], [1, -2]]
76
77# -------------------------- actual unit tests ---------------------------
78
79
80class TestXor(unittest.TestCase):
81
82    def setUp(self):
83        self.solver = Solver(threads=2)
84
85    def test_wrong_args(self):
86        self.assertRaises(TypeError, self.solver.add_xor_clause, [1, 2])
87        self.assertRaises(ValueError, self.solver.add_xor_clause, [1, 0], True)
88        self.assertRaises(
89            ValueError, self.solver.add_xor_clause, [-1, 2], True)
90
91    def test_binary(self):
92        self.solver.add_xor_clause([1, 2], False)
93        res, solution = self.solver.solve([1])
94        self.assertEqual(res, True)
95        self.assertEqual(solution, (None, True, True))
96
97    def test_unit(self):
98        self.solver.add_xor_clause([1], False)
99        res, solution = self.solver.solve()
100        self.assertEqual(res, True)
101        self.assertEqual(solution, (None, False))
102
103    def test_unit2(self):
104        self.solver.add_xor_clause([1], True)
105        res, solution = self.solver.solve()
106        self.assertEqual(res, True)
107        self.assertEqual(solution, (None, True))
108
109    def test_3_long(self):
110        self.solver.add_xor_clause([1, 2, 3], False)
111        res, solution = self.solver.solve([1, 2])
112        self.assertEqual(res, True)
113        # self.assertEqual(solution, (None, True, True, False))
114
115    def test_3_long2(self):
116        self.solver.add_xor_clause([1, 2, 3], True)
117        res, solution = self.solver.solve([1, -2])
118        self.assertEqual(res, True)
119        self.assertEqual(solution, (None, True, False, False))
120
121    def test_long(self):
122        for l in range(10, 30):
123            self.setUp()
124            toadd = []
125            toassume = []
126            solution_expected = [None]
127            for i in range(1, l):
128                toadd.append(i)
129                solution_expected.append(False)
130                if i != l - 1:
131                    toassume.append(i * -1)
132
133            self.solver.add_xor_clause(toadd, False)
134            res, solution = self.solver.solve(toassume)
135            self.assertEqual(res, True)
136            self.assertEqual(solution, tuple(solution_expected))
137
138
139class InitTester(unittest.TestCase):
140
141    def test_wrong_args_to_solver(self):
142        self.assertRaises(ValueError, Solver, threads=-1)
143        self.assertRaises(ValueError, Solver, threads=0)
144        self.assertRaises(ValueError, Solver, verbose=-1)
145        self.assertRaises(ValueError, Solver, time_limit=-1)
146        self.assertRaises(ValueError, Solver, confl_limit=-1)
147        self.assertRaises(TypeError, Solver, threads="fail")
148        self.assertRaises(TypeError, Solver, verbose="fail")
149        self.assertRaises(TypeError, Solver, time_limit="fail")
150        self.assertRaises(TypeError, Solver, confl_limit="fail")
151
152
153class TestDump(unittest.TestCase):
154
155    def setUp(self):
156        self.solver = Solver()
157
158    def test_max_glue_missing(self):
159        self.assertRaises(TypeError,
160                          self.solver.start_getting_small_clauses, 4)
161
162    def test_one_dump(self):
163        with open("tests/test.cnf", "r") as x:
164            for line in x:
165                line = line.strip()
166                if "p" in line or "c" in line:
167                    continue
168
169                out = [int(x) for x in line.split()[:-1]]
170                self.solver.add_clause(out)
171
172        res, _ = self.solver.solve()
173        self.assertEqual(res, True)
174
175        self.solver.start_getting_small_clauses(4, max_glue=10)
176        x = self.solver.get_next_small_clause()
177        self.assertNotEquals(x, None)
178        self.solver.end_getting_small_clauses()
179
180
181class TestSolve(unittest.TestCase):
182
183    def setUp(self):
184        self.solver = Solver(threads=2)
185
186    def test_wrong_args(self):
187        self.assertRaises(TypeError, self.solver.add_clause, 'A')
188        self.assertRaises(TypeError, self.solver.add_clause, 1)
189        self.assertRaises(TypeError, self.solver.add_clause, 1.0)
190        self.assertRaises(TypeError, self.solver.add_clause, object())
191        self.assertRaises(TypeError, self.solver.add_clause, ['a'])
192        self.assertRaises(
193            TypeError, self.solver.add_clause, [[1, 2], [3, None]])
194        self.assertRaises(ValueError, self.solver.add_clause, [1, 0])
195
196    def test_no_clauses(self):
197        for _ in range(7):
198            self.assertEqual(self.solver.solve([]), (True, (None,)))
199
200    def test_cnf1(self):
201        for cl in clauses1:
202            self.solver.add_clause(cl)
203        res, solution = self.solver.solve()
204        self.assertEqual(res, True)
205        self.assertTrue(check_solution(clauses1, solution))
206
207    def test_add_clauses(self):
208        self.solver.add_clauses([[1], [-1]])
209        res, solution = self.solver.solve()
210        self.assertEqual(res, False)
211
212    def test_add_clauses_wrong_zero(self):
213        self.assertRaises(TypeError, self.solver.add_clause, [[1, 0], [-1]])
214
215    def test_add_clauses_array_SAT(self):
216        cls = array('i', [1, 2, 0, 1, 2, 0])
217        self.solver.add_clauses(cls)
218        res, solution = self.solver.solve()
219        self.assertEqual(res, True)
220
221    def test_add_clauses_array_UNSAT(self):
222        cls = array('i', [-1, 0, 1, 0])
223        self.solver.add_clauses(cls)
224        res, solution = self.solver.solve()
225        self.assertEqual(res, False)
226
227    def test_add_clauses_array_unterminated(self):
228        cls = array('i', [1, 2, 0, 1, 2])
229        self.assertRaises(ValueError, self.solver.add_clause, cls)
230
231    def test_bad_iter(self):
232        class Liar:
233
234            def __iter__(self):
235                return None
236        self.assertRaises(TypeError, self.solver.add_clause, Liar())
237
238    def test_get_conflict(self):
239        self.solver.add_clauses([[-1], [2], [3], [-4]])
240        assume = [-2, 3, 4]
241
242        res, model = self.solver.solve(assumptions=assume)
243        self.assertEqual(res, False)
244
245        confl = self.solver.get_conflict()
246        self.assertEqual(isinstance(confl, list), True)
247        self.assertNotIn(3, confl)
248
249        if 2 in confl:
250            self.assertIn(2, confl)
251        elif -4 in confl:
252            self.assertIn(-4, confl)
253        else:
254            self.assertEqual(False, True, msg="Either -2 or 4 should be conflicting!")
255
256        assume = [2, 4]
257        res, model = self.solver.solve(assumptions=assume)
258        self.assertEqual(res, False)
259
260        confl = self.solver.get_conflict()
261        self.assertEqual(isinstance(confl, list), True)
262        self.assertNotIn(2, confl)
263        self.assertIn(-4, confl)
264
265    def test_cnf2(self):
266        for cl in clauses2:
267            self.solver.add_clause(cl)
268        self.assertEqual(self.solver.solve(), (False, None))
269
270    def test_cnf3(self):
271        for cl in clauses3:
272            self.solver.add_clause(cl)
273        res, solution = self.solver.solve()
274        self.assertEqual(res, True)
275        self.assertTrue(check_solution(clauses3, solution))
276
277    def test_cnf1_confl_limit(self):
278        for _ in range(1, 20):
279            self.setUp()
280            for cl in clauses1:
281                self.solver.add_clause(cl)
282
283            res, solution = self.solver.solve()
284            self.assertTrue(res is None or check_solution(clauses1, solution))
285
286    def test_by_re_curse(self):
287        self.solver.add_clause([-1, -2, 3])
288        res, _ = self.solver.solve()
289        self.assertEqual(res, True)
290
291        self.solver.add_clause([-5, 1])
292        self.solver.add_clause([4, -3])
293        self.solver.add_clause([2, 3, 5])
294        res, _ = self.solver.solve()
295        self.assertEqual(res, True)
296
297
298class TestSolveTimeLimit(unittest.TestCase):
299
300    def get_clauses(self):
301        cls = []
302        with open("tests/f400-r425-x000.cnf", "r") as f:
303            for line in f:
304                line = line.strip()
305                if len(line) == 0:
306                    continue
307                if line[0] == "p":
308                    continue
309                if line[0] == "c":
310                    continue
311                line = line.split()
312                line = [int(l.strip()) for l in line]
313                assert line[-1] == 0
314                cls.append(line[:-1])
315
316        return cls
317
318
319    def test_time(self):
320        SAT_TIME_LIMIT = 1
321        clauses = self.get_clauses() #returns a few hundred short clauses
322        t0 = time.time()
323        solver = Solver(threads=4, time_limit=SAT_TIME_LIMIT)
324        solver.add_clauses(clauses)
325        sat, sol = solver.solve()
326        took_time = time.time() - t0
327
328        # NOTE: the above CNF solves in about 1 hour.
329        # So anything below 10min is good. Setting 2s would work... no most
330        # systems, but not on overloaded CI servers
331        self.assertLess(took_time, 4)
332
333# ------------------------------------------------------------------------
334
335
336def run():
337    print("sys.prefix: %s" % sys.prefix)
338    print("sys.version: %s" % sys.version)
339    try:
340        print("pycryptosat version: %r" % pycryptosat.__version__)
341    except AttributeError:
342        pass
343    suite = unittest.TestSuite()
344    suite.addTest(unittest.makeSuite(TestXor))
345    suite.addTest(unittest.makeSuite(InitTester))
346    suite.addTest(unittest.makeSuite(TestSolve))
347    suite.addTest(unittest.makeSuite(TestDump))
348    suite.addTest(unittest.makeSuite(TestSolveTimeLimit))
349
350    runner = unittest.TextTestRunner(verbosity=2)
351    result = runner.run(suite)
352
353    n_errors = len(result.errors)
354    n_failures = len(result.failures)
355
356    if n_errors or n_failures:
357        print('\n\nSummary: %d errors and %d failures reported\n'%\
358            (n_errors, n_failures))
359
360    print()
361
362    sys.exit(n_errors+n_failures)
363
364
365if __name__ == '__main__':
366    run()
367