1import sys
2from typing import Deque, Dict, Set, List, Tuple
3from collections import deque
4
5from networkx import DiGraph # type: ignore
6from networkx.algorithms.components import strongly_connected_components # type: ignore
7from networkx.algorithms.dag import topological_sort # type: ignore
8
9from clingo.control import Control
10from clingo.symbol import Symbol
11from clingo.application import Application, clingo_main
12from clingox.program import Program, ProgramObserver, Rule
13
14Atom = int
15Literal = int
16RuleIndex = int
17
18
19def _analyze(rules: List[Rule]) -> List[List[Rule]]:
20    # build rule dependency graph
21    occ: Dict[Atom, Set[RuleIndex]] = {}
22    dep_graph = DiGraph()
23    for u, rule in enumerate(rules):
24        dep_graph.add_node(u)
25        for lit in rule.body:
26            occ.setdefault(abs(lit), set()).add(u)
27
28    for u, rule in enumerate(rules):
29        atm, = rule.head
30        for v in occ.get(atm, []):
31            dep_graph.add_edge(u, v)
32
33    sccs = list(strongly_connected_components(dep_graph))
34
35    # build scc dependency graph
36    # (this part only exists because the networkx library does not document the
37    # order of components; in principle, the tarjan algorithm guarentees a
38    # topological order)
39    scc_rule: Dict[RuleIndex, RuleIndex] = {}
40    scc_graph = DiGraph()
41    for i, scc in enumerate(sccs):
42        scc_graph.add_node(i)
43        for u in scc:
44            scc_rule[u] = i
45
46    for i, scc in enumerate(sccs):
47        for u in scc:
48            for v in dep_graph[u]:
49                j = scc_rule[u]
50                if i != j:
51                    scc_graph.add_edge(i, j)
52
53    return [[rules[j] for j in sccs[i]] for i in topological_sort(scc_graph)]
54
55
56def _well_founded(interpretation: Set[Literal], scc: List[Rule]) -> None:
57    watches: Dict[Literal, List[RuleIndex]] = {}
58    counters: Dict[RuleIndex, int] = {}
59    todo: List[Literal] = []
60    unfounded: List[Literal] = []
61    need_source: Set[Atom] = set()
62    has_source: Set[Atom] = set()
63    can_source: Dict[Atom, List[RuleIndex]] = {}
64    counters_source: Dict[RuleIndex, int] = dict()
65    todo_source: Deque[Atom] = deque()
66    is_source: Set[RuleIndex] = set()
67
68    def is_true(*args):
69        return all(lit in interpretation for lit in args)
70
71    def is_false(*args):
72        return any(-lit in interpretation for lit in args)
73
74    def is_supported(lit):
75        return not is_false(lit) and (lit < 0 or is_true(lit) or lit in has_source)
76
77    def enqueue_source(idx: RuleIndex):
78        atm, = scc[idx].head
79        if counters_source[idx] == 0 and atm not in has_source:
80            has_source.add(atm)
81            is_source.add(idx)
82            todo_source.append(atm)
83
84    def enqueue_lit(lit: Literal):
85        if lit not in interpretation:
86            interpretation.add(lit)
87            todo.append(lit)
88
89    # initialize the above data structures
90    for i, rule in enumerate(scc):
91        atm, = rule.head
92
93        if is_false(*rule.body) or is_true(atm):
94            continue
95
96        # initialize fact propagation
97        count = 0
98        for lit in rule.body:
99            if not is_true(lit):
100                count += 1
101                watches.setdefault(lit, []).append(i)
102        counters[i] = count
103        if count == 0:
104            enqueue_lit(atm)
105
106        # initialize source propagation
107        count = 0
108        for lit in rule.body:
109            if not is_supported(lit):
110                count += 1
111            if abs(lit) not in need_source:
112                need_source.add(abs(lit))
113                unfounded.append(-abs(lit))
114        counters_source[i] = count
115        enqueue_source(i)
116        can_source.setdefault(atm, []).append(i)
117
118    while todo or unfounded:
119        # forward propagate facts
120        idx = 0
121        while idx < len(todo):
122            lit = todo[idx]
123            idx += 1
124            for i in watches.get(lit, []):
125                counters[i] -= 1
126                if counters[i] == 0:
127                    enqueue_lit(*scc[i].head)
128
129        # remove sources
130        idx = 0
131        while idx < len(todo):
132            lit = todo[idx]
133            idx += 1
134            # Note that in this case, the literal already lost its source earlier
135            # and has already been made false at the end of the loop.
136            if lit < 0 and lit in interpretation:
137                continue
138            for i in watches.get(-lit, []):
139                counters_source[i] += 1
140                if i in is_source:
141                    atm, = scc[i].head
142                    is_source.remove(i)
143                    has_source.remove(atm)
144                    if -atm not in interpretation:
145                        todo.append(-atm)
146
147        # initialize sources
148        for lit in todo:
149            for i in can_source.get(-lit, []):
150                enqueue_source(i)
151
152        # forward propagate sources
153        while todo_source:
154            atm = todo_source.popleft()
155            for i in watches.get(atm, []):
156                counters_source[i] -= 1
157                enqueue_source(i)
158
159        # set literals without sources to false
160        if not unfounded:
161            unfounded, todo = todo, unfounded
162        todo.clear()
163        for lit in unfounded:
164            if lit < 0 and -lit in need_source and -lit not in has_source:
165                enqueue_lit(lit)
166        unfounded.clear()
167
168
169def well_founded(prg: Program) -> Tuple[List[Symbol], List[Symbol]]:
170    '''
171    Computes the well-founded model of the given program returning a pair of
172    facts and unknown atoms.
173
174    This function assumes that the program contains only normal rules.
175    '''
176    for rule in prg.rules:
177        if len(rule.head) != 1 or rule.choice:
178            raise RuntimeError('only normal rules are supported')
179    if prg.weight_rules:
180        raise RuntimeError('only normal rules are supported')
181
182    # analyze program and compute well-founded model
183    interpretation: Set[Literal] = set()
184    for scc in _analyze(prg.rules):
185        _well_founded(interpretation, scc)
186
187    # compute facts
188    fct = [atm.symbol for atm in prg.facts]
189    fct.extend(prg.output_atoms[lit] for lit in interpretation if lit > 0 and lit in prg.output_atoms)
190    # compute unknowns
191    ukn = set()
192    for rule in prg.rules:
193        atm, = rule.head
194        not_false = any(-lit in interpretation for lit in rule.body)
195        if atm not in interpretation and not not_false and atm in prg.output_atoms:
196            ukn.add(prg.output_atoms[atm])
197    return sorted(fct), sorted(ukn)
198
199
200class LevelApp(Application):
201    def __init__(self):
202        self.program_name = "level"
203        self.version = "1.0"
204
205    def main(self, ctl: Control, files):
206        prg = Program()
207        ctl.register_observer(ProgramObserver(prg))
208
209        for f in files:
210            ctl.load(f)
211        if not files:
212            ctl.load('-')
213
214        ctl.ground([("base", [])])
215
216        fct, ukn = well_founded(prg)
217        print('Facts:')
218        print(f'{" ".join(map(str, fct))}')
219        print('Unknown:')
220        print(f'{" ".join(map(str, ukn))}')
221
222        ctl.solve()
223
224
225sys.exit(clingo_main(LevelApp(), sys.argv[1:]))
226