1 /*++
2 Copyright (c) 2012 Microsoft Corporation
3 
4 Module Name:
5 
6     smt_solver.cpp
7 
8 Abstract:
9 
10     Wraps smt::kernel as a solver for the external API and cmd_context.
11 
12 Author:
13 
14     Leonardo (leonardo) 2012-10-21
15 
16 Notes:
17 
18 --*/
19 
20 #include "util/dec_ref_util.h"
21 #include "ast/reg_decl_plugins.h"
22 #include "ast/for_each_expr.h"
23 #include "ast/ast_smt2_pp.h"
24 #include "ast/func_decl_dependencies.h"
25 #include "smt/smt_kernel.h"
26 #include "smt/params/smt_params.h"
27 #include "smt/params/smt_params_helper.hpp"
28 #include "solver/solver_na2as.h"
29 #include "solver/mus.h"
30 
31 namespace {
32 
33     class smt_solver : public solver_na2as {
34 
35         struct cuber {
36             smt_solver& m_solver;
37             unsigned    m_round;
38             expr_ref_vector  m_result;
39             unsigned    m_depth;
cuber__anona0a2e2e20111::smt_solver::cuber40             cuber(smt_solver& s):
41                 m_solver(s),
42                 m_round(0),
43                 m_result(s.get_manager()),
44                 m_depth(s.m_smt_params.m_cube_depth) {}
cube__anona0a2e2e20111::smt_solver::cuber45             expr_ref cube() {
46                 if (m_round == 0) {
47                     m_result = m_solver.m_context.cubes(m_depth);
48                 }
49                 expr_ref r(m_result.m());
50                 if (m_round < m_result.size()) {
51                     r = m_result.get(m_round);
52                 }
53                 else {
54                     r = m_result.m().mk_false();
55                 }
56                 ++m_round;
57                 return r;
58             }
59         };
60 
61         smt_params           m_smt_params;
62         smt::kernel          m_context;
63         cuber*               m_cuber;
64         symbol               m_logic;
65         bool                 m_minimizing_core;
66         bool                 m_core_extend_patterns;
67         unsigned             m_core_extend_patterns_max_distance;
68         bool                 m_core_extend_nonlocal_patterns;
69         obj_map<expr, expr*> m_name2assertion;
70 
71     public:
smt_solver(ast_manager & m,params_ref const & p,symbol const & l)72         smt_solver(ast_manager & m, params_ref const & p, symbol const & l) :
73             solver_na2as(m),
74             m_smt_params(p),
75             m_context(m, m_smt_params),
76             m_cuber(nullptr),
77             m_minimizing_core(false),
78             m_core_extend_patterns(false),
79             m_core_extend_patterns_max_distance(UINT_MAX),
80             m_core_extend_nonlocal_patterns(false) {
81             m_logic = l;
82             if (m_logic != symbol::null)
83                 m_context.set_logic(m_logic);
84             updt_params(p);
85         }
86 
translate(ast_manager & m,params_ref const & p)87         solver * translate(ast_manager & m, params_ref const & p) override {
88             ast_translation translator(get_manager(), m);
89 
90             smt_solver * result = alloc(smt_solver, m, p, m_logic);
91             smt::kernel::copy(m_context, result->m_context);
92 
93             if (mc0())
94                 result->set_model_converter(mc0()->translate(translator));
95 
96             for (auto & kv : m_name2assertion) {
97                 expr* val = translator(kv.m_value);
98                 expr* key = translator(kv.m_key);
99                 result->assert_expr(val, key);
100             }
101 
102             return result;
103         }
104 
~smt_solver()105         ~smt_solver() override {
106             dealloc(m_cuber);
107             for (auto& kv : m_name2assertion) {
108                 get_manager().dec_ref(kv.m_key);
109                 get_manager().dec_ref(kv.m_value);
110             }
111         }
112 
updt_params(params_ref const & p)113         void updt_params(params_ref const & p) override {
114             solver::updt_params(p);
115             m_smt_params.updt_params(solver::get_params());
116             m_context.updt_params(solver::get_params());
117             smt_params_helper smth(solver::get_params());
118             m_core_extend_patterns = smth.core_extend_patterns();
119             m_core_extend_patterns_max_distance = smth.core_extend_patterns_max_distance();
120             m_core_extend_nonlocal_patterns = smth.core_extend_nonlocal_patterns();
121         }
122 
123         params_ref m_params_save;
124         smt_params m_smt_params_save;
125 
push_params()126         void push_params() override {
127             m_params_save = params_ref();
128             m_params_save.copy(solver::get_params());
129             m_smt_params_save = m_smt_params;
130         }
131 
pop_params()132         void pop_params() override {
133             m_smt_params = m_smt_params_save;
134             solver::reset_params(m_params_save);
135         }
136 
collect_param_descrs(param_descrs & r)137         void collect_param_descrs(param_descrs & r) override {
138             m_context.collect_param_descrs(r);
139             insert_timeout(r);
140             insert_rlimit(r);
141             insert_max_memory(r);
142             insert_ctrl_c(r);
143         }
144 
collect_statistics(statistics & st) const145         void collect_statistics(statistics & st) const override {
146             m_context.collect_statistics(st);
147         }
148 
get_consequences_core(expr_ref_vector const & assumptions,expr_ref_vector const & vars,expr_ref_vector & conseq)149         lbool get_consequences_core(expr_ref_vector const& assumptions, expr_ref_vector const& vars, expr_ref_vector& conseq) override {
150             expr_ref_vector unfixed(m_context.m());
151             return m_context.get_consequences(assumptions, vars, conseq, unfixed);
152         }
153 
find_mutexes(expr_ref_vector const & vars,vector<expr_ref_vector> & mutexes)154         lbool find_mutexes(expr_ref_vector const& vars, vector<expr_ref_vector>& mutexes) override {
155             return m_context.find_mutexes(vars, mutexes);
156         }
157 
assert_expr_core(expr * t)158         void assert_expr_core(expr * t) override {
159             m_context.assert_expr(t);
160         }
161 
assert_expr_core2(expr * t,expr * a)162         void assert_expr_core2(expr * t, expr * a) override {
163             if (m_name2assertion.contains(a)) {
164                 throw default_exception("named assertion defined twice");
165             }
166             solver_na2as::assert_expr_core2(t, a);
167             get_manager().inc_ref(t);
168             get_manager().inc_ref(a);
169             m_name2assertion.insert(a, t);
170         }
171 
push_core()172         void push_core() override {
173             m_context.push();
174         }
175 
pop_core(unsigned n)176         void pop_core(unsigned n) override {
177             unsigned cur_sz = m_assumptions.size();
178             if (n > 0 && cur_sz > 0) {
179                 unsigned lvl = m_scopes.size();
180                 SASSERT(n <= lvl);
181                 unsigned new_lvl = lvl - n;
182                 unsigned old_sz = m_scopes[new_lvl];
183                 for (unsigned i = cur_sz; i-- > old_sz; ) {
184                     expr * key = m_assumptions.get(i);
185                     expr * value = m_name2assertion.find(key);
186                     m_name2assertion.erase(key);
187                     m.dec_ref(value);
188                     m.dec_ref(key);
189                 }
190             }
191             m_context.pop(n);
192         }
193 
check_sat_core2(unsigned num_assumptions,expr * const * assumptions)194         lbool check_sat_core2(unsigned num_assumptions, expr * const * assumptions) override {
195             TRACE("solver_na2as", tout << "smt_solver::check_sat_core: " << num_assumptions << "\n";);
196             return m_context.check(num_assumptions, assumptions);
197         }
198 
199 
check_sat_cc_core(expr_ref_vector const & cube,vector<expr_ref_vector> const & clauses)200         lbool check_sat_cc_core(expr_ref_vector const& cube, vector<expr_ref_vector> const& clauses) override {
201             return m_context.check(cube, clauses);
202         }
203 
get_levels(ptr_vector<expr> const & vars,unsigned_vector & depth)204         void get_levels(ptr_vector<expr> const& vars, unsigned_vector& depth) override {
205             m_context.get_levels(vars, depth);
206         }
207 
get_trail()208         expr_ref_vector get_trail() override {
209             return m_context.get_trail();
210         }
211 
user_propagate_init(void * ctx,solver::push_eh_t & push_eh,solver::pop_eh_t & pop_eh,solver::fresh_eh_t & fresh_eh)212         void user_propagate_init(
213             void*                ctx,
214             solver::push_eh_t&   push_eh,
215             solver::pop_eh_t&    pop_eh,
216             solver::fresh_eh_t&  fresh_eh) override {
217             m_context.user_propagate_init(ctx, push_eh, pop_eh, fresh_eh);
218         }
219 
user_propagate_register_fixed(solver::fixed_eh_t & fixed_eh)220         void user_propagate_register_fixed(solver::fixed_eh_t& fixed_eh) override {
221             m_context.user_propagate_register_fixed(fixed_eh);
222         }
223 
user_propagate_register_final(solver::final_eh_t & final_eh)224         void user_propagate_register_final(solver::final_eh_t& final_eh) override {
225             m_context.user_propagate_register_final(final_eh);
226         }
227 
user_propagate_register_eq(solver::eq_eh_t & eq_eh)228         void user_propagate_register_eq(solver::eq_eh_t& eq_eh) override {
229             m_context.user_propagate_register_eq(eq_eh);
230         }
231 
user_propagate_register_diseq(solver::eq_eh_t & diseq_eh)232         void user_propagate_register_diseq(solver::eq_eh_t& diseq_eh) override {
233             m_context.user_propagate_register_diseq(diseq_eh);
234         }
235 
user_propagate_register(expr * e)236         unsigned user_propagate_register(expr* e) override {
237             return m_context.user_propagate_register(e);
238         }
239 
240         struct scoped_minimize_core {
241             smt_solver& s;
242             expr_ref_vector m_assumptions;
scoped_minimize_core__anona0a2e2e20111::smt_solver::scoped_minimize_core243             scoped_minimize_core(smt_solver& s) : s(s), m_assumptions(s.m_assumptions) {
244                 s.m_minimizing_core = true;
245                 s.m_assumptions.reset();
246             }
247 
~scoped_minimize_core__anona0a2e2e20111::smt_solver::scoped_minimize_core248             ~scoped_minimize_core() {
249                 s.m_minimizing_core = false;
250                 s.m_assumptions.append(m_assumptions);
251             }
252         };
253 
get_unsat_core(expr_ref_vector & r)254         void get_unsat_core(expr_ref_vector & r) override {
255             unsigned sz = m_context.get_unsat_core_size();
256             for (unsigned i = 0; i < sz; i++) {
257                 r.push_back(m_context.get_unsat_core_expr(i));
258             }
259 
260             if (!m_minimizing_core && smt_params_helper(get_params()).core_minimize()) {
261                 scoped_minimize_core scm(*this);
262                 mus mus(*this);
263                 mus.add_soft(r.size(), r.c_ptr());
264                 expr_ref_vector r2(m);
265                 if (l_true == mus.get_mus(r2)) {
266                     r.reset();
267                     r.append(r2);
268                 }
269             }
270 
271             if (m_core_extend_patterns)
272                 add_pattern_literals_to_core(r);
273             if (m_core_extend_nonlocal_patterns)
274                 add_nonlocal_pattern_literals_to_core(r);
275         }
276 
get_model_core(model_ref & m)277         void get_model_core(model_ref & m) override {
278             m_context.get_model(m);
279         }
280 
get_proof()281         proof * get_proof() override {
282             return m_context.get_proof();
283         }
284 
reason_unknown() const285         std::string reason_unknown() const override {
286             return m_context.last_failure_as_string();
287         }
288 
set_reason_unknown(char const * msg)289         void set_reason_unknown(char const* msg) override {
290             m_context.set_reason_unknown(msg);
291         }
292 
get_labels(svector<symbol> & r)293         void get_labels(svector<symbol> & r) override {
294             buffer<symbol> tmp;
295             m_context.get_relevant_labels(nullptr, tmp);
296             r.append(tmp.size(), tmp.c_ptr());
297         }
298 
get_manager() const299         ast_manager & get_manager() const override { return m_context.m(); }
300 
set_progress_callback(progress_callback * callback)301         void set_progress_callback(progress_callback * callback) override {
302             m_context.set_progress_callback(callback);
303         }
304 
get_num_assertions() const305         unsigned get_num_assertions() const override {
306             return m_context.size();
307         }
308 
get_assertion(unsigned idx) const309         expr * get_assertion(unsigned idx) const override {
310             SASSERT(idx < get_num_assertions());
311             return m_context.get_formula(idx);
312         }
313 
cube(expr_ref_vector & vars,unsigned cutoff)314         expr_ref_vector cube(expr_ref_vector& vars, unsigned cutoff) override {
315             ast_manager& m = get_manager();
316             if (!m_cuber) {
317                 m_cuber = alloc(cuber, *this);
318                 // force propagation
319                 push_core();
320                 pop_core(1);
321             }
322             expr_ref result = m_cuber->cube();
323             expr_ref_vector lits(m);
324             if (m.is_false(result)) {
325                 dealloc(m_cuber);
326                 m_cuber = nullptr;
327             }
328             if (m.is_true(result)) {
329                 dealloc(m_cuber);
330                 m_cuber = nullptr;
331                 return lits;
332             }
333             lits.push_back(result);
334             return lits;
335         }
336 
337         struct collect_fds_proc {
338             ast_manager & m;
339             func_decl_set & m_fds;
collect_fds_proc__anona0a2e2e20111::smt_solver::collect_fds_proc340             collect_fds_proc(ast_manager & m, func_decl_set & fds) :
341                 m(m), m_fds(fds) {
342             }
operator ()__anona0a2e2e20111::smt_solver::collect_fds_proc343             void operator()(var * n) {}
operator ()__anona0a2e2e20111::smt_solver::collect_fds_proc344             void operator()(app * n) {
345                 func_decl * fd = n->get_decl();
346                 if (fd->get_family_id() == null_family_id)
347                     m_fds.insert_if_not_there(fd);
348             }
operator ()__anona0a2e2e20111::smt_solver::collect_fds_proc349             void operator()(quantifier * n) {}
350         };
351 
352         struct collect_pattern_fds_proc {
353             ast_manager & m;
354             expr_fast_mark1 m_visited;
355             func_decl_set & m_fds;
collect_pattern_fds_proc__anona0a2e2e20111::smt_solver::collect_pattern_fds_proc356             collect_pattern_fds_proc(ast_manager & m, func_decl_set & fds) :
357                 m(m), m_fds(fds) {
358                 m_visited.reset();
359             }
operator ()__anona0a2e2e20111::smt_solver::collect_pattern_fds_proc360             void operator()(var * n) {}
operator ()__anona0a2e2e20111::smt_solver::collect_pattern_fds_proc361             void operator()(app * n) {}
operator ()__anona0a2e2e20111::smt_solver::collect_pattern_fds_proc362             void operator()(quantifier * n) {
363                 collect_fds_proc p(m, m_fds);
364 
365                 unsigned sz = n->get_num_patterns();
366                 for (unsigned i = 0; i < sz; i++)
367                     quick_for_each_expr(p, m_visited, n->get_pattern(i));
368 
369                 sz = n->get_num_no_patterns();
370                 for (unsigned i = 0; i < sz; i++)
371                     quick_for_each_expr(p, m_visited, n->get_no_pattern(i));
372             }
373         };
374 
collect_pattern_fds(expr_ref & e,func_decl_set & fds)375         void collect_pattern_fds(expr_ref & e, func_decl_set & fds) {
376             collect_pattern_fds_proc p(get_manager(), fds);
377             expr_mark visited;
378             for_each_expr(p, visited, e);
379         }
380 
compute_assrtn_fds(expr_ref_vector & core,vector<func_decl_set> & assrtn_fds)381         void compute_assrtn_fds(expr_ref_vector & core, vector<func_decl_set> & assrtn_fds) {
382             assrtn_fds.resize(m_name2assertion.size());
383             unsigned i = 0;
384             for (auto & kv : m_name2assertion) {
385                 if (!core.contains(kv.m_key)) {
386                     collect_fds_proc p(m, assrtn_fds[i]);
387                     expr_fast_mark1 visited;
388                     quick_for_each_expr(p, visited, kv.m_value);
389                 }
390                 ++i;
391             }
392         }
393 
fds_intersect(func_decl_set & pattern_fds,func_decl_set & assrtn_fds)394         bool fds_intersect(func_decl_set & pattern_fds, func_decl_set & assrtn_fds) {
395             for (func_decl * fd : pattern_fds) {
396                 if (assrtn_fds.contains(fd))
397                     return true;
398             }
399             return false;
400         }
401 
add_pattern_literals_to_core(expr_ref_vector & core)402         void add_pattern_literals_to_core(expr_ref_vector & core) {
403             ast_manager & m = get_manager();
404             expr_ref_vector new_core_literals(m);
405 
406             func_decl_set pattern_fds;
407             vector<func_decl_set> assrtn_fds;
408 
409             for (unsigned d = 0; d < m_core_extend_patterns_max_distance; d++) {
410                 new_core_literals.reset();
411 
412                 for (expr* c : core) {
413                     expr_ref name(c, m);
414                     expr* f = nullptr;
415                     if (m_name2assertion.find(name, f)) {
416                         expr_ref assrtn(f, m);
417                         collect_pattern_fds(assrtn, pattern_fds);
418                     }
419                 }
420 
421                 if (!pattern_fds.empty()) {
422                     if (assrtn_fds.empty())
423                         compute_assrtn_fds(core, assrtn_fds);
424 
425                     unsigned i = 0;
426                     for (auto & kv : m_name2assertion) {
427                         if (!core.contains(kv.m_key) &&
428                             fds_intersect(pattern_fds, assrtn_fds[i]))
429                             new_core_literals.push_back(kv.m_key);
430                         ++i;
431                     }
432                 }
433 
434                 core.append(new_core_literals.size(), new_core_literals.c_ptr());
435 
436                 if (new_core_literals.empty())
437                     break;
438             }
439         }
440 
441         struct collect_body_fds_proc {
442             ast_manager & m;
443             func_decl_set & m_fds;
collect_body_fds_proc__anona0a2e2e20111::smt_solver::collect_body_fds_proc444             collect_body_fds_proc(ast_manager & m, func_decl_set & fds) :
445                 m(m), m_fds(fds) {
446             }
operator ()__anona0a2e2e20111::smt_solver::collect_body_fds_proc447             void operator()(var * n) {}
operator ()__anona0a2e2e20111::smt_solver::collect_body_fds_proc448             void operator()(app * n) {}
operator ()__anona0a2e2e20111::smt_solver::collect_body_fds_proc449             void operator()(quantifier * n) {
450                 collect_fds_proc p(m, m_fds);
451                 expr_fast_mark1 visited;
452                 quick_for_each_expr(p, visited, n->get_expr());
453             }
454         };
455 
collect_body_func_decls(expr_ref & e,func_decl_set & fds)456         void collect_body_func_decls(expr_ref & e, func_decl_set & fds) {
457             ast_manager & m = get_manager();
458             collect_body_fds_proc p(m, fds);
459             expr_mark visited;
460             for_each_expr(p, visited, e);
461         }
462 
add_nonlocal_pattern_literals_to_core(expr_ref_vector & core)463         void add_nonlocal_pattern_literals_to_core(expr_ref_vector & core) {
464             ast_manager & m = get_manager();
465             for (auto const& kv : m_name2assertion) {
466                 expr_ref name(kv.m_key, m);
467                 expr_ref assrtn(kv.m_value, m);
468 
469                 if (!core.contains(name)) {
470                     func_decl_set pattern_fds, body_fds;
471                     collect_pattern_fds(assrtn, pattern_fds);
472                     collect_body_func_decls(assrtn, body_fds);
473 
474                     for (func_decl *fd : pattern_fds) {
475                         if (!body_fds.contains(fd) && !core.contains(name)) {
476                             core.push_back(name);
477                             break;
478                         }
479                     }
480                 }
481             }
482         }
483     };
484 }
485 
mk_smt_solver(ast_manager & m,params_ref const & p,symbol const & logic)486 solver * mk_smt_solver(ast_manager & m, params_ref const & p, symbol const & logic) {
487     return alloc(smt_solver, m, p, logic);
488 }
489 
490 namespace {
491 class smt_solver_factory : public solver_factory {
492 public:
operator ()(ast_manager & m,params_ref const & p,bool proofs_enabled,bool models_enabled,bool unsat_core_enabled,symbol const & logic)493     solver * operator()(ast_manager & m, params_ref const & p, bool proofs_enabled, bool models_enabled, bool unsat_core_enabled, symbol const & logic) override {
494         return mk_smt_solver(m, p, logic);
495     }
496 };
497 }
498 
mk_smt_solver_factory()499 solver_factory * mk_smt_solver_factory() {
500     return alloc(smt_solver_factory);
501 }
502 
503