1 /*++
2 Copyright (c) 2020 Microsoft Corporation
3 
4 Module Name:
5 
6     sat_th.cpp
7 
8 Abstract:
9 
10     Theory plugin base classes
11 
12 Author:
13 
14     Nikolaj Bjorner (nbjorner) 2020-08-25
15 
16 --*/
17 
18 #include "sat/smt/sat_th.h"
19 #include "sat/smt/euf_solver.h"
20 #include "tactic/tactic_exception.h"
21 
22 namespace euf {
23 
visit_rec(ast_manager & m,expr * a,bool sign,bool root,bool redundant)24     bool th_internalizer::visit_rec(ast_manager& m, expr* a, bool sign, bool root, bool redundant) {
25         IF_VERBOSE(110, verbose_stream() << "internalize: " << mk_pp(a, m) << "\n");
26         flet<bool> _is_learned(m_is_redundant, redundant);
27         svector<sat::eframe>::scoped_stack _sc(m_stack);
28         unsigned sz = m_stack.size();
29         visit(a);
30         while (m_stack.size() > sz) {
31         loop:
32             if (!m.inc())
33                 throw tactic_exception(m.limit().get_cancel_msg());
34             unsigned fsz = m_stack.size();
35             expr* e = m_stack[fsz-1].m_e;
36             if (visited(e)) {
37                 m_stack.pop_back();
38                 continue;
39             }
40             unsigned num = is_app(e) ? to_app(e)->get_num_args() : 0;
41 
42             while (m_stack[fsz - 1].m_idx < num) {
43                 expr* arg = to_app(e)->get_arg(m_stack[fsz - 1].m_idx);
44                 m_stack[fsz - 1].m_idx++;
45                 if (!visit(arg))
46                     goto loop;
47             }
48             if (!visited(e) && !post_visit(e, sign, root && a == e))
49                 return false;
50             m_stack.pop_back();
51         }
52         return true;
53     }
54 
th_euf_solver(euf::solver & ctx,symbol const & name,euf::theory_id id)55     th_euf_solver::th_euf_solver(euf::solver& ctx, symbol const& name, euf::theory_id id):
56         th_solver(ctx.get_manager(), name, id),
57         ctx(ctx)
58     {}
59 
get_config() const60     smt_params const& th_euf_solver::get_config() const {
61         return ctx.get_config();
62     }
63 
get_region()64     region& th_euf_solver::get_region() {
65         return ctx.get_region();
66     }
67 
get_trail_stack()68     trail_stack& th_euf_solver::get_trail_stack() {
69         return ctx.get_trail_stack();
70     }
71 
expr2enode(expr * e) const72     enode* th_euf_solver::expr2enode(expr* e) const {
73         return ctx.get_enode(e);
74     }
75 
expr2literal(expr * e) const76     sat::literal th_euf_solver::expr2literal(expr* e) const {
77         return ctx.expr2literal(e);
78     }
79 
bool_var2expr(sat::bool_var v) const80     expr* th_euf_solver::bool_var2expr(sat::bool_var v) const {
81         return ctx.bool_var2expr(v);
82     }
83 
literal2expr(sat::literal lit) const84     expr_ref th_euf_solver::literal2expr(sat::literal lit) const {
85         return ctx.literal2expr(lit);
86     }
87 
mk_var(enode * n)88     theory_var th_euf_solver::mk_var(enode * n) {
89         force_push();
90         SASSERT(!is_attached_to_var(n));
91         euf::theory_var v = m_var2enode.size();
92         m_var2enode.push_back(n);
93         return v;
94     }
95 
is_attached_to_var(enode * n) const96     bool th_euf_solver::is_attached_to_var(enode* n) const {
97         theory_var v = n->get_th_var(get_id());
98         return v != null_theory_var && var2enode(v) == n;
99     }
100 
get_th_var(expr * e) const101     theory_var th_euf_solver::get_th_var(expr* e) const {
102         return get_th_var(ctx.get_enode(e));
103     }
104 
get_representative(theory_var v) const105     theory_var th_euf_solver::get_representative(theory_var v) const {
106         euf::enode* r = var2enode(v)->get_root();
107         return get_th_var(r);
108     }
109 
push_core()110     void th_euf_solver::push_core() {
111         m_var2enode_lim.push_back(m_var2enode.size());
112     }
113 
pop_core(unsigned num_scopes)114     void th_euf_solver::pop_core(unsigned num_scopes) {
115         unsigned new_lvl = m_var2enode_lim.size() - num_scopes;
116         m_var2enode.shrink(m_var2enode_lim[new_lvl]);
117         m_var2enode_lim.shrink(new_lvl);
118     }
119 
pop(unsigned n)120     void th_euf_solver::pop(unsigned n) {
121         unsigned k = std::min(m_num_scopes, n);
122         m_num_scopes -= k;
123         n -= k;
124         if (n > 0)
125             pop_core(n);
126     }
127 
mk_status()128     sat::status th_euf_solver::mk_status() {
129         return sat::status::th(m_is_redundant, get_id());
130     }
131 
add_unit(sat::literal lit)132     bool th_euf_solver::add_unit(sat::literal lit) {
133         bool was_true = is_true(lit);
134         ctx.s().add_clause(1, &lit, mk_status());
135         return !was_true;
136     }
137 
add_units(sat::literal_vector const & lits)138     bool th_euf_solver::add_units(sat::literal_vector const& lits) {
139         bool is_new = false;
140         for (auto lit : lits)
141             if (add_unit(lit))
142                 is_new = true;
143         return is_new;
144     }
145 
add_clause(sat::literal a,sat::literal b)146     bool th_euf_solver::add_clause(sat::literal a, sat::literal b) {
147         bool was_true = is_true(a, b);
148         sat::literal lits[2] = { a, b };
149         ctx.s().add_clause(2, lits, mk_status());
150         return !was_true;
151     }
152 
add_clause(sat::literal a,sat::literal b,sat::literal c)153     bool th_euf_solver::add_clause(sat::literal a, sat::literal b, sat::literal c) {
154         bool was_true = is_true(a, b, c);
155         sat::literal lits[3] = { a, b, c };
156         ctx.s().add_clause(3, lits, mk_status());
157         return !was_true;
158     }
159 
add_clause(sat::literal a,sat::literal b,sat::literal c,sat::literal d)160     bool th_euf_solver::add_clause(sat::literal a, sat::literal b, sat::literal c, sat::literal d) {
161         bool was_true = is_true(a, b, c, d);
162         sat::literal lits[4] = { a, b, c, d };
163         ctx.s().add_clause(4, lits, mk_status());
164         return !was_true;
165     }
166 
add_clause(sat::literal_vector const & lits)167     bool th_euf_solver::add_clause(sat::literal_vector const& lits) {
168         bool was_true = false;
169         for (auto lit : lits)
170             was_true |= is_true(lit);
171         s().add_clause(lits.size(), lits.data(), mk_status());
172         return !was_true;
173     }
174 
add_equiv(sat::literal a,sat::literal b)175     void th_euf_solver::add_equiv(sat::literal a, sat::literal b) {
176         add_clause(~a, b);
177         add_clause(a, ~b);
178     }
179 
add_equiv_and(sat::literal a,sat::literal_vector const & bs)180     void th_euf_solver::add_equiv_and(sat::literal a, sat::literal_vector const& bs) {
181         for (auto b : bs)
182             add_clause(~a, b);
183         sat::literal_vector _bs;
184         for (auto b : bs)
185             _bs.push_back(~b);
186         _bs.push_back(a);
187         add_clause(_bs);
188     }
189 
is_true(sat::literal lit)190     bool th_euf_solver::is_true(sat::literal lit) {
191         return ctx.s().value(lit) == l_true;
192     }
193 
mk_enode(expr * e,bool suppress_args)194     euf::enode* th_euf_solver::mk_enode(expr* e, bool suppress_args) {
195         m_args.reset();
196         if (!suppress_args)
197             for (expr* arg : *to_app(e))
198                 m_args.push_back(expr2enode(arg));
199         euf::enode* n = ctx.mk_enode(e, m_args.size(), m_args.data());
200         ctx.attach_node(n);
201         return n;
202     }
203 
rewrite(expr_ref & a)204     void th_euf_solver::rewrite(expr_ref& a) {
205         ctx.get_rewriter()(a);
206     }
207 
mk_eq(expr * e1,expr * e2)208     expr_ref th_euf_solver::mk_eq(expr* e1, expr* e2) {
209         return ctx.mk_eq(e1, e2);
210     }
211 
mk_literal(expr * e) const212     sat::literal th_euf_solver::mk_literal(expr* e) const {
213         return ctx.mk_literal(e);
214     }
215 
eq_internalize(expr * a,expr * b)216     sat::literal th_euf_solver::eq_internalize(expr* a, expr* b) {
217         return mk_literal(ctx.mk_eq(a, b));
218     }
219 
e_internalize(expr * e)220     euf::enode* th_euf_solver::e_internalize(expr* e) {
221         return ctx.e_internalize(e);
222     }
223 
random()224     unsigned th_euf_solver::random() {
225         return ctx.s().rand()();
226     }
227 
get_obj_size(unsigned num_lits,unsigned num_eqs)228     size_t th_explain::get_obj_size(unsigned num_lits, unsigned num_eqs) {
229         return sat::constraint_base::obj_size(sizeof(th_explain) + sizeof(sat::literal) * num_lits + sizeof(enode_pair) * num_eqs);
230     }
231 
th_explain(unsigned n_lits,sat::literal const * lits,unsigned n_eqs,enode_pair const * eqs,sat::literal c,enode_pair const & p)232     th_explain::th_explain(unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode_pair const& p) {
233         m_consequent = c;
234         m_eq = p;
235         m_num_literals = n_lits;
236         m_num_eqs = n_eqs;
237         m_literals = reinterpret_cast<literal*>(reinterpret_cast<char*>(this) + sizeof(th_explain));
238         for (unsigned i = 0; i < n_lits; ++i)
239             m_literals[i] = lits[i];
240         m_eqs = reinterpret_cast<enode_pair*>(reinterpret_cast<char*>(this) + sizeof(th_explain) + sizeof(literal) * n_lits);
241         for (unsigned i = 0; i < n_eqs; ++i)
242             m_eqs[i] = eqs[i];
243 
244     }
245 
mk(th_euf_solver & th,unsigned n_lits,sat::literal const * lits,unsigned n_eqs,enode_pair const * eqs,sat::literal c,enode * x,enode * y)246     th_explain* th_explain::mk(th_euf_solver& th, unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode* x, enode* y) {
247         region& r = th.ctx.get_region();
248         void* mem = r.allocate(get_obj_size(n_lits, n_eqs));
249         sat::constraint_base::initialize(mem, &th);
250         return new (sat::constraint_base::ptr2mem(mem)) th_explain(n_lits, lits, n_eqs, eqs, c, enode_pair(x, y));
251     }
252 
propagate(th_euf_solver & th,sat::literal_vector const & lits,enode_pair_vector const & eqs,sat::literal consequent)253     th_explain* th_explain::propagate(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs, sat::literal consequent) {
254         return mk(th, lits.size(), lits.data(), eqs.size(), eqs.data(), consequent, nullptr, nullptr);
255     }
256 
propagate(th_euf_solver & th,sat::literal_vector const & lits,enode_pair_vector const & eqs,euf::enode * x,euf::enode * y)257     th_explain* th_explain::propagate(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs, euf::enode* x, euf::enode* y) {
258         return mk(th, lits.size(), lits.data(), eqs.size(), eqs.data(), sat::null_literal, x, y);
259     }
260 
propagate(th_euf_solver & th,sat::literal lit,euf::enode * x,euf::enode * y)261     th_explain* th_explain::propagate(th_euf_solver& th, sat::literal lit, euf::enode* x, euf::enode* y) {
262         return mk(th, 1, &lit, 0, nullptr, sat::null_literal, x, y);
263     }
264 
conflict(th_euf_solver & th,sat::literal_vector const & lits,enode_pair_vector const & eqs)265     th_explain* th_explain::conflict(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs) {
266         return conflict(th, lits.size(), lits.data(), eqs.size(), eqs.data());
267     }
268 
conflict(th_euf_solver & th,unsigned n_lits,sat::literal const * lits,unsigned n_eqs,enode_pair const * eqs)269     th_explain* th_explain::conflict(th_euf_solver& th, unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs) {
270         return mk(th, n_lits, lits, n_eqs, eqs, sat::null_literal, nullptr, nullptr);
271     }
272 
conflict(th_euf_solver & th,enode_pair_vector const & eqs)273     th_explain* th_explain::conflict(th_euf_solver& th, enode_pair_vector const& eqs) {
274         return conflict(th, 0, nullptr, eqs.size(), eqs.data());
275     }
276 
conflict(th_euf_solver & th,sat::literal lit)277     th_explain* th_explain::conflict(th_euf_solver& th, sat::literal lit) {
278         return conflict(th, 1, &lit, 0, nullptr);
279     }
280 
conflict(th_euf_solver & th,sat::literal lit,euf::enode * x,euf::enode * y)281     th_explain* th_explain::conflict(th_euf_solver& th, sat::literal lit, euf::enode* x, euf::enode* y) {
282         enode_pair eq(x, y);
283         return conflict(th, 1, &lit, 1, &eq);
284     }
285 
conflict(th_euf_solver & th,euf::enode * x,euf::enode * y)286     th_explain* th_explain::conflict(th_euf_solver& th, euf::enode* x, euf::enode* y) {
287         enode_pair eq(x, y);
288         return conflict(th, 0, nullptr, 1, &eq);
289     }
290 
display(std::ostream & out) const291     std::ostream& th_explain::display(std::ostream& out) const {
292         for (auto lit : euf::th_explain::lits(*this))
293             out << lit << " ";
294         for (auto eq : euf::th_explain::eqs(*this))
295             out << eq.first->get_expr_id() << " == " << eq.second->get_expr_id() << " ";
296         if (m_consequent != sat::null_literal)
297             out << "--> " << m_consequent;
298         if (m_eq.first != nullptr)
299             out << "--> " << m_eq.first->get_expr_id() << " == " << m_eq.second->get_expr_id();
300         return out;
301     }
302 
303 
304 }
305