1#
2# This file is part of pySMT.
3#
4#   Copyright 2014 Andrea Micheli and Marco Gario
5#
6#   Licensed under the Apache License, Version 2.0 (the "License");
7#   you may not use this file except in compliance with the License.
8#   You may obtain a copy of the License at
9#
10#       http://www.apache.org/licenses/LICENSE-2.0
11#
12#   Unless required by applicable law or agreed to in writing, software
13#   distributed under the License is distributed on an "AS IS" BASIS,
14#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15#   See the License for the specific language governing permissions and
16#   limitations under the License.
17#
18
19import warnings
20from collections import namedtuple
21from six.moves import cStringIO
22from six.moves import xrange
23
24import pysmt.smtlib.commands as smtcmd
25from pysmt.exceptions import (UnknownSmtLibCommandError, NoLogicAvailableError,
26                              UndefinedLogicError, PysmtValueError)
27from pysmt.smtlib.printers import SmtPrinter, SmtDagPrinter, quote
28from pysmt.oracles import get_logic
29from pysmt.logics import get_closer_smtlib_logic, Logic, SMTLIB2_LOGICS
30from pysmt.environment import get_env
31
32
33def check_sat_filter(log):
34    """
35    Returns the result of the check-sat command from a log.
36
37    Raises errors in case a unique check-sat command cannot be located.
38    """
39    filtered = [(x,y) for x,y in log if x == smtcmd.CHECK_SAT]
40    assert len(filtered) == 1
41    return filtered[0][1]
42
43
44class SmtLibCommand(namedtuple('SmtLibCommand', ['name', 'args'])):
45    def serialize(self, outstream=None, printer=None, daggify=True):
46        """Serializes the SmtLibCommand into outstream using the given printer.
47
48        Exactly one of outstream or printer must be specified. When
49        specifying the printer, the associated outstream will be used.
50        If printer is not specified, daggify controls the printer to
51        be created. If true a daggified formula is produced, otherwise
52        a tree printing is done.
53
54        """
55
56        if (outstream is None) and (printer is not None):
57            outstream = printer.stream
58        elif (outstream is not None) and (printer is None):
59            if daggify:
60                printer = SmtDagPrinter(outstream)
61            else:
62                printer = SmtPrinter(outstream)
63        else:
64            assert (outstream is not None and printer is not None) or \
65                   (outstream is None and printer is None), \
66                   "Exactly one of outstream and printer must be set."
67
68        if self.name == smtcmd.SET_OPTION:
69            outstream.write("(%s %s %s)" % (self.name,self.args[0],self.args[1]))
70
71        elif self.name == smtcmd.SET_INFO:
72            outstream.write("(%s %s %s)" % (self.name,self.args[0],
73                                            quote(self.args[1])))
74
75        elif self.name == smtcmd.ASSERT:
76            outstream.write("(%s " % self.name)
77            printer.printer(self.args[0])
78            outstream.write(")")
79
80        elif self.name == smtcmd.GET_VALUE:
81            outstream.write("(%s (" % self.name)
82            for a in self.args:
83                printer.printer(a)
84                outstream.write(" ")
85            outstream.write("))")
86
87        elif self.name in [smtcmd.CHECK_SAT, smtcmd.EXIT,
88                           smtcmd.RESET_ASSERTIONS, smtcmd.GET_UNSAT_CORE,
89                           smtcmd.GET_ASSIGNMENT, smtcmd.GET_MODEL]:
90            outstream.write("(%s)" % self.name)
91
92        elif self.name == smtcmd.SET_LOGIC:
93            outstream.write("(%s %s)" % (self.name, self.args[0]))
94
95        elif self.name in [smtcmd.DECLARE_FUN, smtcmd.DECLARE_CONST]:
96            symbol = self.args[0]
97            type_str = symbol.symbol_type().as_smtlib()
98            outstream.write("(%s %s %s)" % (self.name,
99                                            quote(symbol.symbol_name()),
100                                            type_str))
101
102        elif self.name == smtcmd.DEFINE_FUN:
103            name = self.args[0]
104            params_list = self.args[1]
105            params = " ".join(["(%s %s)" % (v, v.symbol_type()) for v in params_list])
106            rtype = self.args[2]
107            expr = self.args[3]
108            outstream.write("(%s %s (%s) %s " % (self.name,
109                                                name,
110                                                params,
111                                                rtype))
112            printer.printer(expr)
113            outstream.write(")")
114
115        elif self.name in [smtcmd.PUSH, smtcmd.POP]:
116            outstream.write("(%s %d)" % (self.name, self.args[0]))
117
118        elif self.name == smtcmd.DEFINE_SORT:
119            name = self.args[0]
120            params_list = self.args[1]
121            params = " ".join(params_list)
122            rtype = self.args[2]
123            outstream.write("(%s %s (%s) %s)" % (self.name,
124                                                 name,
125                                                 params,
126                                                 rtype))
127        elif self.name == smtcmd.DECLARE_SORT:
128            type_decl = self.args[0]
129            outstream.write("(%s %s %d)" % (self.name,
130                                            type_decl.name,
131                                            type_decl.arity))
132
133        elif self.name in smtcmd.ALL_COMMANDS:
134            raise NotImplementedError("'%s' is a valid SMT-LIB command "\
135                                      "but it is currently not supported. "\
136                                      "Please open a bug-report." % self.name)
137        else:
138            raise UnknownSmtLibCommandError(self.name)
139
140    def serialize_to_string(self, daggify=True):
141        buf = cStringIO()
142        self.serialize(buf, daggify=daggify)
143        return buf.getvalue()
144
145
146class SmtLibScript(object):
147
148    def __init__(self):
149        self.annotations = None
150        self.commands = []
151
152    def add(self, name, args):
153        """Adds a new SmtLibCommand with the given name and arguments."""
154        self.add_command(SmtLibCommand(name=name,
155                                       args=args))
156
157    def add_command(self, command):
158        self.commands.append(command)
159
160    def evaluate(self, solver):
161        log = []
162        for cmd in self.commands:
163            r = evaluate_command(cmd, solver)
164            log.append((cmd.name, r))
165        return log
166
167    def contains_command(self, command_name):
168        return any(x.name == command_name for x in self.commands)
169
170    def count_command_occurrences(self, command_name):
171        return sum(1 for cmd in self.commands if cmd.name == command_name)
172
173    def filter_by_command_name(self, command_name_set):
174        return (cmd for cmd in self.commands if cmd.name in command_name_set)
175
176    def get_strict_formula(self, mgr=None):
177        if self.contains_command(smtcmd.PUSH) or \
178           self.contains_command(smtcmd.POP):
179            raise PysmtValueError("Was not expecting push-pop commands")
180        if self.count_command_occurrences(smtcmd.CHECK_SAT) != 1:
181            raise PysmtValueError("Was expecting exactly one check-sat command")
182        _And = mgr.And if mgr else get_env().formula_manager.And
183
184        assertions = [cmd.args[0]
185                      for cmd in self.filter_by_command_name([smtcmd.ASSERT])]
186        return _And(assertions)
187
188    def get_declared_symbols(self):
189        return {cmd.args[0] for cmd in self.filter_by_command_name([smtcmd.DECLARE_CONST,
190                                                                    smtcmd.DECLARE_FUN])}
191    def get_define_fun_parameter_symbols(self):
192        res = set()
193        for cmd in self.filter_by_command_name([smtcmd.DEFINE_FUN]):
194            for s in cmd.args[1]:
195                res.add(s)
196        return res
197
198    def get_last_formula(self, mgr=None):
199        """Returns the last formula of the execution of the Script.
200
201        This coincides with the conjunction of the assertions that are
202        left on the assertion stack at the end of the SMTLibScript.
203        """
204        stack = []
205        backtrack = []
206        _And = mgr.And if mgr else get_env().formula_manager.And
207
208        for cmd in self.commands:
209            if cmd.name == smtcmd.ASSERT:
210                stack.append(cmd.args[0])
211            if cmd.name == smtcmd.RESET_ASSERTIONS:
212                stack = []
213                backtrack = []
214            elif cmd.name == smtcmd.PUSH:
215                for _ in xrange(cmd.args[0]):
216                    backtrack.append(len(stack))
217            elif cmd.name == smtcmd.POP:
218                for _ in xrange(cmd.args[0]):
219                    l = backtrack.pop()
220                    stack = stack[:l]
221
222        return _And(stack)
223
224    def to_file(self, fname, daggify=True):
225        with open(fname, "w") as outstream:
226            self.serialize(outstream, daggify=daggify)
227
228    def serialize(self, outstream, daggify=True):
229        """Serializes the SmtLibScript expanding commands"""
230        if daggify:
231            printer = SmtDagPrinter(outstream)
232        else:
233            printer = SmtPrinter(outstream)
234
235        for cmd in self.commands:
236            cmd.serialize(printer=printer)
237            outstream.write("\n")
238
239    def __len__(self):
240        return len(self.commands)
241
242    def __iter__(self):
243        return iter(self.commands)
244
245    def __str__(self):
246        return "\n".join((str(cmd) for cmd in self.commands))
247
248
249def smtlibscript_from_formula(formula, logic=None):
250    script = SmtLibScript()
251
252    if logic is None:
253        # Get the simplest SmtLib logic that contains the formula
254        f_logic = get_logic(formula)
255
256        smt_logic = None
257        try:
258            smt_logic = get_closer_smtlib_logic(f_logic)
259        except NoLogicAvailableError:
260            warnings.warn("The logic %s is not reducible to any SMTLib2 " \
261                          "standard logic. Proceeding with non-standard " \
262                          "logic '%s'" % (f_logic, f_logic),
263                          stacklevel=3)
264            smt_logic = f_logic
265    elif not (isinstance(logic, Logic) or isinstance(logic, str)):
266        raise UndefinedLogicError(str(logic))
267    else:
268        if logic not in SMTLIB2_LOGICS:
269            warnings.warn("The logic %s is not reducible to any SMTLib2 " \
270                          "standard logic. Proceeding with non-standard " \
271                          "logic '%s'" % (logic, logic),
272                          stacklevel=3)
273        smt_logic = logic
274
275    script.add(name=smtcmd.SET_LOGIC,
276               args=[smt_logic])
277
278    # Declare all types
279    types = get_env().typeso.get_types(formula, custom_only=True)
280    for type_ in types:
281        script.add(name=smtcmd.DECLARE_SORT, args=[type_.decl])
282
283    deps = formula.get_free_variables()
284    # Declare all variables
285    for symbol in deps:
286        assert symbol.is_symbol()
287        script.add(name=smtcmd.DECLARE_FUN, args=[symbol])
288
289    # Assert formula
290    script.add_command(SmtLibCommand(name=smtcmd.ASSERT,
291                                     args=[formula]))
292    # check-sat
293    script.add_command(SmtLibCommand(name=smtcmd.CHECK_SAT,
294                                     args=[]))
295    return script
296
297
298def evaluate_command(cmd, solver):
299    if cmd.name == smtcmd.SET_INFO:
300        return solver.set_info(cmd.args[0], cmd.args[1])
301
302    if cmd.name == smtcmd.SET_OPTION:
303        return solver.set_option(cmd.args[0], cmd.args[1])
304
305    elif cmd.name == smtcmd.ASSERT:
306        return solver.assert_(cmd.args[0])
307
308    elif cmd.name == smtcmd.CHECK_SAT:
309        return solver.check_sat()
310
311    elif cmd.name == smtcmd.RESET_ASSERTIONS:
312        return solver.reset_assertions()
313
314    elif cmd.name == smtcmd.GET_VALUE:
315        return solver.get_values(cmd.args)
316
317    elif cmd.name == smtcmd.PUSH:
318        return solver.push(cmd.args[0])
319
320    elif cmd.name == smtcmd.POP:
321        return solver.pop(cmd.args[0])
322
323    elif cmd.name == smtcmd.EXIT:
324        return solver.exit()
325
326    elif cmd.name == smtcmd.SET_LOGIC:
327        name = cmd.args[0]
328        return solver.set_logic(name)
329
330    elif cmd.name == smtcmd.DECLARE_FUN:
331        return solver.declare_fun(cmd.args[0])
332
333    elif cmd.name == smtcmd.DECLARE_CONST:
334        return solver.declare_const(cmd.args[0])
335
336    elif cmd.name == smtcmd.DEFINE_FUN:
337        (var, formals, typename, body) = cmd.args
338        return solver.define_fun(var, formals, typename, body)
339
340    elif cmd.name == smtcmd.ECHO:
341        print(cmd.args[0])
342        return None
343
344    elif cmd.name == smtcmd.CHECK_SAT_ASSUMING:
345        return solver.check_sat(cmd.args)
346
347    elif cmd.name == smtcmd.GET_UNSAT_CORE:
348        return solver.get_unsat_core()
349
350    elif cmd.name == smtcmd.DECLARE_SORT:
351        name = cmd.args[0].name
352        arity = cmd.args[0].arity
353        return solver.declare_sort(name, arity)
354
355    elif cmd.name in smtcmd.ALL_COMMANDS:
356        raise NotImplementedError("'%s' is a valid SMT-LIB command "\
357                                  "but it is currently not supported. "\
358                                  "Please open a bug-report." % cmd.name)
359    else:
360        raise UnknownSmtLibCommandError(cmd.name)
361