1'''
2Tests for the backend/observer.
3'''
4
5from unittest import TestCase
6from typing import Sequence, Tuple
7from clingo import Control, Function, HeuristicType, Observer, Symbol, TruthValue
8
9class TestObserverBackend(Observer):
10    '''
11    Test Observer.
12    '''
13    def __init__(self, case):
14        self._case = case
15        self.called = set()
16
17    def init_program(self, incremental: bool) -> None:
18        self.called.add('init_program')
19
20    def begin_step(self) -> None:
21        self.called.add('begin_step')
22
23    def rule(self, choice: bool, head: Sequence[int], body: Sequence[int]) -> None:
24        self.called.add('rule')
25        self._case.assertTrue(choice)
26        self._case.assertEqual(head, [1])
27        self._case.assertEqual(body, [2, 3])
28
29    def weight_rule(self, choice: bool, head: Sequence[int], lower_bound: int,
30                    body: Sequence[Tuple[int,int]]) -> None:
31        self.called.add('weight_rule')
32        self._case.assertFalse(choice)
33        self._case.assertEqual(head, [2])
34        self._case.assertEqual(lower_bound, 1)
35        self._case.assertEqual(body, [(2, 3), (4, 5)])
36
37    def minimize(self, priority: int, literals: Sequence[Tuple[int,int]]) -> None:
38        self.called.add('minimize')
39        self._case.assertEqual(priority, 0)
40        self._case.assertEqual(literals, [(2, 3), (4, 5)])
41
42    def project(self, atoms: Sequence[int]) -> None:
43        self.called.add('project')
44        self._case.assertEqual(atoms, [2, 4])
45
46    def output_atom(self, symbol: Symbol, atom: int) -> None:
47        self.called.add('output_atom')
48        self._case.assertEqual(symbol, Function('a'))
49        self._case.assertEqual(atom, 2)
50
51    def external(self, atom: int, value: TruthValue) -> None:
52        self.called.add('external')
53        self._case.assertEqual(atom, 3)
54        self._case.assertEqual(value, TruthValue.Release)
55
56    def assume(self, literals: Sequence[int]) -> None:
57        self.called.add('assume')
58        self._case.assertEqual(literals, [2, 3])
59
60    def heuristic(self, atom: int, type_: HeuristicType, bias: int,
61                  priority: int, condition: Sequence[int]) -> None:
62        self.called.add('heuristic')
63        self._case.assertEqual(atom, 2)
64        self._case.assertEqual(type_, HeuristicType.Level)
65        self._case.assertEqual(bias, 5)
66        self._case.assertEqual(priority, 7)
67        self._case.assertEqual(condition, [1, 3])
68
69    def acyc_edge(self, node_u: int, node_v: int,
70                  condition: Sequence[int]) -> None:
71        self.called.add('acyc_edge')
72        self._case.assertEqual(node_u, 1)
73        self._case.assertEqual(node_v, 2)
74        self._case.assertEqual(condition, [3, 4])
75
76    def end_step(self) -> None:
77        self.called.add('end_step')
78
79class TestObserverTheory(Observer):
80    '''
81    Test Observer.
82    '''
83    def __init__(self, case):
84        self._case = case
85        self.called = set()
86
87    def output_term(self, symbol: Symbol, condition: Sequence[int]) -> None:
88        self.called.add('output_term')
89        self._case.assertEqual(symbol, Function('t'))
90        self._case.assertGreaterEqual(len(condition), 1)
91
92    def theory_term_number(self, term_id: int, number: int) -> None:
93        self.called.add('theory_term_number')
94        self._case.assertEqual(number, 1)
95
96    def theory_term_string(self, term_id : int, name : str) -> None:
97        self.called.add('theory_term_string')
98        self._case.assertEqual(name, "a")
99
100    def theory_term_compound(self, term_id: int, name_id_or_type: int,
101                             arguments: Sequence[int]) -> None:
102        self.called.add('theory_term_compound')
103        self._case.assertEqual(name_id_or_type, -1)
104        self._case.assertGreaterEqual(len(arguments), 2)
105
106    def theory_element(self, element_id: int, terms: Sequence[int],
107                       condition: Sequence[int]) -> None:
108        self.called.add('theory_element')
109        self._case.assertEqual(len(terms), 1)
110        self._case.assertEqual(len(condition), 2)
111
112    def theory_atom(self, atom_id_or_zero: int, term_id: int,
113                    elements: Sequence[int]) -> None:
114        self.called.add('theory_atom')
115        self._case.assertEqual(len(elements), 1)
116
117class TestObserverTheoryWithGuard(Observer):
118    '''
119    Test Observer.
120    '''
121    def __init__(self, case):
122        self._case = case
123        self.called = set()
124
125    def theory_term_string(self, term_id : int, name : str) -> None:
126        self.called.add(f'theory_term_string: {name}')
127
128    def theory_atom_with_guard(self, atom_id_or_zero: int, term_id: int,
129                               elements: Sequence[int], operator_id: int,
130                               right_hand_side_id: int) -> None:
131        self.called.add('theory_atom_with_guard')
132        self._case.assertEqual(len(elements), 0)
133
134class TestBackend(TestCase):
135    '''
136    Tests basic solving and related functions.
137    '''
138    def test_backend(self):
139        '''
140        Test backend via observer.
141        '''
142        ctl = Control()
143        obs = TestObserverBackend(self)
144        ctl.register_observer(obs)
145        with ctl.backend() as backend:
146            self.assertIn('init_program', obs.called)
147            self.assertIn('begin_step', obs.called)
148            backend.add_atom()
149            backend.add_atom(Function('a'))
150            backend.add_rule([1], [2, 3], True)
151            self.assertIn('rule', obs.called)
152            backend.add_weight_rule([2], 1, [(2, 3), (4, 5)])
153            self.assertIn('weight_rule', obs.called)
154            backend.add_minimize(0, [(2, 3), (4, 5)])
155            self.assertIn('minimize', obs.called)
156            backend.add_project([2, 4])
157            self.assertIn('project', obs.called)
158            backend.add_heuristic(2, HeuristicType.Level, 5, 7, [1, 3])
159            self.assertIn('heuristic', obs.called)
160            backend.add_assume([2, 3])
161            self.assertIn('assume', obs.called)
162            backend.add_acyc_edge(1, 2, [3, 4])
163            self.assertIn('acyc_edge', obs.called)
164            backend.add_external(3, TruthValue.Release)
165            self.assertIn('external', obs.called)
166        self.assertIn('output_atom', obs.called)
167        ctl.solve()
168        self.assertIn('end_step', obs.called)
169
170    def test_theory(self):
171        '''
172        Test observer via grounding.
173        '''
174        ctl = Control()
175        obs = TestObserverTheory(self)
176        ctl.register_observer(obs)
177        ctl.add('base', [], '''\
178        #theory test {
179            t { };
180            &a/0 : t, head
181        }.
182        {a; b}.
183        #show t : a, b.
184        &a { (1,a): a,b }.
185        ''')
186        ctl.ground([('base', [])])
187        self.assertIn('output_term', obs.called)
188        self.assertIn('theory_term_number', obs.called)
189        self.assertIn('theory_term_string', obs.called)
190        self.assertIn('theory_term_compound', obs.called)
191        self.assertIn('theory_element', obs.called)
192        self.assertIn('theory_atom', obs.called)
193        ctl.solve()
194
195    def test_theory_with_guard(self):
196        '''
197        Test observer via grounding.
198        '''
199        ctl = Control()
200        obs = TestObserverTheoryWithGuard(self)
201        ctl.register_observer(obs)
202        ctl.add('base', [], '''\
203        #theory test {
204            t { };
205            &a/0 : t, {=}, t, head
206        }.
207        &a { } = a.
208        ''')
209        ctl.ground([('base', [])])
210        self.assertIn('theory_term_string: a', obs.called)
211        self.assertIn('theory_term_string: =', obs.called)
212        self.assertIn('theory_atom_with_guard', obs.called)
213        ctl.solve()
214