1from sys import stdout, exit
2from textwrap import dedent
3from copy import copy
4
5from clingo.application import Application
6from clingo import SymbolType, Number, Function, ast, clingo_main
7
8class TermTransformer(ast.Transformer):
9    def __init__(self, parameter):
10        self.parameter = parameter
11
12    def __get_param(self, name, location):
13        n = name.replace('\'', '')
14        primes = len(name) - len(n)
15        param = ast.SymbolicTerm(location, self.parameter)
16        if primes > 0:
17            param = ast.BinaryOperation(location, ast.BinaryOperator.Minus, param, ast.SymbolicTerm(location, Number(primes)))
18        return n, param
19
20    def visit_Function(self, term):
21        name, param = self.__get_param(term.name, term.location)
22        term = term.update(name=name)
23        term.arguments.append(param)
24        return term
25
26    def visit_SymbolicTerm(self, term):
27        # this function is not necessary if gringo's parser is used
28        # but this case could occur in a valid AST
29        raise RuntimeError("not implemented")
30
31class ProgramTransformer(ast.Transformer):
32    def __init__(self, parameter):
33        self.final = False
34        self.parameter = parameter
35        self.term_transformer = TermTransformer(parameter)
36
37    def visit(self, x, *args, **kwargs):
38        ret = super().visit(x, *args, **kwargs)
39        if self.final and hasattr(ret, "body"):
40            if x is ret:
41                ret = copy(x)
42            loc = ret.location
43            fun = ast.Function(loc, "finally", [ast.SymbolicTerm(loc, self.parameter)], False)
44            atm = ast.SymbolicAtom(fun)
45            lit = ast.Literal(loc, ast.Sign.NoSign, atm)
46            ret.body.append(lit)
47        return ret
48
49    def visit_SymbolicAtom(self, atom):
50        return atom.update(symbol=self.term_transformer(atom.symbol))
51
52    def visit_Program(self, prg):
53        self.final = prg.name == "final"
54        prg = copy(prg)
55        if self.final:
56            prg.name = "static"
57        prg.parameters.append(ast.Id(prg.location, self.parameter.name))
58        return prg
59
60    def visit_ShowSignature(self, sig):
61        return sig.update(arity=sig.arity + 1)
62
63    def visit_ProjectSignature(self, sig):
64        return sig.update(arity=sig.arity + 1)
65
66class TModeApp(Application):
67    def __init__(self):
68        self._imin = 0
69        self._imax = None
70        self._istop = "SAT"
71        self._horizon = 0
72
73    def _parse_imin(self, value):
74        try:
75            self._imin = int(value)
76        except ValueError:
77            return False
78        return self._imin >= 0
79
80    def _parse_imax(self, value):
81        if value.upper() in ("INF", "INFINITY"):
82            self._imax = None
83            return True
84        try:
85            self._imax = int(value)
86        except ValueError:
87            return False
88        return self._imax >= 0
89
90    def _parse_istop(self, value):
91        self._istop = value.upper()
92        return self._istop in ["SAT", "UNSAT", "UNKNOWN"]
93
94    def register_options(self, options):
95        group = "Incremental Options"
96        options.add(group, "imin", "Minimum number of solving steps [0]",
97                    self._parse_imin, argument="<n>")
98        options.add(group, "imax", "Maximum number of solving steps [infinity]",
99                    self._parse_imax, argument="<n>")
100        options.add(group, "istop", dedent("""\
101            Stop criterion [sat]
102                  <arg>: {sat|unsat|unknown}"""), self._parse_istop)
103
104    def print_model(self, model, printer):
105        table = {}
106        for sym in model.symbols(shown=True):
107            if sym.type == SymbolType.Function and len(sym.arguments) > 0:
108                table.setdefault(sym.arguments[-1], []).append(Function(sym.name, sym.arguments[:-1]))
109        for step, symbols in sorted(table.items()):
110            stdout.write(" State {}:".format(step))
111            sig = None
112            for sym in sorted(symbols):
113                if (sym.name, len(sym.arguments)) != sig:
114                    stdout.write("\n ")
115                    sig = (sym.name, len(sym.arguments))
116                stdout.write(" {}".format(sym))
117            stdout.write("\n")
118
119    def _main(self, ctl):
120        step, ret = 0, None
121        while ((self._imax is None or step < self._imax) and
122               (step == 0 or step < self._imin or (
123                  (self._istop == "SAT"     and not ret.satisfiable) or
124                  (self._istop == "UNSAT"   and not ret.unsatisfiable) or
125                  (self._istop == "UNKNOWN" and not ret.unknown)))):
126            parts = []
127            parts.append(("base", [Number(step)]))
128            parts.append(("static", [Number(step)]))
129            if step > 0:
130                ctl.release_external(Function("finally", [Number(step-1)]))
131                parts.append(("dynamic", [Number(step)]))
132            else:
133                parts.append(("initial", [Number(0)]))
134            ctl.ground(parts)
135            ctl.assign_external(Function("finally", [Number(step)]), True)
136            ret, step = ctl.solve(), step+1
137
138    def main(self, ctl, files):
139        with ast.ProgramBuilder(ctl) as bld:
140            ptf = ProgramTransformer(Function("__t"))
141            ast.parse_files(files, lambda stm: bld.add(ptf(stm)))
142        ctl.add("initial", ["t"], "initially(t).")
143        ctl.add("static", ["t"], "#external finally(t).")
144        self._main(ctl)
145
146exit(clingo_main(TModeApp()))
147