1'''
2Tests for solving.
3'''
4from unittest import TestCase
5from typing import cast
6from clingo import Control, Function, Model, ModelType, SolveHandle, SolveResult, SymbolicAtom
7
8from . util import _MCB, _check_sat, _p
9
10class TestSolving(TestCase):
11    '''
12    Tests basic solving and related functions.
13    '''
14
15    def setUp(self):
16        self.mcb = _MCB()
17        self.mit = _MCB()
18        self.ctl = Control(['0'])
19
20    def tearDown(self):
21        self.mcb = None
22        self.mit = None
23        self.ctl = None
24
25    def test_solve_result_str(self):
26        '''
27        Test string representation of solve results.
28        '''
29        ret = self.ctl.solve()
30        self.assertEqual(str(ret), 'SAT')
31        self.assertRegex(repr(ret), 'SolveResult(.*)')
32
33    def test_model_str(self):
34        '''
35        Test string representation of models.
36        '''
37        self.ctl.add('base', [], 'a.')
38        self.ctl.ground([('base', [])])
39        with cast(SolveHandle, self.ctl.solve(yield_=True)) as hnd:
40            for mdl in hnd:
41                self.assertEqual(str(mdl), "a")
42                self.assertRegex(repr(mdl), "Model(.*)")
43
44    def test_solve_cb(self):
45        '''
46        Test solving using callback.
47        '''
48        self.ctl.add("base", [], "1 {a; b} 1. c.")
49        self.ctl.ground([("base", [])])
50        _check_sat(self, cast(SolveResult, self.ctl.solve(on_model=self.mcb.on_model, yield_=False, async_=False)))
51        self.assertEqual(self.mcb.models, _p(['a', 'c'], ['b', 'c']))
52        self.assertEqual(self.mcb.last[0], ModelType.StableModel)
53
54    def test_solve_async(self):
55        '''
56        Test asynchonous solving.
57        '''
58        self.ctl.add("base", [], "1 {a; b} 1. c.")
59        self.ctl.ground([("base", [])])
60        with cast(SolveHandle, self.ctl.solve(on_model=self.mcb.on_model, yield_=False, async_=True)) as hnd:
61            _check_sat(self, hnd.get())
62            self.assertEqual(self.mcb.models, _p(['a', 'c'], ['b', 'c']))
63
64    def test_solve_yield(self):
65        '''
66        Test solving yielding models.
67        '''
68        self.ctl.add("base", [], "1 {a; b} 1. c.")
69        self.ctl.ground([("base", [])])
70        with cast(SolveHandle, self.ctl.solve(on_model=self.mcb.on_model, yield_=True, async_=False)) as hnd:
71            for m in hnd:
72                self.mit.on_model(m)
73            _check_sat(self, hnd.get())
74            self.assertEqual(self.mcb.models, _p(['a', 'c'], ['b', 'c']))
75            self.assertEqual(self.mit.models, _p(['a', 'c'], ['b', 'c']))
76
77    def test_solve_async_yield(self):
78        '''
79        Test solving yielding models asynchronously.
80        '''
81        self.ctl.add("base", [], "1 {a; b} 1. c.")
82        self.ctl.ground([("base", [])])
83        with cast(SolveHandle, self.ctl.solve(on_model=self.mcb.on_model, yield_=True, async_=True)) as hnd:
84            while True:
85                hnd.resume()
86                _ = hnd.wait()
87                m = hnd.model()
88                if m is None:
89                    break
90                self.mit.on_model(m)
91            _check_sat(self, hnd.get())
92            self.assertEqual(self.mcb.models, _p(['a', 'c'], ['b', 'c']))
93            self.assertEqual(self.mit.models, _p(['a', 'c'], ['b', 'c']))
94
95    def test_solve_interrupt(self):
96        '''
97        Test interrupting solving.
98        '''
99        self.ctl.add("base", [], "1 { p(P,H): H=1..99 } 1 :- P=1..100.\n1 { p(P,H): P=1..100 } 1 :- H=1..99.")
100        self.ctl.ground([("base", [])])
101        with cast(SolveHandle, self.ctl.solve(async_=True)) as hnd:
102            hnd.resume()
103            hnd.cancel()
104            ret = hnd.get()
105            self.assertTrue(ret.interrupted)
106
107        with cast(SolveHandle, self.ctl.solve(async_=True)) as hnd:
108            hnd.resume()
109            self.ctl.interrupt()
110            ret = hnd.get()
111            self.assertTrue(ret.interrupted)
112
113    def test_solve_core(self):
114        '''
115        Test core retrieval.
116        '''
117        self.ctl.add("base", [], "3 { p(1..10) } 3.")
118        self.ctl.ground([("base", [])])
119        ass = []
120        for atom in self.ctl.symbolic_atoms.by_signature("p", 1):
121            ass.append(-atom.literal)
122        ret = cast(SolveResult, self.ctl.solve(on_core=self.mcb.on_core, assumptions=ass))
123        self.assertTrue(ret.unsatisfiable)
124        self.assertTrue(len(self.mcb.core) > 7)
125
126    def test_enum(self):
127        '''
128        Test core retrieval.
129        '''
130        self.ctl = Control(['0', '-e', 'cautious'])
131        self.ctl.add("base", [], "1 {a; b} 1. c.")
132        self.ctl.ground([("base", [])])
133        self.ctl.solve(on_model=self.mcb.on_model)
134        self.assertEqual(self.mcb.last[0], ModelType.CautiousConsequences)
135        self.assertEqual([self.mcb.last[1]], _p(['c']))
136
137        self.ctl = Control(['0', '-e', 'brave'])
138        self.ctl.add("base", [], "1 {a; b} 1. c.")
139        self.ctl.ground([("base", [])])
140        self.ctl.solve(on_model=self.mcb.on_model)
141        self.assertEqual(self.mcb.last[0], ModelType.BraveConsequences)
142        self.assertEqual([self.mcb.last[1]], _p(['a', 'b', 'c']))
143
144    def test_model(self):
145        '''
146        Test functions of model.
147        '''
148        def on_model(m: Model):
149            self.assertTrue(m.contains(Function('a')))
150            self.assertTrue(m.is_true(cast(SymbolicAtom, m.context.symbolic_atoms[Function('a')]).literal))
151            self.assertFalse(m.is_true(1000))
152            self.assertEqual(m.thread_id, 0)
153            self.assertEqual(m.number, 1)
154            self.assertFalse(m.optimality_proven)
155            self.assertEqual(m.cost, [3])
156            m.extend([Function('e')])
157            self.assertSequenceEqual(m.symbols(theory=True), [Function('e')])
158        self.ctl.add("base", [], "a. b. c. #minimize { 1,a:a; 1,b:b; 1,c:c }.")
159        self.ctl.ground([("base", [])])
160        self.ctl.solve(on_model=on_model)
161
162    def test_control_clause(self):
163        '''
164        Test adding clauses while solving.
165        '''
166        self.ctl.add("base", [], "1 {a; b; c} 1.")
167        self.ctl.ground([("base", [])])
168        with cast(SolveHandle, self.ctl.solve(on_model=self.mcb.on_model, yield_=True, async_=False)) as hnd:
169            for m in hnd:
170                clause = []
171                if m.contains(Function('a')):
172                    clause.append((Function('b'), False))
173                else:
174                    clause.append((Function('a'), False))
175                m.context.add_clause(clause)
176
177            _check_sat(self, hnd.get())
178            self.assertEqual(len(self.mcb.models), 2)
179
180    def test_control_nogood(self):
181        '''
182        Test adding nogoods while solving.
183        '''
184        self.ctl.add("base", [], "1 {a; b; c} 1.")
185        self.ctl.ground([("base", [])])
186        with cast(SolveHandle, self.ctl.solve(on_model=self.mcb.on_model, yield_=True, async_=False)) as hnd:
187            for m in hnd:
188                clause = []
189                if m.contains(Function('a')):
190                    clause.append((Function('b'), True))
191                else:
192                    clause.append((Function('a'), True))
193                m.context.add_nogood(clause)
194
195            _check_sat(self, hnd.get())
196            self.assertEqual(len(self.mcb.models), 2)
197