1 /*++
2 Copyright (c) 2020 Microsoft Corporation
3 
4 Module Name:
5 
6     euf_solver.h
7 
8 Abstract:
9 
10     Solver plugin for EUF
11 
12 Author:
13 
14     Nikolaj Bjorner (nbjorner) 2020-08-25
15 
16 --*/
17 #pragma once
18 
19 #include "util/scoped_ptr_vector.h"
20 #include "util/trail.h"
21 #include "ast/ast_translation.h"
22 #include "ast/euf/euf_egraph.h"
23 #include "ast/rewriter/th_rewriter.h"
24 #include "tactic/model_converter.h"
25 #include "sat/sat_extension.h"
26 #include "sat/smt/atom2bool_var.h"
27 #include "sat/smt/sat_th.h"
28 #include "sat/smt/sat_dual_solver.h"
29 #include "sat/smt/euf_ackerman.h"
30 #include "sat/smt/user_solver.h"
31 #include "smt/params/smt_params.h"
32 
33 namespace euf {
34     typedef sat::literal literal;
35     typedef sat::ext_constraint_idx ext_constraint_idx;
36     typedef sat::ext_justification_idx ext_justification_idx;
37     typedef sat::literal_vector literal_vector;
38     typedef sat::bool_var bool_var;
39 
40     class constraint {
41     public:
42         enum class kind_t { conflict, eq, lit };
43     private:
44         kind_t m_kind;
45     public:
constraint(kind_t k)46         constraint(kind_t k) : m_kind(k) {}
kind()47         kind_t kind() const { return m_kind; }
from_idx(size_t z)48         static constraint& from_idx(size_t z) {
49             return *reinterpret_cast<constraint*>(sat::constraint_base::idx2mem(z));
50         }
to_index()51         size_t to_index() const { return sat::constraint_base::mem2base(this); }
52     };
53 
54     class solver : public sat::extension, public th_internalizer, public th_decompile {
55         typedef top_sort<euf::enode> deps_t;
56         friend class ackerman;
57         class user_sort;
58         // friend class sat::ba_solver;
59         struct stats {
60             unsigned m_ackerman;
statsstats61             stats() { reset(); }
resetstats62             void reset() { memset(this, 0, sizeof(*this)); }
63         };
64         struct scope {
65             unsigned m_var_lim;
66         };
67         typedef trail_stack<solver> euf_trail_stack;
68 
69 
to_ptr(sat::literal l)70         size_t* to_ptr(sat::literal l) { return TAG(size_t*, reinterpret_cast<size_t*>((size_t)(l.index() << 4)), 1); }
to_ptr(size_t jst)71         size_t* to_ptr(size_t jst) { return TAG(size_t*, reinterpret_cast<size_t*>(jst), 2); }
is_literal(size_t * p)72         bool is_literal(size_t* p) const { return GET_TAG(p) == 1; }
is_justification(size_t * p)73         bool is_justification(size_t* p) const { return GET_TAG(p) == 2; }
get_literal(size_t * p)74         sat::literal get_literal(size_t* p) const {
75             unsigned idx = static_cast<unsigned>(reinterpret_cast<size_t>(UNTAG(size_t*, p)));
76             return sat::to_literal(idx >> 4);
77         }
get_justification(size_t * p)78         size_t get_justification(size_t* p) const {
79             return reinterpret_cast<size_t>(UNTAG(size_t*, p));
80         }
81 
82         std::function<::solver*(void)>   m_mk_solver;
83         ast_manager&                     m;
84         sat::sat_internalizer& si;
85         smt_params             m_config;
86         euf::egraph            m_egraph;
87         euf_trail_stack        m_trail;
88         stats                  m_stats;
89         th_rewriter            m_rewriter;
90         func_decl_ref_vector   m_unhandled_functions;
91         sat::lookahead*        m_lookahead{ nullptr };
92         ast_manager*           m_to_m;
93         sat::sat_internalizer* m_to_si;
94         scoped_ptr<euf::ackerman>    m_ackerman;
95         scoped_ptr<sat::dual_solver> m_dual_solver;
96         user::solver*          m_user_propagator{ nullptr };
97         th_solver*             m_qsolver { nullptr };
98         unsigned               m_generation { 0 };
99 
100         ptr_vector<expr>                                m_bool_var2expr;
101         ptr_vector<size_t>                              m_explain;
102         unsigned                                        m_num_scopes{ 0 };
103         unsigned_vector                                 m_var_trail;
104         svector<scope>                                  m_scopes;
105         scoped_ptr_vector<th_solver>                    m_solvers;
106         ptr_vector<th_solver>                           m_id2solver;
107 
108         constraint* m_conflict{ nullptr };
109         constraint* m_eq{ nullptr };
110         constraint* m_lit{ nullptr };
111 
112         // internalization
113         bool visit(expr* e) override;
114         bool visited(expr* e) override;
115         bool post_visit(expr* e, bool sign, bool root) override;
116 
117         void add_distinct_axiom(app* e, euf::enode* const* args);
118         void add_not_distinct_axiom(app* e, euf::enode* const* args);
119         void axiomatize_basic(enode* n);
120         bool internalize_root(app* e, bool sign, ptr_vector<enode> const& args);
121         euf::enode* mk_true();
122         euf::enode* mk_false();
123 
124         // replay
125         typedef std::tuple<expr_ref, unsigned, sat::bool_var> reinit_t;
126         vector<reinit_t>    m_reinit;
127 
128         void start_reinit(unsigned num_scopes);
129         void finish_reinit();
130 
131         // extensions
132         th_solver* get_solver(family_id fid, func_decl* f);
sort2solver(sort * s)133         th_solver* sort2solver(sort* s) { return get_solver(s->get_family_id(), nullptr); }
func_decl2solver(func_decl * f)134         th_solver* func_decl2solver(func_decl* f) { return get_solver(f->get_family_id(), f); }
135         th_solver* quantifier2solver();
136         th_solver* expr2solver(expr* e);
137         th_solver* bool_var2solver(sat::bool_var v);
138         void add_solver(th_solver* th);
139         void init_ackerman();
140 
141         // model building
142         expr_ref_vector m_values;
143         obj_map<expr, enode*> m_values2root;
144         bool include_func_interp(func_decl* f);
145         void register_macros(model& mdl);
146         void dependencies2values(user_sort& us, deps_t& deps, model_ref& mdl);
147         void collect_dependencies(user_sort& us, deps_t& deps);
148         void values2model(deps_t const& deps, model_ref& mdl);
149         void validate_model(model& mdl);
150 
151         // solving
152         void propagate_literals();
153         void propagate_th_eqs();
154         bool is_self_propagated(th_eq const& e);
155         void get_antecedents(literal l, constraint& j, literal_vector& r, bool probing);
156         void new_diseq(enode* a, enode* b, literal lit);
157 
158         // proofs
159         void log_antecedents(std::ostream& out, literal l, literal_vector const& r);
160         void log_antecedents(literal l, literal_vector const& r);
161         void log_justification(literal l, th_propagation const& jst);
162         void drat_log_decl(func_decl* f);
163         void drat_log_expr(expr* n);
164         void drat_log_expr1(expr* n);
165         ptr_vector<expr> m_drat_todo;
166         obj_hashtable<ast> m_drat_asts;
167         bool m_drat_initialized{ false };
168         void init_drat();
169 
170         // relevancy
171         bool_vector m_relevant_expr_ids;
172         void ensure_dual_solver();
173         bool init_relevancy();
174 
175 
176         // invariant
177         void check_eqc_bool_assignment() const;
178         void check_missing_bool_enode_propagation() const;
179         void check_missing_eq_propagation() const;
180 
181         // diagnosis
182         std::ostream& display_justification_ptr(std::ostream& out, size_t* j) const;
183 
184         // constraints
185         constraint& mk_constraint(constraint*& c, constraint::kind_t k);
conflict_constraint()186         constraint& conflict_constraint() { return mk_constraint(m_conflict, constraint::kind_t::conflict); }
eq_constraint()187         constraint& eq_constraint() { return mk_constraint(m_eq, constraint::kind_t::eq); }
lit_constraint()188         constraint& lit_constraint() { return mk_constraint(m_lit, constraint::kind_t::lit); }
189 
190         // user propagator
check_for_user_propagator()191         void check_for_user_propagator() {
192             if (!m_user_propagator)
193                 throw default_exception("user propagator must be initialized");
194         }
195 
196     public:
197         solver(ast_manager& m, sat::sat_internalizer& si, params_ref const& p = params_ref());
198 
~solver()199         ~solver() override {
200             if (m_conflict) dealloc(sat::constraint_base::mem2base_ptr(m_conflict));
201             if (m_eq) dealloc(sat::constraint_base::mem2base_ptr(m_eq));
202             if (m_lit) dealloc(sat::constraint_base::mem2base_ptr(m_lit));
203             m_trail.reset();
204         }
205 
206         struct scoped_set_translate {
207             solver& s;
scoped_set_translatescoped_set_translate208             scoped_set_translate(solver& s, ast_manager& m, sat::sat_internalizer& si) :
209                 s(s) {
210                 s.m_to_m = &m;
211                 s.m_to_si = &si;
212             }
~scoped_set_translatescoped_set_translate213             ~scoped_set_translate() {
214                 s.m_to_m = &s.m;
215                 s.m_to_si = &s.si;
216             }
217         };
218 
219         struct scoped_generation {
220             solver& s;
221             unsigned m_g;
scoped_generationscoped_generation222             scoped_generation(solver& s, unsigned g):
223                 s(s),
224                 m_g(s.m_generation) {
225                 s.m_generation = g;
226             }
~scoped_generationscoped_generation227             ~scoped_generation() {
228                 s.m_generation = m_g;
229             }
230         };
231 
232         // accessors
233 
get_si()234         sat::sat_internalizer& get_si() { return si; }
get_manager()235         ast_manager& get_manager() { return m; }
get_enode(expr * e)236         enode* get_enode(expr* e) const { return m_egraph.find(e); }
expr2literal(expr * e)237         sat::literal expr2literal(expr* e) const { return enode2literal(get_enode(e)); }
enode2literal(enode * n)238         sat::literal enode2literal(enode* n) const { return sat::literal(n->bool_var(), false); }
value(enode * n)239         lbool value(enode* n) const { return s().value(enode2literal(n)); }
get_config()240         smt_params const& get_config() const { return m_config; }
get_region()241         region& get_region() { return m_trail.get_region(); }
get_egraph()242         egraph& get_egraph() { return m_egraph; }
fid2solver(family_id fid)243         th_solver* fid2solver(family_id fid) const { return m_id2solver.get(fid, nullptr); }
244 
245         template <typename C>
push(C const & c)246         void push(C const& c) { m_trail.push(c); }
247         template <typename V>
push_vec(ptr_vector<V> & vec,V * val)248         void push_vec(ptr_vector<V>& vec, V* val) {
249             vec.push_back(val);
250             push(push_back_trail<solver, V*, false>(vec));
251         }
252         template <typename V>
push_vec(svector<V> & vec,V val)253         void push_vec(svector<V>& vec, V val) {
254             vec.push_back(val);
255             push(push_back_trail<solver, V, false>(vec));
256         }
get_trail_stack()257         euf_trail_stack& get_trail_stack() { return m_trail; }
258 
259         void updt_params(params_ref const& p);
set_lookahead(sat::lookahead * s)260         void set_lookahead(sat::lookahead* s) override { m_lookahead = s; }
261         void init_search() override;
262         double get_reward(literal l, ext_constraint_idx idx, sat::literal_occs_fun& occs) const override;
263         bool is_extended_binary(ext_justification_idx idx, literal_vector& r) override;
264         bool is_external(bool_var v) override;
265         bool propagated(literal l, ext_constraint_idx idx) override;
266         bool unit_propagate() override;
267 
268         void propagate(literal lit, ext_justification_idx idx);
269         bool propagate(enode* a, enode* b, ext_justification_idx idx);
270         void set_conflict(ext_justification_idx idx);
271 
propagate(literal lit,th_propagation * p)272         void propagate(literal lit, th_propagation* p) { propagate(lit, p->to_index()); }
propagate(enode * a,enode * b,th_propagation * p)273         bool propagate(enode* a, enode* b, th_propagation* p) { return propagate(a, b, p->to_index()); }
set_conflict(th_propagation * p)274         void set_conflict(th_propagation* p) { set_conflict(p->to_index()); }
275 
276         bool set_root(literal l, literal r) override;
277         void flush_roots() override;
278 
279         void get_antecedents(literal l, ext_justification_idx idx, literal_vector& r, bool probing) override;
280         void get_antecedents(literal l, th_propagation& jst, literal_vector& r, bool probing);
281         void add_antecedent(enode* a, enode* b);
282         void asserted(literal l) override;
283         sat::check_result check() override;
284         void push() override;
285         void pop(unsigned n) override;
286         void user_push() override;
287         void user_pop(unsigned n) override;
288         void pre_simplify() override;
289         void simplify() override;
290         // have a way to replace l by r in all constraints
291         void clauses_modifed() override;
292         lbool get_phase(bool_var v) override;
293         std::ostream& display(std::ostream& out) const override;
294         std::ostream& display_justification(std::ostream& out, ext_justification_idx idx) const override;
295         std::ostream& display_constraint(std::ostream& out, ext_constraint_idx idx) const override;
bpp(enode * n)296         euf::egraph::b_pp bpp(enode* n) { return m_egraph.bpp(n); }
297         void collect_statistics(statistics& st) const override;
298         extension* copy(sat::solver* s) override;
299         enode* copy(solver& dst_ctx, enode* src_n);
300         void find_mutexes(literal_vector& lits, vector<literal_vector>& mutexes) override;
301         void gc() override;
302         void pop_reinit() override;
303         bool validate() override;
304         void init_use_list(sat::ext_use_list& ul) override;
305         bool is_blocked(literal l, ext_constraint_idx) override;
306         bool check_model(sat::model const& m) const override;
307         void gc_vars(unsigned num_vars) override;
308 
309         // proof
use_drat()310         bool use_drat() { return s().get_config().m_drat && (init_drat(), true); }
get_drat()311         sat::drat& get_drat() { return s().get_drat(); }
312         void drat_bool_def(sat::bool_var v, expr* n);
313         void drat_eq_def(sat::literal lit, expr* eq);
314 
315         // decompile
316         bool extract_pb(std::function<void(unsigned sz, literal const* c, unsigned k)>& card,
317             std::function<void(unsigned sz, literal const* c, unsigned const* coeffs, unsigned k)>& pb) override;
318 
319         bool to_formulas(std::function<expr_ref(sat::literal)>& l2e, expr_ref_vector& fmls) override;
320 
321         // internalize
322         sat::literal internalize(expr* e, bool sign, bool root, bool learned) override;
323         void internalize(expr* e, bool learned) override;
324         sat::literal mk_literal(expr* e);
attach_th_var(enode * n,th_solver * th,theory_var v)325         void attach_th_var(enode* n, th_solver* th, theory_var v) { m_egraph.add_th_var(n, v, th->get_id()); }
326         void attach_node(euf::enode* n);
327         expr_ref mk_eq(expr* e1, expr* e2);
mk_eq(euf::enode * n1,euf::enode * n2)328         expr_ref mk_eq(euf::enode* n1, euf::enode* n2) { return mk_eq(n1->get_expr(), n2->get_expr()); }
mk_enode(expr * e,unsigned n,enode * const * args)329         euf::enode* mk_enode(expr* e, unsigned n, enode* const* args) { return m_egraph.mk(e, m_generation, n, args); }
bool_var2expr(sat::bool_var v)330         expr* bool_var2expr(sat::bool_var v) const { return m_bool_var2expr.get(v, nullptr); }
literal2expr(sat::literal lit)331         expr_ref literal2expr(sat::literal lit) const { expr* e = bool_var2expr(lit.var()); return lit.sign() ? expr_ref(m.mk_not(e), m) : expr_ref(e, m); }
332 
333         sat::literal attach_lit(sat::literal lit, expr* e);
334         void unhandled_function(func_decl* f);
get_rewriter()335         th_rewriter& get_rewriter() { return m_rewriter; }
336         bool is_shared(euf::enode* n) const;
337 
338         // relevancy
relevancy_enabled()339         bool relevancy_enabled() const { return get_config().m_relevancy_lvl > 0; }
340         void add_root(unsigned n, sat::literal const* lits);
341         void add_aux(unsigned n, sat::literal const* lits);
342         void track_relevancy(sat::bool_var v);
343         bool is_relevant(expr* e) const;
344         bool is_relevant(enode* n) const;
345 
346         // model construction
347         void update_model(model_ref& mdl);
348         obj_map<expr, enode*> const& values2root();
349         expr* node2value(enode* n) const;
350 
351         // diagnostics
unhandled_functions()352         func_decl_ref_vector const& unhandled_functions() { return m_unhandled_functions; }
353 
354         // user propagator
355         void user_propagate_init(
356             void* ctx,
357             ::solver::push_eh_t& push_eh,
358             ::solver::pop_eh_t& pop_eh,
359             ::solver::fresh_eh_t& fresh_eh);
360         bool watches_fixed(enode* n) const;
361         void assign_fixed(enode* n, expr* val, unsigned sz, literal const* explain);
assign_fixed(enode * n,expr * val,literal_vector const & explain)362         void assign_fixed(enode* n, expr* val, literal_vector const& explain) { assign_fixed(n, val, explain.size(), explain.c_ptr()); }
assign_fixed(enode * n,expr * val,literal explain)363         void assign_fixed(enode* n, expr* val, literal explain) { assign_fixed(n, val, 1, &explain); }
364 
user_propagate_register_final(::solver::final_eh_t & final_eh)365         void user_propagate_register_final(::solver::final_eh_t& final_eh) {
366             check_for_user_propagator();
367             m_user_propagator->register_final(final_eh);
368         }
user_propagate_register_fixed(::solver::fixed_eh_t & fixed_eh)369         void user_propagate_register_fixed(::solver::fixed_eh_t& fixed_eh) {
370             check_for_user_propagator();
371             m_user_propagator->register_fixed(fixed_eh);
372         }
user_propagate_register_eq(::solver::eq_eh_t & eq_eh)373         void user_propagate_register_eq(::solver::eq_eh_t& eq_eh) {
374             check_for_user_propagator();
375             m_user_propagator->register_eq(eq_eh);
376         }
user_propagate_register_diseq(::solver::eq_eh_t & diseq_eh)377         void user_propagate_register_diseq(::solver::eq_eh_t& diseq_eh) {
378             check_for_user_propagator();
379             m_user_propagator->register_diseq(diseq_eh);
380         }
user_propagate_register(expr * e)381         unsigned user_propagate_register(expr* e) {
382             check_for_user_propagator();
383             return m_user_propagator->add_expr(e);
384         }
385 
386         // solver factory
mk_solver()387         ::solver* mk_solver() { return m_mk_solver(); }
set_mk_solver(std::function<::solver * (void)> & mk)388         void set_mk_solver(std::function<::solver*(void)>& mk) { m_mk_solver = mk; }
389 
390 
391     };
392 };
393 
394 inline std::ostream& operator<<(std::ostream& out, euf::solver const& s) {
395     return s.display(out);
396 }
397