1 /*++
2 Copyright (c) 2020 Microsoft Corporation
3 
4 Module Name:
5 
6     euf_solver.cpp
7 
8 Abstract:
9 
10     Solver plugin for EUF
11 
12 Author:
13 
14     Nikolaj Bjorner (nbjorner) 2020-08-25
15 
16 --*/
17 
18 #include "ast/pb_decl_plugin.h"
19 #include "ast/ast_ll_pp.h"
20 #include "sat/sat_solver.h"
21 #include "sat/smt/sat_smt.h"
22 #include "sat/smt/pb_solver.h"
23 #include "sat/smt/bv_solver.h"
24 #include "sat/smt/euf_solver.h"
25 #include "sat/smt/array_solver.h"
26 #include "sat/smt/arith_solver.h"
27 #include "sat/smt/q_solver.h"
28 #include "sat/smt/fpa_solver.h"
29 #include "sat/smt/dt_solver.h"
30 #include "sat/smt/recfun_solver.h"
31 
32 namespace euf {
33 
display(std::ostream & out) const34     std::ostream& clause_pp::display(std::ostream& out) const {
35         for (auto lit : lits)
36             out << s.literal2expr(lit) << " ";
37         return out;
38     }
39 
solver(ast_manager & m,sat::sat_internalizer & si,params_ref const & p)40     solver::solver(ast_manager& m, sat::sat_internalizer& si, params_ref const& p) :
41         extension(symbol("euf"), m.mk_family_id("euf")),
42         m(m),
43         si(si),
44         m_egraph(m),
45         m_trail(),
46         m_rewriter(m),
47         m_unhandled_functions(m),
48         m_lookahead(nullptr),
49         m_to_m(&m),
50         m_to_si(&si),
51         m_values(m)
52     {
53         updt_params(p);
54 
55         std::function<void(std::ostream&, void*)> disp =
56             [&](std::ostream& out, void* j) {
57             display_justification_ptr(out, reinterpret_cast<size_t*>(j));
58         };
59         m_egraph.set_display_justification(disp);
60     }
61 
updt_params(params_ref const & p)62     void solver::updt_params(params_ref const& p) {
63         m_config.updt_params(p);
64     }
65 
66     /**
67     * retrieve extension that is associated with Boolean variable.
68     */
bool_var2solver(sat::bool_var v)69     th_solver* solver::bool_var2solver(sat::bool_var v) {
70         if (v >= m_bool_var2expr.size())
71             return nullptr;
72         expr* e = m_bool_var2expr[v];
73         if (!e)
74             return nullptr;
75         return expr2solver(e);
76     }
77 
expr2solver(expr * e)78     th_solver* solver::expr2solver(expr* e) {
79         if (is_app(e))
80             return func_decl2solver(to_app(e)->get_decl());
81         if (is_forall(e) || is_exists(e))
82             return quantifier2solver();
83         return nullptr;
84     }
85 
quantifier2solver()86     th_solver* solver::quantifier2solver() {
87         family_id fid = m.mk_family_id(symbol("quant"));
88         auto* ext = m_id2solver.get(fid, nullptr);
89         if (ext)
90             return ext;
91         ext = alloc(q::solver, *this, fid);
92         m_qsolver = ext;
93         add_solver(ext);
94         return ext;
95     }
96 
get_solver(family_id fid,func_decl * f)97     th_solver* solver::get_solver(family_id fid, func_decl* f) {
98         if (fid == null_family_id)
99             return nullptr;
100         auto* ext = m_id2solver.get(fid, nullptr);
101         if (ext)
102             return ext;
103         if (fid == m.get_basic_family_id())
104             return nullptr;
105         if (fid == m.get_user_sort_family_id())
106             return nullptr;
107         pb_util pb(m);
108         bv_util bvu(m);
109         array_util au(m);
110         fpa_util fpa(m);
111         arith_util arith(m);
112         datatype_util dt(m);
113         recfun::util rf(m);
114         if (pb.get_family_id() == fid)
115             ext = alloc(pb::solver, *this, fid);
116         else if (bvu.get_family_id() == fid)
117             ext = alloc(bv::solver, *this, fid);
118         else if (au.get_family_id() == fid)
119             ext = alloc(array::solver, *this, fid);
120         else if (fpa.get_family_id() == fid)
121             ext = alloc(fpa::solver, *this);
122         else if (arith.get_family_id() == fid)
123             ext = alloc(arith::solver, *this, fid);
124         else if (dt.get_family_id() == fid)
125             ext = alloc(dt::solver, *this, fid);
126         else if (rf.get_family_id() == fid)
127             ext = alloc(recfun::solver, *this);
128 
129         if (ext)
130             add_solver(ext);
131         else if (f)
132             unhandled_function(f);
133         return ext;
134     }
135 
add_solver(th_solver * th)136     void solver::add_solver(th_solver* th) {
137         family_id fid = th->get_id();
138         if (use_drat())
139             s().get_drat().add_theory(fid, th->name());
140         th->set_solver(m_solver);
141         th->push_scopes(s().num_scopes() + s().num_user_scopes());
142         m_solvers.push_back(th);
143         m_id2solver.setx(fid, th, nullptr);
144         if (th->use_diseqs())
145             m_egraph.set_th_propagates_diseqs(fid);
146     }
147 
unhandled_function(func_decl * f)148     void solver::unhandled_function(func_decl* f) {
149         if (m_unhandled_functions.contains(f))
150             return;
151         if (m.is_model_value(f))
152             return;
153         m_unhandled_functions.push_back(f);
154         m_trail.push(push_back_vector<func_decl_ref_vector>(m_unhandled_functions));
155         IF_VERBOSE(0, verbose_stream() << mk_pp(f, m) << " not handled\n");
156     }
157 
init_search()158     void solver::init_search() {
159         TRACE("before_search", s().display(tout););
160         for (auto* s : m_solvers)
161             s->init_search();
162     }
163 
is_external(bool_var v)164     bool solver::is_external(bool_var v) {
165         if (s().is_external(v))
166             return true;
167         if (nullptr != m_bool_var2expr.get(v, nullptr))
168             return true;
169         for (auto* s : m_solvers)
170             if (s->is_external(v))
171                 return true;
172         return false;
173     }
174 
propagated(literal l,ext_constraint_idx idx)175     bool solver::propagated(literal l, ext_constraint_idx idx) {
176         auto* ext = sat::constraint_base::to_extension(idx);
177         SASSERT(ext != this);
178         return ext->propagated(l, idx);
179     }
180 
set_conflict(ext_constraint_idx idx)181     void solver::set_conflict(ext_constraint_idx idx) {
182         s().set_conflict(sat::justification::mk_ext_justification(s().scope_lvl(), idx));
183     }
184 
propagate(literal lit,ext_justification_idx idx)185     void solver::propagate(literal lit, ext_justification_idx idx) {
186         add_auto_relevant(bool_var2expr(lit.var()));
187         s().assign(lit, sat::justification::mk_ext_justification(s().scope_lvl(), idx));
188     }
189 
get_antecedents(literal l,ext_justification_idx idx,literal_vector & r,bool probing)190     void solver::get_antecedents(literal l, ext_justification_idx idx, literal_vector& r, bool probing) {
191         m_egraph.begin_explain();
192         m_explain.reset();
193         auto* ext = sat::constraint_base::to_extension(idx);
194         if (ext == this)
195             get_antecedents(l, constraint::from_idx(idx), r, probing);
196         else
197             ext->get_antecedents(l, idx, r, probing);
198         for (unsigned qhead = 0; qhead < m_explain.size(); ++qhead) {
199             size_t* e = m_explain[qhead];
200             if (is_literal(e))
201                 r.push_back(get_literal(e));
202             else {
203                 size_t idx = get_justification(e);
204                 auto* ext = sat::constraint_base::to_extension(idx);
205                 SASSERT(ext != this);
206                 sat::literal lit = sat::null_literal;
207                 ext->get_antecedents(lit, idx, r, probing);
208             }
209         }
210         m_egraph.end_explain();
211         unsigned j = 0;
212         for (sat::literal lit : r)
213             if (s().lvl(lit) > 0) r[j++] = lit;
214         r.shrink(j);
215         TRACE("euf", tout << "explain " << l << " <- " << r << " " << probing << "\n";);
216         DEBUG_CODE(for (auto lit : r) SASSERT(s().value(lit) == l_true););
217 
218         if (!probing)
219             log_antecedents(l, r);
220     }
221 
get_antecedents(literal l,th_explain & jst,literal_vector & r,bool probing)222     void solver::get_antecedents(literal l, th_explain& jst, literal_vector& r, bool probing) {
223         for (auto lit : euf::th_explain::lits(jst))
224             r.push_back(lit);
225         for (auto eq : euf::th_explain::eqs(jst))
226             add_antecedent(eq.first, eq.second);
227 
228         if (!probing && use_drat())
229             log_justification(l, jst);
230     }
231 
add_antecedent(enode * a,enode * b)232     void solver::add_antecedent(enode* a, enode* b) {
233         m_egraph.explain_eq<size_t>(m_explain, a, b);
234     }
235 
add_diseq_antecedent(enode * a,enode * b)236     void solver::add_diseq_antecedent(enode* a, enode* b) {
237         sat::bool_var v = get_egraph().explain_diseq(m_explain, a, b);
238         SASSERT(v == sat::null_bool_var || s().value(v) == l_false);
239         if (v != sat::null_bool_var)
240             m_explain.push_back(to_ptr(sat::literal(v, true)));
241     }
242 
propagate(enode * a,enode * b,ext_justification_idx idx)243     bool solver::propagate(enode* a, enode* b, ext_justification_idx idx) {
244         if (a->get_root() == b->get_root())
245             return false;
246         m_egraph.merge(a, b, to_ptr(idx));
247         return true;
248     }
249 
get_antecedents(literal l,constraint & j,literal_vector & r,bool probing)250     void solver::get_antecedents(literal l, constraint& j, literal_vector& r, bool probing) {
251         expr* e = nullptr;
252         euf::enode* n = nullptr;
253 
254         if (!probing && !m_drating)
255             init_ackerman();
256 
257         switch (j.kind()) {
258         case constraint::kind_t::conflict:
259             SASSERT(m_egraph.inconsistent());
260             m_egraph.explain<size_t>(m_explain);
261             break;
262         case constraint::kind_t::eq:
263             e = m_bool_var2expr[l.var()];
264             n = m_egraph.find(e);
265             SASSERT(n);
266             SASSERT(n->is_equality());
267             SASSERT(!l.sign());
268             m_egraph.explain_eq<size_t>(m_explain, n->get_arg(0), n->get_arg(1));
269             break;
270         case constraint::kind_t::lit:
271             e = m_bool_var2expr[l.var()];
272             n = m_egraph.find(e);
273             SASSERT(n);
274             SASSERT(m.is_bool(n->get_expr()));
275             m_egraph.explain_eq<size_t>(m_explain, n, (l.sign() ? mk_false() : mk_true()));
276             break;
277         default:
278             IF_VERBOSE(0, verbose_stream() << (unsigned)j.kind() << "\n");
279             UNREACHABLE();
280         }
281     }
282 
set_eliminated(bool_var v)283     void solver::set_eliminated(bool_var v) {
284         si.uncache(literal(v, false));
285         si.uncache(literal(v, true));
286     }
287 
asserted(literal l)288     void solver::asserted(literal l) {
289         expr* e = m_bool_var2expr.get(l.var(), nullptr);
290         TRACE("euf", tout << "asserted: " << l << "@" << s().scope_lvl() << " := " << mk_bounded_pp(e, m) << "\n";);
291         if (!e)
292             return;
293         euf::enode* n = m_egraph.find(e);
294         if (!n)
295             return;
296         bool sign = l.sign();
297         m_egraph.set_value(n, sign ? l_false : l_true);
298         for (auto th : enode_th_vars(n))
299             m_id2solver[th.get_id()]->asserted(l);
300 
301         size_t* c = to_ptr(l);
302         SASSERT(is_literal(c));
303         SASSERT(l == get_literal(c));
304         if (n->value_conflict()) {
305             euf::enode* nb = sign ? mk_false() : mk_true();
306             euf::enode* r = n->get_root();
307             euf::enode* rb = sign ? mk_true() : mk_false();
308             sat::literal rl(r->bool_var(), r->value() == l_false);
309             m_egraph.merge(n, nb, c);
310             m_egraph.merge(r, rb, to_ptr(rl));
311             SASSERT(m_egraph.inconsistent());
312             return;
313 	    }
314         if (n->merge_tf()) {
315             euf::enode* nb = sign ? mk_false() : mk_true();
316             m_egraph.merge(n, nb, c);
317         }
318         if (n->is_equality()) {
319             SASSERT(!m.is_iff(e));
320             SASSERT(m.is_eq(e));
321             if (sign)
322                 m_egraph.new_diseq(n);
323             else
324                 m_egraph.merge(n->get_arg(0), n->get_arg(1), c);
325         }
326     }
327 
328 
unit_propagate()329     bool solver::unit_propagate() {
330         bool propagated = false;
331         while (!s().inconsistent()) {
332             if (m_egraph.inconsistent()) {
333                 unsigned lvl = s().scope_lvl();
334                 s().set_conflict(sat::justification::mk_ext_justification(lvl, conflict_constraint().to_index()));
335                 return true;
336             }
337             bool propagated1 = false;
338             if (m_egraph.propagate()) {
339                 propagate_literals();
340                 propagate_th_eqs();
341                 propagated1 = true;
342             }
343 
344             for (unsigned i = 0; i < m_solvers.size(); ++i)
345                 if (m_solvers[i]->unit_propagate())
346                     propagated1 = true;
347 
348             if (!propagated1)
349                 break;
350             propagated = true;
351         }
352         DEBUG_CODE(if (!propagated && !s().inconsistent()) check_missing_eq_propagation(););
353         return propagated;
354     }
355 
propagate_literals()356     void solver::propagate_literals() {
357         for (; m_egraph.has_literal() && !s().inconsistent() && !m_egraph.inconsistent(); m_egraph.next_literal()) {
358             auto [n, is_eq] = m_egraph.get_literal();
359             expr* e = n->get_expr();
360             expr* a = nullptr, *b = nullptr;
361             bool_var v = n->bool_var();
362             SASSERT(m.is_bool(e));
363             size_t cnstr;
364             literal lit;
365             if (is_eq) {
366                 VERIFY(m.is_eq(e, a, b));
367                 cnstr = eq_constraint().to_index();
368                 lit = literal(v, false);
369             }
370             else {
371                 lbool val = n->get_root()->value();
372                 if (val == l_undef && m.is_false(n->get_root()->get_expr()))
373                     val = l_false;
374                 if (val == l_undef && m.is_true(n->get_root()->get_expr()))
375                     val = l_true;
376                 a = e;
377                 b = (val == l_true) ? m.mk_true() : m.mk_false();
378                 SASSERT(val != l_undef);
379                 cnstr = lit_constraint().to_index();
380                 lit = literal(v, val == l_false);
381             }
382             unsigned lvl = s().scope_lvl();
383 
384             CTRACE("euf", s().value(lit) != l_true, tout << lit << " " << s().value(lit) << "@" << lvl << " " << is_eq << " " << mk_bounded_pp(a, m) << " = " << mk_bounded_pp(b, m) << "\n";);
385             if (s().value(lit) == l_false && m_ackerman)
386                 m_ackerman->cg_conflict_eh(a, b);
387             switch (s().value(lit)) {
388             case l_true:
389                 break;
390             case l_undef:
391             case l_false:
392                 s().assign(lit, sat::justification::mk_ext_justification(lvl, cnstr));
393                 break;
394             }
395         }
396     }
397 
is_self_propagated(th_eq const & e)398     bool solver::is_self_propagated(th_eq const& e) {
399         if (!e.is_eq())
400             return false;
401 
402         m_egraph.begin_explain();
403         m_explain.reset();
404         m_egraph.explain_eq<size_t>(m_explain, e.child(), e.root());
405         m_egraph.end_explain();
406         if (m_egraph.uses_congruence())
407             return false;
408         for (auto p : m_explain) {
409             if (is_literal(p))
410                 return false;
411 
412             size_t idx = get_justification(p);
413             auto* ext = sat::constraint_base::to_extension(idx);
414             if (ext->get_id() != e.id())
415                 return false;
416             if (ext->enable_self_propagate())
417                 return false;
418         }
419         return true;
420     }
421 
propagate_th_eqs()422     void solver::propagate_th_eqs() {
423         for (; m_egraph.has_th_eq() && !s().inconsistent() && !m_egraph.inconsistent(); m_egraph.next_th_eq()) {
424             th_eq eq = m_egraph.get_th_eq();
425             if (eq.is_eq()) {
426                 if (!is_self_propagated(eq))
427                     m_id2solver[eq.id()]->new_eq_eh(eq);
428             }
429             else
430                 m_id2solver[eq.id()]->new_diseq_eh(eq);
431         }
432     }
433 
mk_constraint(constraint * & c,constraint::kind_t k)434     constraint& solver::mk_constraint(constraint*& c, constraint::kind_t k) {
435         if (!c) {
436             void* mem = memory::allocate(sat::constraint_base::obj_size(sizeof(constraint)));
437             c = new (sat::constraint_base::ptr2mem(mem)) constraint(k);
438             sat::constraint_base::initialize(mem, this);
439         }
440         return *c;
441     }
442 
mk_true()443     enode* solver::mk_true() {
444         VERIFY(visit(m.mk_true()));
445         return m_egraph.find(m.mk_true());
446     }
447 
mk_false()448     enode* solver::mk_false() {
449         VERIFY(visit(m.mk_false()));
450         return m_egraph.find(m.mk_false());
451     }
452 
check()453     sat::check_result solver::check() {
454         ++m_stats.m_final_checks;
455         TRACE("euf", s().display(tout););
456         bool give_up = false;
457         bool cont = false;
458 
459         if (unit_propagate())
460             return sat::check_result::CR_CONTINUE;
461 
462         if (!init_relevancy())
463             give_up = true;
464 
465         unsigned num_nodes = m_egraph.num_nodes();
466         auto apply_solver = [&](th_solver* e) {
467             switch (e->check()) {
468             case sat::check_result::CR_CONTINUE: cont = true; break;
469             case sat::check_result::CR_GIVEUP: give_up = true; break;
470             default: break;
471             }
472         };
473         if (merge_shared_bools())
474             cont = true;
475         for (auto* e : m_solvers) {
476             if (!m.inc())
477                 return sat::check_result::CR_GIVEUP;
478             if (e == m_qsolver)
479                 continue;
480             apply_solver(e);
481             if (s().inconsistent())
482                 return sat::check_result::CR_CONTINUE;
483         }
484 
485 
486         if (s().inconsistent())
487             return sat::check_result::CR_CONTINUE;
488         if (cont)
489             return sat::check_result::CR_CONTINUE;
490         if (m_qsolver)
491             apply_solver(m_qsolver);
492         if (num_nodes < m_egraph.num_nodes())
493             return sat::check_result::CR_CONTINUE;
494         if (cont)
495             return sat::check_result::CR_CONTINUE;
496         TRACE("after_search", s().display(tout););
497         if (give_up)
498             return sat::check_result::CR_GIVEUP;
499         return sat::check_result::CR_DONE;
500     }
501 
merge_shared_bools()502     bool solver::merge_shared_bools() {
503         bool merged = false;
504         for (unsigned i = m_egraph.nodes().size(); i-- > 0; ) {
505             euf::enode* n = m_egraph.nodes()[i];
506             if (!is_shared(n) || !m.is_bool(n->get_expr()))
507                 continue;
508             if (n->value() == l_true && !m.is_true(n->get_root()->get_expr())) {
509                 m_egraph.merge(n, mk_true(), to_ptr(sat::literal(n->bool_var())));
510                 merged = true;
511             }
512             if (n->value() == l_false && !m.is_false(n->get_root()->get_expr())) {
513                 m_egraph.merge(n, mk_false(), to_ptr(~sat::literal(n->bool_var())));
514                 merged = true;
515             }
516         }
517         return merged;
518     }
519 
push()520     void solver::push() {
521         si.push();
522         scope s;
523         s.m_var_lim = m_var_trail.size();
524         m_scopes.push_back(s);
525         m_trail.push_scope();
526         for (auto* e : m_solvers)
527             e->push();
528         m_egraph.push();
529         if (m_dual_solver)
530             m_dual_solver->push();
531         push_relevant();
532     }
533 
pop(unsigned n)534     void solver::pop(unsigned n) {
535         start_reinit(n);
536         m_trail.pop_scope(n);
537         for (auto* e : m_solvers)
538             e->pop(n);
539         si.pop(n);
540         m_egraph.pop(n);
541         pop_relevant(n);
542         scope const & sc = m_scopes[m_scopes.size() - n];
543         for (unsigned i = m_var_trail.size(); i-- > sc.m_var_lim; ) {
544             bool_var v = m_var_trail[i];
545             m_bool_var2expr[v] = nullptr;
546             s().set_non_external(v);
547         }
548         m_var_trail.shrink(sc.m_var_lim);
549         m_scopes.shrink(m_scopes.size() - n);
550         if (m_dual_solver)
551             m_dual_solver->pop(n);
552         SASSERT(m_egraph.num_scopes() == m_scopes.size());
553         TRACE("euf_verbose", display(tout << "pop to: " << m_scopes.size() << "\n"););
554     }
555 
user_push()556     void solver::user_push() {
557         push();
558     }
559 
user_pop(unsigned n)560     void solver::user_pop(unsigned n) {
561         pop(n);
562     }
563 
start_reinit(unsigned n)564     void solver::start_reinit(unsigned n) {
565         m_reinit.reset();
566         for (sat::bool_var v : s().get_vars_to_reinit()) {
567             expr* e = bool_var2expr(v);
568             if (e)
569                 m_reinit.push_back(reinit_t(expr_ref(e, m), get_enode(e)?get_enode(e)->generation():0, v));
570         }
571     }
572 
573     /**
574     * After a pop has completed, re-initialize the association between Boolean variables
575     * and the theories by re-creating the expression/variable mapping used for Booleans
576     * and replaying internalization.
577     */
finish_reinit()578     void solver::finish_reinit() {
579         if (m_reinit.empty())
580             return;
581 
582         struct scoped_set_replay {
583             solver& s;
584             obj_map<expr, sat::bool_var> m;
585             scoped_set_replay(solver& s) :s(s) {
586                 s.si.set_expr2var_replay(&m);
587             }
588             ~scoped_set_replay() {
589                 s.si.set_expr2var_replay(nullptr);
590             }
591         };
592         scoped_set_replay replay(*this);
593         scoped_suspend_rlimit suspend_rlimit(m.limit());
594 
595         for (auto const& [e, generation, v] : m_reinit)
596             replay.m.insert(e, v);
597 
598         TRACE("euf", for (auto const& kv : replay.m) tout << kv.m_value << "\n";);
599         for (auto const& [e, generation, v] : m_reinit) {
600             scoped_generation _sg(*this, generation);
601             TRACE("euf", tout << "replay: " << v << " " << e->get_id() << " " << mk_bounded_pp(e, m) << " " << si.is_bool_op(e) << "\n";);
602             sat::literal lit;
603             if (si.is_bool_op(e))
604                 lit = literal(replay.m[e], false);
605             else
606                 lit = si.internalize(e, true);
607             VERIFY(lit.var() == v);
608             if (!m_egraph.find(e) && (!m.is_iff(e) && !m.is_or(e) && !m.is_and(e) && !m.is_not(e))) {
609                 ptr_buffer<euf::enode> args;
610                 if (is_app(e))
611                     for (expr* arg : *to_app(e))
612                         args.push_back(e_internalize(arg));
613                 if (!m_egraph.find(e))
614                     mk_enode(e, args.size(), args.data());
615             }
616             attach_lit(lit, e);
617         }
618 
619         if (relevancy_enabled())
620             for (auto const& [e, generation, v] : m_reinit)
621                 if (si.is_bool_op(e))
622                     relevancy_reinit(e);
623         TRACE("euf", display(tout << "replay done\n"););
624     }
625 
626     /**
627     * Boolean structure needs to be replayed for relevancy tracking.
628     * Main cases for replaying Boolean functions are included. When a replay
629     * is not supported, we just disable relevancy.
630     */
relevancy_reinit(expr * e)631     void solver::relevancy_reinit(expr* e) {
632         TRACE("euf", tout << "internalize again " << mk_pp(e, m) << "\n";);
633         if (to_app(e)->get_family_id() != m.get_basic_family_id()) {
634             disable_relevancy(e);
635             return;
636         }
637         auto lit = si.internalize(e, true);
638         switch (to_app(e)->get_decl_kind()) {
639         case OP_NOT: {
640             auto lit2 = si.internalize(to_app(e)->get_arg(0), true);
641             add_aux(lit, lit2);
642             add_aux(~lit, ~lit2);
643             break;
644         }
645         case OP_EQ: {
646             if (to_app(e)->get_num_args() != 2) {
647                 disable_relevancy(e);
648                 return;
649             }
650             auto lit1 = si.internalize(to_app(e)->get_arg(0), true);
651             auto lit2 = si.internalize(to_app(e)->get_arg(1), true);
652             add_aux(~lit, ~lit1, lit2);
653             add_aux(~lit, lit1, ~lit2);
654             add_aux(lit, lit1, lit2);
655             add_aux(lit, ~lit1, ~lit2);
656             break;
657         }
658         case OP_OR: {
659             sat::literal_vector lits;
660             for (expr* arg : *to_app(e))
661                 lits.push_back(si.internalize(arg, true));
662             for (auto lit2 : lits)
663                 add_aux(~lit2, lit);
664             lits.push_back(~lit);
665             add_aux(lits);
666             break;
667         }
668         case OP_AND: {
669             sat::literal_vector lits;
670             for (expr* arg : *to_app(e))
671                 lits.push_back(~si.internalize(arg, true));
672             for (auto nlit2 : lits)
673                 add_aux(~lit, ~nlit2);
674             lits.push_back(lit);
675             add_aux(lits);
676             break;
677         }
678         case OP_TRUE:
679             add_aux(lit);
680             break;
681         case OP_FALSE:
682             add_aux(~lit);
683             break;
684         case OP_ITE: {
685             auto lit1 = si.internalize(to_app(e)->get_arg(0), true);
686             auto lit2 = si.internalize(to_app(e)->get_arg(1), true);
687             auto lit3 = si.internalize(to_app(e)->get_arg(2), true);
688             add_aux(~lit, ~lit1, lit2);
689             add_aux(~lit, lit1, lit3);
690             add_aux(lit, ~lit1, ~lit2);
691             add_aux(lit, lit1, ~lit3);
692             break;
693         }
694         case OP_XOR: {
695             if (to_app(e)->get_num_args() != 2) {
696                 disable_relevancy(e);
697                 break;
698             }
699             auto lit1 = si.internalize(to_app(e)->get_arg(0), true);
700             auto lit2 = si.internalize(to_app(e)->get_arg(1), true);
701             add_aux(lit, ~lit1, lit2);
702             add_aux(lit, lit1, ~lit2);
703             add_aux(~lit, lit1, lit2);
704             add_aux(~lit, ~lit1, ~lit2);
705             break;
706         }
707         case OP_IMPLIES: {
708             if (to_app(e)->get_num_args() != 2) {
709                 disable_relevancy(e);
710                 break;
711             }
712             auto lit1 = si.internalize(to_app(e)->get_arg(0), true);
713             auto lit2 = si.internalize(to_app(e)->get_arg(1), true);
714             add_aux(~lit, ~lit1, lit2);
715             add_aux(lit, lit1);
716             add_aux(lit, ~lit2);
717             break;
718         }
719         default:
720             UNREACHABLE();
721         }
722     }
723 
pre_simplify()724     void solver::pre_simplify() {
725         for (auto* e : m_solvers)
726             e->pre_simplify();
727     }
728 
simplify()729     void solver::simplify() {
730         for (auto* e : m_solvers)
731             e->simplify();
732         if (m_ackerman)
733             m_ackerman->propagate();
734     }
735 
should_research(sat::literal_vector const & core)736     bool solver::should_research(sat::literal_vector const& core) {
737         bool result = false;
738         for (auto* e : m_solvers)
739             if (e->should_research(core))
740                 result = true;
741         return result;
742     }
743 
add_assumptions(sat::literal_set & assumptions)744     void solver::add_assumptions(sat::literal_set& assumptions) {
745         for (auto* e : m_solvers)
746             e->add_assumptions(assumptions);
747     }
748 
tracking_assumptions()749     bool solver::tracking_assumptions() {
750         for (auto* e : m_solvers)
751             if (e->tracking_assumptions())
752                 return true;
753         return false;
754     }
755 
clauses_modifed()756     void solver::clauses_modifed() {
757         for (auto* e : m_solvers)
758             e->clauses_modifed();
759     }
760 
get_phase(bool_var v)761     lbool solver::get_phase(bool_var v) {
762         auto* ext = bool_var2solver(v);
763         if (ext)
764             return ext->get_phase(v);
765         return l_undef;
766     }
767 
set_root(literal l,literal r)768     bool solver::set_root(literal l, literal r) {
769         expr* e = bool_var2expr(l.var());
770         if (!e)
771             return true;
772         bool ok = true;
773         for (auto* s : m_solvers)
774             if (!s->set_root(l, r))
775                 ok = false;
776         if (m.is_eq(e) && !m.is_iff(e))
777             ok = false;
778         euf::enode* n = get_enode(e);
779         if (n && n->merge_enabled())
780             ok = false;
781 
782         (void)ok;
783         TRACE("euf", tout << ok << " " << l << " -> " << r << "\n";);
784         // roots cannot be eliminated as long as the egraph contains the expressions.
785         return false;
786     }
787 
flush_roots()788     void solver::flush_roots() {
789         for (auto* s : m_solvers)
790             s->flush_roots();
791     }
792 
display(std::ostream & out) const793     std::ostream& solver::display(std::ostream& out) const {
794         m_egraph.display(out);
795         out << "bool-vars\n";
796         for (unsigned v : m_var_trail) {
797             expr* e = m_bool_var2expr[v];
798             out << v << ": " << e->get_id() << " " << m_solver->value(v) << " " << mk_bounded_pp(e, m, 1) << "\n";
799         }
800         for (auto* e : m_solvers)
801             e->display(out);
802         return out;
803     }
804 
display_justification_ptr(std::ostream & out,size_t * j) const805     std::ostream& solver::display_justification_ptr(std::ostream& out, size_t* j) const {
806         if (is_literal(j))
807             return out << "sat: " << get_literal(j);
808         else
809             return display_justification(out, get_justification(j));
810     }
811 
display_justification(std::ostream & out,ext_justification_idx idx) const812     std::ostream& solver::display_justification(std::ostream& out, ext_justification_idx idx) const {
813         auto* ext = sat::constraint_base::to_extension(idx);
814         if (ext == this) {
815             constraint& c = constraint::from_idx(idx);
816             switch (c.kind()) {
817             case constraint::kind_t::conflict:
818                 return out << "euf conflict";
819             case constraint::kind_t::eq:
820                 return out << "euf equality propagation";
821             case constraint::kind_t::lit:
822                 return out << "euf literal propagation";
823             default:
824                 UNREACHABLE();
825                 return out;
826             }
827         }
828         else
829             return ext->display_justification(out, idx);
830         return out;
831     }
832 
display_constraint(std::ostream & out,ext_constraint_idx idx) const833     std::ostream& solver::display_constraint(std::ostream& out, ext_constraint_idx idx) const {
834         auto* ext = sat::constraint_base::to_extension(idx);
835         if (ext != this)
836             return ext->display_constraint(out, idx);
837         return display_justification(out, idx);
838     }
839 
collect_statistics(statistics & st) const840     void solver::collect_statistics(statistics& st) const {
841         m_egraph.collect_statistics(st);
842         for (auto* e : m_solvers)
843             e->collect_statistics(st);
844         st.update("euf ackerman", m_stats.m_ackerman);
845         st.update("euf final check", m_stats.m_final_checks);
846     }
847 
copy(solver & dst_ctx,enode * src_n)848     enode* solver::copy(solver& dst_ctx, enode* src_n) {
849         if (!src_n)
850             return nullptr;
851         ast_translation tr(m, dst_ctx.get_manager(), false);
852         expr* e1 = src_n->get_expr();
853         expr* e2 = tr(e1);
854         euf::enode* n2 = dst_ctx.get_enode(e2);
855         SASSERT(n2);
856         return n2;
857     }
858 
copy(sat::solver * s)859     sat::extension* solver::copy(sat::solver* s) {
860         auto* r = alloc(solver, *m_to_m, *m_to_si);
861         r->m_config = m_config;
862         sat::literal true_lit = sat::null_literal;
863         if (s->init_trail_size() > 0)
864             true_lit = s->trail_literal(0);
865         std::function<void* (void*)> copy_justification = [&](void* x) {
866             SASSERT(true_lit != sat::null_literal);
867             return (void*)(r->to_ptr(true_lit));
868         };
869         r->m_egraph.copy_from(m_egraph, copy_justification);
870         r->set_solver(s);
871         for (euf::enode* n : r->m_egraph.nodes()) {
872             auto b = n->bool_var();
873             if (b != sat::null_bool_var) {
874                 r->m_bool_var2expr.setx(b, n->get_expr(), nullptr);
875                 SASSERT(r->m.is_bool(n->get_sort()));
876                 IF_VERBOSE(11, verbose_stream() << "set bool_var " << b << " " << r->bpp(n) << " " << mk_bounded_pp(n->get_expr(), m) << "\n");
877             }
878         }
879         for (auto* s_orig : m_id2solver) {
880             if (s_orig) {
881                 auto* s_clone = s_orig->clone(*r);
882                 r->add_solver(s_clone);
883                 s_clone->set_solver(s);
884             }
885         }
886         return r;
887     }
888 
find_mutexes(literal_vector & lits,vector<literal_vector> & mutexes)889     void solver::find_mutexes(literal_vector& lits, vector<literal_vector> & mutexes) {
890         for (auto* e : m_solvers)
891             e->find_mutexes(lits, mutexes);
892     }
893 
gc()894     void solver::gc() {
895         for (auto* e : m_solvers)
896             e->gc();
897     }
898 
pop_reinit()899     void solver::pop_reinit() {
900         finish_reinit();
901         for (auto* e : m_solvers)
902             e->pop_reinit();
903 
904 #if 0
905         for (enode* n : m_egraph.nodes()) {
906             if (n->bool_var() != sat::null_bool_var && s().is_free(n->bool_var()))
907                 std::cout << "has free " << n->bool_var() << "\n";
908         }
909 #endif
910     }
911 
validate()912     bool solver::validate() {
913         for (auto* e : m_solvers)
914             if (!e->validate())
915                 return false;
916         check_eqc_bool_assignment();
917         check_missing_bool_enode_propagation();
918         check_missing_eq_propagation();
919         m_egraph.invariant();
920         return true;
921     }
922 
init_use_list(sat::ext_use_list & ul)923     void solver::init_use_list(sat::ext_use_list& ul) {
924         for (auto* e : m_solvers)
925             e->init_use_list(ul);
926     }
927 
is_blocked(literal l,ext_constraint_idx idx)928     bool solver::is_blocked(literal l, ext_constraint_idx idx) {
929         auto* ext = sat::constraint_base::to_extension(idx);
930         if (ext != this)
931             return ext->is_blocked(l, idx);
932         return false;
933     }
934 
check_model(sat::model const & m) const935     bool solver::check_model(sat::model const& m) const {
936         for (auto* e : m_solvers)
937             if (!e->check_model(m))
938                 return false;
939         return true;
940     }
941 
gc_vars(unsigned num_vars)942     void solver::gc_vars(unsigned num_vars) {
943         for (auto* e : m_solvers)
944             e->gc_vars(num_vars);
945     }
946 
get_reward(literal l,ext_constraint_idx idx,sat::literal_occs_fun & occs) const947     double solver::get_reward(literal l, ext_constraint_idx idx, sat::literal_occs_fun& occs) const {
948         auto* ext = sat::constraint_base::to_extension(idx);
949         SASSERT(ext);
950         return (ext == this) ? 0 : ext->get_reward(l, idx, occs);
951     }
952 
is_extended_binary(ext_justification_idx idx,literal_vector & r)953     bool solver::is_extended_binary(ext_justification_idx idx, literal_vector& r) {
954         auto* ext = sat::constraint_base::to_extension(idx);
955         SASSERT(ext);
956         return (ext != this) && ext->is_extended_binary(idx, r);
957     }
958 
init_ackerman()959     void solver::init_ackerman() {
960         if (m_ackerman)
961             return;
962         if (m_config.m_dack == dyn_ack_strategy::DACK_DISABLED)
963             return;
964         m_ackerman = alloc(ackerman, *this, m);
965         std::function<void(expr*,expr*,expr*)> used_eq = [&](expr* a, expr* b, expr* lca) {
966             m_ackerman->used_eq_eh(a, b, lca);
967         };
968         std::function<void(app*,app*)> used_cc = [&](app* a, app* b) {
969             m_ackerman->used_cc_eh(a, b);
970         };
971         m_egraph.set_used_eq(used_eq);
972         m_egraph.set_used_cc(used_cc);
973     }
974 
to_formulas(std::function<expr_ref (sat::literal)> & l2e,expr_ref_vector & fmls)975     bool solver::to_formulas(std::function<expr_ref(sat::literal)>& l2e, expr_ref_vector& fmls) {
976         for (auto* th : m_solvers) {
977             if (!th->to_formulas(l2e, fmls))
978                 return false;
979         }
980         for (euf::enode* n : m_egraph.nodes()) {
981             if (!n->is_root())
982                 fmls.push_back(m.mk_eq(n->get_expr(), n->get_root()->get_expr()));
983         }
984         return true;
985     }
986 
extract_pb(std::function<void (unsigned sz,literal const * c,unsigned k)> & card,std::function<void (unsigned sz,literal const * c,unsigned const * coeffs,unsigned k)> & pb)987     bool solver::extract_pb(std::function<void(unsigned sz, literal const* c, unsigned k)>& card,
988                             std::function<void(unsigned sz, literal const* c, unsigned const* coeffs, unsigned k)>& pb) {
989         for (auto* e : m_solvers)
990             if (!e->extract_pb(card, pb))
991                 return false;
992         return true;
993     }
994 
user_propagate_init(void * ctx,::solver::push_eh_t & push_eh,::solver::pop_eh_t & pop_eh,::solver::fresh_eh_t & fresh_eh)995     void solver::user_propagate_init(
996         void* ctx,
997         ::solver::push_eh_t& push_eh,
998         ::solver::pop_eh_t& pop_eh,
999         ::solver::fresh_eh_t& fresh_eh) {
1000         m_user_propagator = alloc(user_solver::solver, *this);
1001         m_user_propagator->add(ctx, push_eh, pop_eh, fresh_eh);
1002         for (unsigned i = m_scopes.size(); i-- > 0; )
1003             m_user_propagator->push();
1004         m_solvers.push_back(m_user_propagator);
1005         m_id2solver.setx(m_user_propagator->get_id(), m_user_propagator, nullptr);
1006     }
1007 
watches_fixed(enode * n) const1008     bool solver::watches_fixed(enode* n) const {
1009         return m_user_propagator && m_user_propagator->has_fixed() && n->get_th_var(m_user_propagator->get_id()) != null_theory_var;
1010     }
1011 
assign_fixed(enode * n,expr * val,unsigned sz,literal const * explain)1012     void solver::assign_fixed(enode* n, expr* val, unsigned sz, literal const* explain) {
1013         theory_var v = n->get_th_var(m_user_propagator->get_id());
1014         m_user_propagator->new_fixed_eh(v, val, sz, explain);
1015     }
1016 
1017 
1018 }
1019