1 /*++
2 Copyright (c) 2020 Microsoft Corporation
3 
4 Module Name:
5 
6     sat_th.h
7 
8 Abstract:
9 
10     Theory plugins
11 
12 Author:
13 
14     Nikolaj Bjorner (nbjorner) 2020-08-25
15 
16 --*/
17 #pragma once
18 
19 #include "util/top_sort.h"
20 #include "sat/smt/sat_smt.h"
21 #include "ast/euf/euf_egraph.h"
22 #include "model/model.h"
23 #include "smt/params/smt_params.h"
24 
25 namespace euf {
26 
27     class solver;
28 
29     class th_internalizer {
30     protected:
31         euf::enode_vector     m_args;
32         svector<sat::eframe>  m_stack;
33         bool                  m_is_redundant{ false };
34 
35         bool visit_rec(ast_manager& m, expr* e, bool sign, bool root, bool redundant);
36 
visit(expr * e)37         virtual bool visit(expr* e) { return false; }
visited(expr * e)38         virtual bool visited(expr* e) { return false; }
post_visit(expr * e,bool sign,bool root)39         virtual bool post_visit(expr* e, bool sign, bool root) { return false; }
40 
41     public:
~th_internalizer()42         virtual ~th_internalizer() {}
43 
44         virtual sat::literal internalize(expr* e, bool sign, bool root, bool redundant) = 0;
45 
46         virtual void internalize(expr* e, bool redundant) = 0;
47 
48 
49         /**
50            \brief Apply (interpreted) sort constraints on the given enode.
51         */
apply_sort_cnstr(enode * n,sort * s)52         virtual void apply_sort_cnstr(enode* n, sort* s) {}
53 
54         /**
55            \brief Record that an equality has been internalized.
56          */
eq_internalized(enode * n)57         virtual void eq_internalized(enode* n) {}
58 
59     };
60 
61     class th_decompile {
62     public:
~th_decompile()63         virtual ~th_decompile() {}
64 
to_formulas(std::function<expr_ref (sat::literal)> & lit2expr,expr_ref_vector & fmls)65         virtual bool to_formulas(std::function<expr_ref(sat::literal)>& lit2expr, expr_ref_vector& fmls) { return false; }
66     };
67 
68     class th_model_builder {
69     public:
70 
~th_model_builder()71         virtual ~th_model_builder() {}
72 
73         /**
74            \brief compute the value for enode \c n and store the value in \c values
75            for the root of the class of \c n.
76          */
add_value(euf::enode * n,model & mdl,expr_ref_vector & values)77         virtual void add_value(euf::enode* n, model& mdl, expr_ref_vector& values) {}
78 
79         /**
80            \brief compute dependencies for node n
81          */
add_dep(euf::enode * n,top_sort<euf::enode> & dep)82         virtual bool add_dep(euf::enode* n, top_sort<euf::enode>& dep) { dep.insert(n, nullptr); return true; }
83 
84         /**
85            \brief should function be included in model.
86         */
include_func_interp(func_decl * f)87         virtual bool include_func_interp(func_decl* f) const { return false; }
88 
89         /**
90           \brief initialize model building
91         */
init_model()92         virtual void init_model() {}
93 
94         /**
95           \brief conclude model building
96         */
finalize_model(model & mdl)97         virtual void finalize_model(model& mdl) {}
98     };
99 
100     class th_solver : public sat::extension, public th_model_builder, public th_decompile, public th_internalizer {
101     protected:
102         ast_manager& m;
103     public:
th_solver(ast_manager & m,symbol const & name,euf::theory_id id)104         th_solver(ast_manager& m, symbol const& name, euf::theory_id id) : extension(name, id), m(m) {}
105 
106         virtual th_solver* clone(euf::solver& ctx) = 0;
107 
new_eq_eh(euf::th_eq const & eq)108         virtual void new_eq_eh(euf::th_eq const& eq) {}
109 
use_diseqs()110         virtual bool use_diseqs() const { return false; }
111 
new_diseq_eh(euf::th_eq const & eq)112         virtual void new_diseq_eh(euf::th_eq const& eq) {}
113 
114         /**
115            \brief Parametric theories (e.g. Arrays) should implement this method.
116         */
is_shared(theory_var v)117         virtual bool is_shared(theory_var v) const { return false; }
118 
status()119         sat::status status() const { return sat::status::th(m_is_redundant, get_id()); }
120 
121     };
122 
123     class th_euf_solver : public th_solver {
124     protected:
125         solver& ctx;
126         euf::enode_vector   m_var2enode;
127         unsigned_vector     m_var2enode_lim;
128         unsigned            m_num_scopes{ 0 };
129 
130         smt_params const& get_config() const;
131         sat::literal expr2literal(expr* e) const;
132         region& get_region();
133 
134 
135         sat::status mk_status();
136         bool add_unit(sat::literal lit);
137         bool add_units(sat::literal_vector const& lits);
add_clause(sat::literal lit)138         bool add_clause(sat::literal lit) { return add_unit(lit); }
139         bool add_clause(sat::literal a, sat::literal b);
140         bool add_clause(sat::literal a, sat::literal b, sat::literal c);
141         bool add_clause(sat::literal a, sat::literal b, sat::literal c, sat::literal d);
142         bool add_clause(sat::literal_vector const& lits);
143         void add_equiv(sat::literal a, sat::literal b);
144         void add_equiv_and(sat::literal a, sat::literal_vector const& bs);
145 
146 
147         bool is_true(sat::literal lit);
is_true(sat::literal a,sat::literal b)148         bool is_true(sat::literal a, sat::literal b) { return is_true(a) || is_true(b); }
is_true(sat::literal a,sat::literal b,sat::literal c)149         bool is_true(sat::literal a, sat::literal b, sat::literal c) { return is_true(a) || is_true(b, c); }
is_true(sat::literal a,sat::literal b,sat::literal c,sat::literal d)150         bool is_true(sat::literal a, sat::literal b, sat::literal c, sat::literal d) { return is_true(a) || is_true(b, c, c); }
151 
152         sat::literal eq_internalize(expr* a, expr* b);
eq_internalize(enode * a,enode * b)153         sat::literal eq_internalize(enode* a, enode* b) { return eq_internalize(a->get_expr(), b->get_expr()); }
154 
155         euf::enode* mk_enode(expr* e, bool suppress_args = false);
156         expr_ref mk_eq(expr* e1, expr* e2);
mk_var_eq(theory_var v1,theory_var v2)157         expr_ref mk_var_eq(theory_var v1, theory_var v2) { return mk_eq(var2expr(v1), var2expr(v2)); }
158 
159         void rewrite(expr_ref& a);
160 
161         virtual void push_core();
162         virtual void pop_core(unsigned n);
force_push()163         void force_push() {
164             CTRACE("euf_verbose", m_num_scopes > 0, tout << "push-core " << m_num_scopes << "\n";);
165             for (; m_num_scopes > 0; --m_num_scopes) push_core();
166         }
167 
168         friend class th_explain;
169 
170     public:
171         th_euf_solver(euf::solver& ctx, symbol const& name, euf::theory_id id);
~th_euf_solver()172         virtual ~th_euf_solver() {}
173         virtual theory_var mk_var(enode* n);
get_num_vars()174         unsigned get_num_vars() const { return m_var2enode.size(); }
175         euf::enode* e_internalize(expr* e);
176         enode* expr2enode(expr* e) const;
var2enode(theory_var v)177         enode* var2enode(theory_var v) const { return m_var2enode[v]; }
var2expr(theory_var v)178         expr* var2expr(theory_var v) const { return var2enode(v)->get_expr(); }
179         expr* bool_var2expr(sat::bool_var v) const;
180         expr_ref literal2expr(sat::literal lit) const;
bool_var2enode(sat::bool_var v)181         enode* bool_var2enode(sat::bool_var v) const { expr* e = bool_var2expr(v); return e ? expr2enode(e) : nullptr; }
182         sat::literal mk_literal(expr* e) const;
get_th_var(enode * n)183         theory_var get_th_var(enode* n) const { return n->get_th_var(get_id()); }
184         theory_var get_th_var(expr* e) const;
185         theory_var get_representative(theory_var v) const;
186         trail_stack& get_trail_stack();
187         bool is_attached_to_var(enode* n) const;
is_root(theory_var v)188         bool is_root(theory_var v) const { return var2enode(v)->is_root(); }
push()189         void push() override { m_num_scopes++; }
190         void pop(unsigned n) override;
191 
192         unsigned random();
193     };
194 
195     /**
196     * General purpose, eager explanation object. Explanations are conjunctions of literals and equalities.
197     * Used literals and equalities are stored in the object and retrieved on demand for conflict resolution
198     * It is "eager" in the sense that relevant literals are accumulated when the explanation is created.
199     * This is not a real problem for conflicts, but a theory has an option to implement custom lazy explanations
200     * that retrieve literals on demand.
201     */
202     class th_explain {
203         sat::literal   m_consequent { sat::null_literal }; // literal consequent for propagations
204         enode_pair     m_eq { enode_pair() };              // equality consequent for propagations
205         unsigned       m_num_literals;
206         unsigned       m_num_eqs;
207         sat::literal*  m_literals;
208         enode_pair*    m_eqs;
209         static size_t get_obj_size(unsigned num_lits, unsigned num_eqs);
210         th_explain(unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode_pair const& eq);
211         static th_explain* mk(th_euf_solver& th, unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs, sat::literal c, enode* x, enode* y);
212 
213     public:
214         static th_explain* conflict(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs);
conflict(th_euf_solver & th,sat::literal_vector const & lits)215         static th_explain* conflict(th_euf_solver& th, sat::literal_vector const& lits) { return conflict(th, lits.size(), lits.data(), 0, nullptr); }
216         static th_explain* conflict(th_euf_solver& th, unsigned n_lits, sat::literal const* lits, unsigned n_eqs, enode_pair const* eqs);
217         static th_explain* conflict(th_euf_solver& th, enode_pair_vector const& eqs);
218         static th_explain* conflict(th_euf_solver& th, sat::literal lit);
219         static th_explain* conflict(th_euf_solver& th, sat::literal lit, euf::enode* x, euf::enode* y);
220         static th_explain* conflict(th_euf_solver& th, euf::enode* x, euf::enode* y);
221         static th_explain* propagate(th_euf_solver& th, sat::literal lit, euf::enode* x, euf::enode* y);
222         static th_explain* propagate(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs, sat::literal consequent);
223         static th_explain* propagate(th_euf_solver& th, sat::literal_vector const& lits, enode_pair_vector const& eqs, euf::enode* x, euf::enode* y);
224 
to_index()225         sat::ext_constraint_idx to_index() const {
226             return sat::constraint_base::mem2base(this);
227         }
from_index(size_t idx)228         static th_explain& from_index(size_t idx) {
229             return *reinterpret_cast<th_explain*>(sat::constraint_base::from_index(idx)->mem());
230         }
231 
ext()232         sat::extension& ext() const {
233             return *sat::constraint_base::to_extension(to_index());
234         }
235 
236         std::ostream& display(std::ostream& out) const;
237 
238         class lits {
239             th_explain const& th;
240         public:
lits(th_explain const & th)241             lits(th_explain const& th) : th(th) {}
begin()242             sat::literal const* begin() const { return th.m_literals; }
end()243             sat::literal const* end() const { return th.m_literals + th.m_num_literals; }
244         };
245 
246         class eqs {
247             th_explain const& th;
248         public:
eqs(th_explain const & th)249             eqs(th_explain const& th) : th(th) {}
begin()250             enode_pair const* begin() const { return th.m_eqs; }
end()251             enode_pair const* end() const { return th.m_eqs + th.m_num_eqs; }
252         };
253 
lit_consequent()254         sat::literal lit_consequent() const { return m_consequent; }
255 
eq_consequent()256         enode_pair eq_consequent() const { return m_eq; }
257 
258     };
259 
260 
261 }
262