1import sys 2from typing import Deque, Dict, Set, List, Tuple 3from collections import deque 4 5from networkx import DiGraph # type: ignore 6from networkx.algorithms.components import strongly_connected_components # type: ignore 7from networkx.algorithms.dag import topological_sort # type: ignore 8 9from clingo.control import Control 10from clingo.symbol import Symbol 11from clingo.application import Application, clingo_main 12from clingox.program import Program, ProgramObserver, Rule 13 14Atom = int 15Literal = int 16RuleIndex = int 17 18 19def _analyze(rules: List[Rule]) -> List[List[Rule]]: 20 # build rule dependency graph 21 occ: Dict[Atom, Set[RuleIndex]] = {} 22 dep_graph = DiGraph() 23 for u, rule in enumerate(rules): 24 dep_graph.add_node(u) 25 for lit in rule.body: 26 occ.setdefault(abs(lit), set()).add(u) 27 28 for u, rule in enumerate(rules): 29 atm, = rule.head 30 for v in occ.get(atm, []): 31 dep_graph.add_edge(u, v) 32 33 sccs = list(strongly_connected_components(dep_graph)) 34 35 # build scc dependency graph 36 # (this part only exists because the networkx library does not document the 37 # order of components; in principle, the tarjan algorithm guarentees a 38 # topological order) 39 scc_rule: Dict[RuleIndex, RuleIndex] = {} 40 scc_graph = DiGraph() 41 for i, scc in enumerate(sccs): 42 scc_graph.add_node(i) 43 for u in scc: 44 scc_rule[u] = i 45 46 for i, scc in enumerate(sccs): 47 for u in scc: 48 for v in dep_graph[u]: 49 j = scc_rule[u] 50 if i != j: 51 scc_graph.add_edge(i, j) 52 53 return [[rules[j] for j in sccs[i]] for i in topological_sort(scc_graph)] 54 55 56def _well_founded(interpretation: Set[Literal], scc: List[Rule]) -> None: 57 watches: Dict[Literal, List[RuleIndex]] = {} 58 counters: Dict[RuleIndex, int] = {} 59 todo: List[Literal] = [] 60 unfounded: List[Literal] = [] 61 need_source: Set[Atom] = set() 62 has_source: Set[Atom] = set() 63 can_source: Dict[Atom, List[RuleIndex]] = {} 64 counters_source: Dict[RuleIndex, int] = dict() 65 todo_source: Deque[Atom] = deque() 66 is_source: Set[RuleIndex] = set() 67 68 def is_true(*args): 69 return all(lit in interpretation for lit in args) 70 71 def is_false(*args): 72 return any(-lit in interpretation for lit in args) 73 74 def is_supported(lit): 75 return not is_false(lit) and (lit < 0 or is_true(lit) or lit in has_source) 76 77 def enqueue_source(idx: RuleIndex): 78 atm, = scc[idx].head 79 if counters_source[idx] == 0 and atm not in has_source: 80 has_source.add(atm) 81 is_source.add(idx) 82 todo_source.append(atm) 83 84 def enqueue_lit(lit: Literal): 85 if lit not in interpretation: 86 interpretation.add(lit) 87 todo.append(lit) 88 89 # initialize the above data structures 90 for i, rule in enumerate(scc): 91 atm, = rule.head 92 93 if is_false(*rule.body) or is_true(atm): 94 continue 95 96 # initialize fact propagation 97 count = 0 98 for lit in rule.body: 99 if not is_true(lit): 100 count += 1 101 watches.setdefault(lit, []).append(i) 102 counters[i] = count 103 if count == 0: 104 enqueue_lit(atm) 105 106 # initialize source propagation 107 count = 0 108 for lit in rule.body: 109 if not is_supported(lit): 110 count += 1 111 if abs(lit) not in need_source: 112 need_source.add(abs(lit)) 113 unfounded.append(-abs(lit)) 114 counters_source[i] = count 115 enqueue_source(i) 116 can_source.setdefault(atm, []).append(i) 117 118 while todo or unfounded: 119 # forward propagate facts 120 idx = 0 121 while idx < len(todo): 122 lit = todo[idx] 123 idx += 1 124 for i in watches.get(lit, []): 125 counters[i] -= 1 126 if counters[i] == 0: 127 enqueue_lit(*scc[i].head) 128 129 # remove sources 130 idx = 0 131 while idx < len(todo): 132 lit = todo[idx] 133 idx += 1 134 # Note that in this case, the literal already lost its source earlier 135 # and has already been made false at the end of the loop. 136 if lit < 0 and lit in interpretation: 137 continue 138 for i in watches.get(-lit, []): 139 counters_source[i] += 1 140 if i in is_source: 141 atm, = scc[i].head 142 is_source.remove(i) 143 has_source.remove(atm) 144 if -atm not in interpretation: 145 todo.append(-atm) 146 147 # initialize sources 148 for lit in todo: 149 for i in can_source.get(-lit, []): 150 enqueue_source(i) 151 152 # forward propagate sources 153 while todo_source: 154 atm = todo_source.popleft() 155 for i in watches.get(atm, []): 156 counters_source[i] -= 1 157 enqueue_source(i) 158 159 # set literals without sources to false 160 if not unfounded: 161 unfounded, todo = todo, unfounded 162 todo.clear() 163 for lit in unfounded: 164 if lit < 0 and -lit in need_source and -lit not in has_source: 165 enqueue_lit(lit) 166 unfounded.clear() 167 168 169def well_founded(prg: Program) -> Tuple[List[Symbol], List[Symbol]]: 170 ''' 171 Computes the well-founded model of the given program returning a pair of 172 facts and unknown atoms. 173 174 This function assumes that the program contains only normal rules. 175 ''' 176 for rule in prg.rules: 177 if len(rule.head) != 1 or rule.choice: 178 raise RuntimeError('only normal rules are supported') 179 if prg.weight_rules: 180 raise RuntimeError('only normal rules are supported') 181 182 # analyze program and compute well-founded model 183 interpretation: Set[Literal] = set() 184 for scc in _analyze(prg.rules): 185 _well_founded(interpretation, scc) 186 187 # compute facts 188 fct = [atm.symbol for atm in prg.facts] 189 fct.extend(prg.output_atoms[lit] for lit in interpretation if lit > 0 and lit in prg.output_atoms) 190 # compute unknowns 191 ukn = set() 192 for rule in prg.rules: 193 atm, = rule.head 194 not_false = any(-lit in interpretation for lit in rule.body) 195 if atm not in interpretation and not not_false and atm in prg.output_atoms: 196 ukn.add(prg.output_atoms[atm]) 197 return sorted(fct), sorted(ukn) 198 199 200class LevelApp(Application): 201 def __init__(self): 202 self.program_name = "level" 203 self.version = "1.0" 204 205 def main(self, ctl: Control, files): 206 prg = Program() 207 ctl.register_observer(ProgramObserver(prg)) 208 209 for f in files: 210 ctl.load(f) 211 if not files: 212 ctl.load('-') 213 214 ctl.ground([("base", [])]) 215 216 fct, ukn = well_founded(prg) 217 print('Facts:') 218 print(f'{" ".join(map(str, fct))}') 219 print('Unknown:') 220 print(f'{" ".join(map(str, ukn))}') 221 222 ctl.solve() 223 224 225sys.exit(clingo_main(LevelApp(), sys.argv[1:])) 226