1 /*++ 2 Copyright (c) 2020 Microsoft Corporation 3 4 Module Name: 5 6 euf_solver.h 7 8 Abstract: 9 10 Solver plugin for EUF 11 12 Author: 13 14 Nikolaj Bjorner (nbjorner) 2020-08-25 15 16 --*/ 17 #pragma once 18 19 #include "util/scoped_ptr_vector.h" 20 #include "util/trail.h" 21 #include "ast/ast_translation.h" 22 #include "ast/euf/euf_egraph.h" 23 #include "ast/rewriter/th_rewriter.h" 24 #include "tactic/model_converter.h" 25 #include "sat/sat_extension.h" 26 #include "sat/smt/atom2bool_var.h" 27 #include "sat/smt/sat_th.h" 28 #include "sat/smt/sat_dual_solver.h" 29 #include "sat/smt/euf_ackerman.h" 30 #include "sat/smt/user_solver.h" 31 #include "smt/params/smt_params.h" 32 33 namespace euf { 34 typedef sat::literal literal; 35 typedef sat::ext_constraint_idx ext_constraint_idx; 36 typedef sat::ext_justification_idx ext_justification_idx; 37 typedef sat::literal_vector literal_vector; 38 typedef sat::bool_var bool_var; 39 40 class constraint { 41 public: 42 enum class kind_t { conflict, eq, lit }; 43 private: 44 kind_t m_kind; 45 public: constraint(kind_t k)46 constraint(kind_t k) : m_kind(k) {} kind()47 kind_t kind() const { return m_kind; } from_idx(size_t z)48 static constraint& from_idx(size_t z) { 49 return *reinterpret_cast<constraint*>(sat::constraint_base::idx2mem(z)); 50 } to_index()51 size_t to_index() const { return sat::constraint_base::mem2base(this); } 52 }; 53 54 class solver : public sat::extension, public th_internalizer, public th_decompile { 55 typedef top_sort<euf::enode> deps_t; 56 friend class ackerman; 57 class user_sort; 58 // friend class sat::ba_solver; 59 struct stats { 60 unsigned m_ackerman; statsstats61 stats() { reset(); } resetstats62 void reset() { memset(this, 0, sizeof(*this)); } 63 }; 64 struct scope { 65 unsigned m_var_lim; 66 }; 67 typedef trail_stack<solver> euf_trail_stack; 68 69 to_ptr(sat::literal l)70 size_t* to_ptr(sat::literal l) { return TAG(size_t*, reinterpret_cast<size_t*>((size_t)(l.index() << 4)), 1); } to_ptr(size_t jst)71 size_t* to_ptr(size_t jst) { return TAG(size_t*, reinterpret_cast<size_t*>(jst), 2); } is_literal(size_t * p)72 bool is_literal(size_t* p) const { return GET_TAG(p) == 1; } is_justification(size_t * p)73 bool is_justification(size_t* p) const { return GET_TAG(p) == 2; } get_literal(size_t * p)74 sat::literal get_literal(size_t* p) const { 75 unsigned idx = static_cast<unsigned>(reinterpret_cast<size_t>(UNTAG(size_t*, p))); 76 return sat::to_literal(idx >> 4); 77 } get_justification(size_t * p)78 size_t get_justification(size_t* p) const { 79 return reinterpret_cast<size_t>(UNTAG(size_t*, p)); 80 } 81 82 std::function<::solver*(void)> m_mk_solver; 83 ast_manager& m; 84 sat::sat_internalizer& si; 85 smt_params m_config; 86 euf::egraph m_egraph; 87 euf_trail_stack m_trail; 88 stats m_stats; 89 th_rewriter m_rewriter; 90 func_decl_ref_vector m_unhandled_functions; 91 sat::lookahead* m_lookahead{ nullptr }; 92 ast_manager* m_to_m; 93 sat::sat_internalizer* m_to_si; 94 scoped_ptr<euf::ackerman> m_ackerman; 95 scoped_ptr<sat::dual_solver> m_dual_solver; 96 user::solver* m_user_propagator{ nullptr }; 97 th_solver* m_qsolver { nullptr }; 98 unsigned m_generation { 0 }; 99 100 ptr_vector<expr> m_bool_var2expr; 101 ptr_vector<size_t> m_explain; 102 unsigned m_num_scopes{ 0 }; 103 unsigned_vector m_var_trail; 104 svector<scope> m_scopes; 105 scoped_ptr_vector<th_solver> m_solvers; 106 ptr_vector<th_solver> m_id2solver; 107 108 constraint* m_conflict{ nullptr }; 109 constraint* m_eq{ nullptr }; 110 constraint* m_lit{ nullptr }; 111 112 // internalization 113 bool visit(expr* e) override; 114 bool visited(expr* e) override; 115 bool post_visit(expr* e, bool sign, bool root) override; 116 117 void add_distinct_axiom(app* e, euf::enode* const* args); 118 void add_not_distinct_axiom(app* e, euf::enode* const* args); 119 void axiomatize_basic(enode* n); 120 bool internalize_root(app* e, bool sign, ptr_vector<enode> const& args); 121 euf::enode* mk_true(); 122 euf::enode* mk_false(); 123 124 // replay 125 typedef std::tuple<expr_ref, unsigned, sat::bool_var> reinit_t; 126 vector<reinit_t> m_reinit; 127 128 void start_reinit(unsigned num_scopes); 129 void finish_reinit(); 130 131 // extensions 132 th_solver* get_solver(family_id fid, func_decl* f); sort2solver(sort * s)133 th_solver* sort2solver(sort* s) { return get_solver(s->get_family_id(), nullptr); } func_decl2solver(func_decl * f)134 th_solver* func_decl2solver(func_decl* f) { return get_solver(f->get_family_id(), f); } 135 th_solver* quantifier2solver(); 136 th_solver* expr2solver(expr* e); 137 th_solver* bool_var2solver(sat::bool_var v); 138 void add_solver(th_solver* th); 139 void init_ackerman(); 140 141 // model building 142 expr_ref_vector m_values; 143 obj_map<expr, enode*> m_values2root; 144 bool include_func_interp(func_decl* f); 145 void register_macros(model& mdl); 146 void dependencies2values(user_sort& us, deps_t& deps, model_ref& mdl); 147 void collect_dependencies(user_sort& us, deps_t& deps); 148 void values2model(deps_t const& deps, model_ref& mdl); 149 void validate_model(model& mdl); 150 151 // solving 152 void propagate_literals(); 153 void propagate_th_eqs(); 154 bool is_self_propagated(th_eq const& e); 155 void get_antecedents(literal l, constraint& j, literal_vector& r, bool probing); 156 void new_diseq(enode* a, enode* b, literal lit); 157 158 // proofs 159 void log_antecedents(std::ostream& out, literal l, literal_vector const& r); 160 void log_antecedents(literal l, literal_vector const& r); 161 void log_justification(literal l, th_propagation const& jst); 162 void drat_log_decl(func_decl* f); 163 void drat_log_expr(expr* n); 164 void drat_log_expr1(expr* n); 165 ptr_vector<expr> m_drat_todo; 166 obj_hashtable<ast> m_drat_asts; 167 bool m_drat_initialized{ false }; 168 void init_drat(); 169 170 // relevancy 171 bool_vector m_relevant_expr_ids; 172 void ensure_dual_solver(); 173 bool init_relevancy(); 174 175 176 // invariant 177 void check_eqc_bool_assignment() const; 178 void check_missing_bool_enode_propagation() const; 179 void check_missing_eq_propagation() const; 180 181 // diagnosis 182 std::ostream& display_justification_ptr(std::ostream& out, size_t* j) const; 183 184 // constraints 185 constraint& mk_constraint(constraint*& c, constraint::kind_t k); conflict_constraint()186 constraint& conflict_constraint() { return mk_constraint(m_conflict, constraint::kind_t::conflict); } eq_constraint()187 constraint& eq_constraint() { return mk_constraint(m_eq, constraint::kind_t::eq); } lit_constraint()188 constraint& lit_constraint() { return mk_constraint(m_lit, constraint::kind_t::lit); } 189 190 // user propagator check_for_user_propagator()191 void check_for_user_propagator() { 192 if (!m_user_propagator) 193 throw default_exception("user propagator must be initialized"); 194 } 195 196 public: 197 solver(ast_manager& m, sat::sat_internalizer& si, params_ref const& p = params_ref()); 198 ~solver()199 ~solver() override { 200 if (m_conflict) dealloc(sat::constraint_base::mem2base_ptr(m_conflict)); 201 if (m_eq) dealloc(sat::constraint_base::mem2base_ptr(m_eq)); 202 if (m_lit) dealloc(sat::constraint_base::mem2base_ptr(m_lit)); 203 m_trail.reset(); 204 } 205 206 struct scoped_set_translate { 207 solver& s; scoped_set_translatescoped_set_translate208 scoped_set_translate(solver& s, ast_manager& m, sat::sat_internalizer& si) : 209 s(s) { 210 s.m_to_m = &m; 211 s.m_to_si = &si; 212 } ~scoped_set_translatescoped_set_translate213 ~scoped_set_translate() { 214 s.m_to_m = &s.m; 215 s.m_to_si = &s.si; 216 } 217 }; 218 219 struct scoped_generation { 220 solver& s; 221 unsigned m_g; scoped_generationscoped_generation222 scoped_generation(solver& s, unsigned g): 223 s(s), 224 m_g(s.m_generation) { 225 s.m_generation = g; 226 } ~scoped_generationscoped_generation227 ~scoped_generation() { 228 s.m_generation = m_g; 229 } 230 }; 231 232 // accessors 233 get_si()234 sat::sat_internalizer& get_si() { return si; } get_manager()235 ast_manager& get_manager() { return m; } get_enode(expr * e)236 enode* get_enode(expr* e) const { return m_egraph.find(e); } expr2literal(expr * e)237 sat::literal expr2literal(expr* e) const { return enode2literal(get_enode(e)); } enode2literal(enode * n)238 sat::literal enode2literal(enode* n) const { return sat::literal(n->bool_var(), false); } value(enode * n)239 lbool value(enode* n) const { return s().value(enode2literal(n)); } get_config()240 smt_params const& get_config() const { return m_config; } get_region()241 region& get_region() { return m_trail.get_region(); } get_egraph()242 egraph& get_egraph() { return m_egraph; } fid2solver(family_id fid)243 th_solver* fid2solver(family_id fid) const { return m_id2solver.get(fid, nullptr); } 244 245 template <typename C> push(C const & c)246 void push(C const& c) { m_trail.push(c); } 247 template <typename V> push_vec(ptr_vector<V> & vec,V * val)248 void push_vec(ptr_vector<V>& vec, V* val) { 249 vec.push_back(val); 250 push(push_back_trail<solver, V*, false>(vec)); 251 } 252 template <typename V> push_vec(svector<V> & vec,V val)253 void push_vec(svector<V>& vec, V val) { 254 vec.push_back(val); 255 push(push_back_trail<solver, V, false>(vec)); 256 } get_trail_stack()257 euf_trail_stack& get_trail_stack() { return m_trail; } 258 259 void updt_params(params_ref const& p); set_lookahead(sat::lookahead * s)260 void set_lookahead(sat::lookahead* s) override { m_lookahead = s; } 261 void init_search() override; 262 double get_reward(literal l, ext_constraint_idx idx, sat::literal_occs_fun& occs) const override; 263 bool is_extended_binary(ext_justification_idx idx, literal_vector& r) override; 264 bool is_external(bool_var v) override; 265 bool propagated(literal l, ext_constraint_idx idx) override; 266 bool unit_propagate() override; 267 268 void propagate(literal lit, ext_justification_idx idx); 269 bool propagate(enode* a, enode* b, ext_justification_idx idx); 270 void set_conflict(ext_justification_idx idx); 271 propagate(literal lit,th_propagation * p)272 void propagate(literal lit, th_propagation* p) { propagate(lit, p->to_index()); } propagate(enode * a,enode * b,th_propagation * p)273 bool propagate(enode* a, enode* b, th_propagation* p) { return propagate(a, b, p->to_index()); } set_conflict(th_propagation * p)274 void set_conflict(th_propagation* p) { set_conflict(p->to_index()); } 275 276 bool set_root(literal l, literal r) override; 277 void flush_roots() override; 278 279 void get_antecedents(literal l, ext_justification_idx idx, literal_vector& r, bool probing) override; 280 void get_antecedents(literal l, th_propagation& jst, literal_vector& r, bool probing); 281 void add_antecedent(enode* a, enode* b); 282 void asserted(literal l) override; 283 sat::check_result check() override; 284 void push() override; 285 void pop(unsigned n) override; 286 void user_push() override; 287 void user_pop(unsigned n) override; 288 void pre_simplify() override; 289 void simplify() override; 290 // have a way to replace l by r in all constraints 291 void clauses_modifed() override; 292 lbool get_phase(bool_var v) override; 293 std::ostream& display(std::ostream& out) const override; 294 std::ostream& display_justification(std::ostream& out, ext_justification_idx idx) const override; 295 std::ostream& display_constraint(std::ostream& out, ext_constraint_idx idx) const override; bpp(enode * n)296 euf::egraph::b_pp bpp(enode* n) { return m_egraph.bpp(n); } 297 void collect_statistics(statistics& st) const override; 298 extension* copy(sat::solver* s) override; 299 enode* copy(solver& dst_ctx, enode* src_n); 300 void find_mutexes(literal_vector& lits, vector<literal_vector>& mutexes) override; 301 void gc() override; 302 void pop_reinit() override; 303 bool validate() override; 304 void init_use_list(sat::ext_use_list& ul) override; 305 bool is_blocked(literal l, ext_constraint_idx) override; 306 bool check_model(sat::model const& m) const override; 307 void gc_vars(unsigned num_vars) override; 308 309 // proof use_drat()310 bool use_drat() { return s().get_config().m_drat && (init_drat(), true); } get_drat()311 sat::drat& get_drat() { return s().get_drat(); } 312 void drat_bool_def(sat::bool_var v, expr* n); 313 void drat_eq_def(sat::literal lit, expr* eq); 314 315 // decompile 316 bool extract_pb(std::function<void(unsigned sz, literal const* c, unsigned k)>& card, 317 std::function<void(unsigned sz, literal const* c, unsigned const* coeffs, unsigned k)>& pb) override; 318 319 bool to_formulas(std::function<expr_ref(sat::literal)>& l2e, expr_ref_vector& fmls) override; 320 321 // internalize 322 sat::literal internalize(expr* e, bool sign, bool root, bool learned) override; 323 void internalize(expr* e, bool learned) override; 324 sat::literal mk_literal(expr* e); attach_th_var(enode * n,th_solver * th,theory_var v)325 void attach_th_var(enode* n, th_solver* th, theory_var v) { m_egraph.add_th_var(n, v, th->get_id()); } 326 void attach_node(euf::enode* n); 327 expr_ref mk_eq(expr* e1, expr* e2); mk_eq(euf::enode * n1,euf::enode * n2)328 expr_ref mk_eq(euf::enode* n1, euf::enode* n2) { return mk_eq(n1->get_expr(), n2->get_expr()); } mk_enode(expr * e,unsigned n,enode * const * args)329 euf::enode* mk_enode(expr* e, unsigned n, enode* const* args) { return m_egraph.mk(e, m_generation, n, args); } bool_var2expr(sat::bool_var v)330 expr* bool_var2expr(sat::bool_var v) const { return m_bool_var2expr.get(v, nullptr); } literal2expr(sat::literal lit)331 expr_ref literal2expr(sat::literal lit) const { expr* e = bool_var2expr(lit.var()); return lit.sign() ? expr_ref(m.mk_not(e), m) : expr_ref(e, m); } 332 333 sat::literal attach_lit(sat::literal lit, expr* e); 334 void unhandled_function(func_decl* f); get_rewriter()335 th_rewriter& get_rewriter() { return m_rewriter; } 336 bool is_shared(euf::enode* n) const; 337 338 // relevancy relevancy_enabled()339 bool relevancy_enabled() const { return get_config().m_relevancy_lvl > 0; } 340 void add_root(unsigned n, sat::literal const* lits); 341 void add_aux(unsigned n, sat::literal const* lits); 342 void track_relevancy(sat::bool_var v); 343 bool is_relevant(expr* e) const; 344 bool is_relevant(enode* n) const; 345 346 // model construction 347 void update_model(model_ref& mdl); 348 obj_map<expr, enode*> const& values2root(); 349 expr* node2value(enode* n) const; 350 351 // diagnostics unhandled_functions()352 func_decl_ref_vector const& unhandled_functions() { return m_unhandled_functions; } 353 354 // user propagator 355 void user_propagate_init( 356 void* ctx, 357 ::solver::push_eh_t& push_eh, 358 ::solver::pop_eh_t& pop_eh, 359 ::solver::fresh_eh_t& fresh_eh); 360 bool watches_fixed(enode* n) const; 361 void assign_fixed(enode* n, expr* val, unsigned sz, literal const* explain); assign_fixed(enode * n,expr * val,literal_vector const & explain)362 void assign_fixed(enode* n, expr* val, literal_vector const& explain) { assign_fixed(n, val, explain.size(), explain.c_ptr()); } assign_fixed(enode * n,expr * val,literal explain)363 void assign_fixed(enode* n, expr* val, literal explain) { assign_fixed(n, val, 1, &explain); } 364 user_propagate_register_final(::solver::final_eh_t & final_eh)365 void user_propagate_register_final(::solver::final_eh_t& final_eh) { 366 check_for_user_propagator(); 367 m_user_propagator->register_final(final_eh); 368 } user_propagate_register_fixed(::solver::fixed_eh_t & fixed_eh)369 void user_propagate_register_fixed(::solver::fixed_eh_t& fixed_eh) { 370 check_for_user_propagator(); 371 m_user_propagator->register_fixed(fixed_eh); 372 } user_propagate_register_eq(::solver::eq_eh_t & eq_eh)373 void user_propagate_register_eq(::solver::eq_eh_t& eq_eh) { 374 check_for_user_propagator(); 375 m_user_propagator->register_eq(eq_eh); 376 } user_propagate_register_diseq(::solver::eq_eh_t & diseq_eh)377 void user_propagate_register_diseq(::solver::eq_eh_t& diseq_eh) { 378 check_for_user_propagator(); 379 m_user_propagator->register_diseq(diseq_eh); 380 } user_propagate_register(expr * e)381 unsigned user_propagate_register(expr* e) { 382 check_for_user_propagator(); 383 return m_user_propagator->add_expr(e); 384 } 385 386 // solver factory mk_solver()387 ::solver* mk_solver() { return m_mk_solver(); } set_mk_solver(std::function<::solver * (void)> & mk)388 void set_mk_solver(std::function<::solver*(void)>& mk) { m_mk_solver = mk; } 389 390 391 }; 392 }; 393 394 inline std::ostream& operator<<(std::ostream& out, euf::solver const& s) { 395 return s.display(out); 396 } 397