1 /*++ 2 Copyright (c) 2020 Microsoft Corporation 3 4 Module Name: 5 6 euf_solver.h 7 8 Abstract: 9 10 Solver plugin for EUF 11 12 Author: 13 14 Nikolaj Bjorner (nbjorner) 2020-08-25 15 16 --*/ 17 #pragma once 18 19 #include "util/scoped_ptr_vector.h" 20 #include "util/trail.h" 21 #include "ast/ast_translation.h" 22 #include "ast/euf/euf_egraph.h" 23 #include "ast/rewriter/th_rewriter.h" 24 #include "tactic/model_converter.h" 25 #include "sat/sat_extension.h" 26 #include "sat/smt/atom2bool_var.h" 27 #include "sat/smt/sat_th.h" 28 #include "sat/smt/sat_dual_solver.h" 29 #include "sat/smt/euf_ackerman.h" 30 #include "sat/smt/user_solver.h" 31 #include "smt/params/smt_params.h" 32 33 namespace euf { 34 typedef sat::literal literal; 35 typedef sat::ext_constraint_idx ext_constraint_idx; 36 typedef sat::ext_justification_idx ext_justification_idx; 37 typedef sat::literal_vector literal_vector; 38 typedef sat::bool_var bool_var; 39 40 class constraint { 41 public: 42 enum class kind_t { conflict, eq, lit }; 43 private: 44 kind_t m_kind; 45 public: constraint(kind_t k)46 constraint(kind_t k) : m_kind(k) {} kind()47 kind_t kind() const { return m_kind; } from_idx(size_t z)48 static constraint& from_idx(size_t z) { 49 return *reinterpret_cast<constraint*>(sat::constraint_base::idx2mem(z)); 50 } to_index()51 size_t to_index() const { return sat::constraint_base::mem2base(this); } 52 }; 53 54 class clause_pp { 55 solver& s; 56 sat::literal_vector const& lits; 57 public: clause_pp(solver & s,sat::literal_vector const & lits)58 clause_pp(solver& s, sat::literal_vector const& lits):s(s), lits(lits) {} 59 std::ostream& display(std::ostream& out) const; 60 }; 61 62 class solver : public sat::extension, public th_internalizer, public th_decompile { 63 typedef top_sort<euf::enode> deps_t; 64 friend class ackerman; 65 class user_sort; 66 // friend class sat::ba_solver; 67 struct stats { 68 unsigned m_ackerman; 69 unsigned m_final_checks; statsstats70 stats() { reset(); } resetstats71 void reset() { memset(this, 0, sizeof(*this)); } 72 }; 73 struct scope { 74 unsigned m_var_lim; 75 }; 76 77 to_ptr(sat::literal l)78 size_t* to_ptr(sat::literal l) { return TAG(size_t*, reinterpret_cast<size_t*>((size_t)(l.index() << 4)), 1); } to_ptr(size_t jst)79 size_t* to_ptr(size_t jst) { return TAG(size_t*, reinterpret_cast<size_t*>(jst), 2); } is_literal(size_t * p)80 bool is_literal(size_t* p) const { return GET_TAG(p) == 1; } is_justification(size_t * p)81 bool is_justification(size_t* p) const { return GET_TAG(p) == 2; } get_literal(size_t * p)82 sat::literal get_literal(size_t* p) const { 83 unsigned idx = static_cast<unsigned>(reinterpret_cast<size_t>(UNTAG(size_t*, p))); 84 return sat::to_literal(idx >> 4); 85 } get_justification(size_t * p)86 size_t get_justification(size_t* p) const { 87 return reinterpret_cast<size_t>(UNTAG(size_t*, p)); 88 } 89 90 std::function<::solver*(void)> m_mk_solver; 91 ast_manager& m; 92 sat::sat_internalizer& si; 93 smt_params m_config; 94 euf::egraph m_egraph; 95 trail_stack m_trail; 96 stats m_stats; 97 th_rewriter m_rewriter; 98 func_decl_ref_vector m_unhandled_functions; 99 sat::lookahead* m_lookahead = nullptr; 100 ast_manager* m_to_m; 101 sat::sat_internalizer* m_to_si; 102 scoped_ptr<euf::ackerman> m_ackerman; 103 user_solver::solver* m_user_propagator = nullptr; 104 th_solver* m_qsolver = nullptr; 105 unsigned m_generation = 0; 106 mutable ptr_vector<expr> m_todo; 107 108 ptr_vector<expr> m_bool_var2expr; 109 ptr_vector<size_t> m_explain; 110 unsigned m_num_scopes = 0; 111 unsigned_vector m_var_trail; 112 svector<scope> m_scopes; 113 scoped_ptr_vector<th_solver> m_solvers; 114 ptr_vector<th_solver> m_id2solver; 115 116 constraint* m_conflict = nullptr; 117 constraint* m_eq = nullptr; 118 constraint* m_lit = nullptr; 119 120 // internalization 121 bool visit(expr* e) override; 122 bool visited(expr* e) override; 123 bool post_visit(expr* e, bool sign, bool root) override; 124 125 void add_distinct_axiom(app* e, euf::enode* const* args); 126 void add_not_distinct_axiom(app* e, euf::enode* const* args); 127 void axiomatize_basic(enode* n); 128 bool internalize_root(app* e, bool sign, ptr_vector<enode> const& args); 129 void ensure_merged_tf(euf::enode* n); 130 euf::enode* mk_true(); 131 euf::enode* mk_false(); 132 133 // replay 134 typedef std::tuple<expr_ref, unsigned, sat::bool_var> reinit_t; 135 vector<reinit_t> m_reinit; 136 137 void start_reinit(unsigned num_scopes); 138 void finish_reinit(); 139 void relevancy_reinit(expr* e); 140 141 // extensions 142 th_solver* get_solver(family_id fid, func_decl* f); sort2solver(sort * s)143 th_solver* sort2solver(sort* s) { return get_solver(s->get_family_id(), nullptr); } func_decl2solver(func_decl * f)144 th_solver* func_decl2solver(func_decl* f) { return get_solver(f->get_family_id(), f); } 145 th_solver* quantifier2solver(); 146 th_solver* expr2solver(expr* e); 147 th_solver* bool_var2solver(sat::bool_var v); 148 void add_solver(th_solver* th); 149 void init_ackerman(); 150 151 // model building 152 expr_ref_vector m_values; 153 obj_map<expr, enode*> m_values2root; 154 bool include_func_interp(func_decl* f); 155 void register_macros(model& mdl); 156 void dependencies2values(user_sort& us, deps_t& deps, model_ref& mdl); 157 void collect_dependencies(user_sort& us, deps_t& deps); 158 void values2model(deps_t const& deps, model_ref& mdl); 159 void validate_model(model& mdl); 160 void display_validation_failure(std::ostream& out, model& mdl, enode* n); 161 162 // solving 163 void propagate_literals(); 164 void propagate_th_eqs(); 165 bool is_self_propagated(th_eq const& e); 166 void get_antecedents(literal l, constraint& j, literal_vector& r, bool probing); 167 void new_diseq(enode* a, enode* b, literal lit); 168 bool merge_shared_bools(); 169 170 // proofs 171 void log_antecedents(std::ostream& out, literal l, literal_vector const& r); 172 void log_antecedents(literal l, literal_vector const& r); 173 void log_justification(literal l, th_explain const& jst); 174 void drat_log_decl(func_decl* f); 175 void drat_log_expr(expr* n); 176 void drat_log_expr1(expr* n); 177 ptr_vector<expr> m_drat_todo; 178 obj_hashtable<ast> m_drat_asts; 179 bool m_drat_initialized{ false }; 180 void init_drat(); 181 182 // relevancy 183 bool_vector m_relevant_expr_ids; 184 bool_vector m_relevant_visited; 185 ptr_vector<expr> m_relevant_todo; 186 void ensure_dual_solver(); 187 bool init_relevancy(); 188 189 190 // invariant 191 void check_eqc_bool_assignment() const; 192 void check_missing_bool_enode_propagation() const; 193 void check_missing_eq_propagation() const; 194 195 // diagnosis 196 std::ostream& display_justification_ptr(std::ostream& out, size_t* j) const; 197 198 // constraints 199 constraint& mk_constraint(constraint*& c, constraint::kind_t k); conflict_constraint()200 constraint& conflict_constraint() { return mk_constraint(m_conflict, constraint::kind_t::conflict); } eq_constraint()201 constraint& eq_constraint() { return mk_constraint(m_eq, constraint::kind_t::eq); } lit_constraint()202 constraint& lit_constraint() { return mk_constraint(m_lit, constraint::kind_t::lit); } 203 204 // user propagator check_for_user_propagator()205 void check_for_user_propagator() { 206 if (!m_user_propagator) 207 throw default_exception("user propagator must be initialized"); 208 } 209 210 public: 211 solver(ast_manager& m, sat::sat_internalizer& si, params_ref const& p = params_ref()); 212 ~solver()213 ~solver() override { 214 if (m_conflict) dealloc(sat::constraint_base::mem2base_ptr(m_conflict)); 215 if (m_eq) dealloc(sat::constraint_base::mem2base_ptr(m_eq)); 216 if (m_lit) dealloc(sat::constraint_base::mem2base_ptr(m_lit)); 217 m_trail.reset(); 218 } 219 220 struct scoped_set_translate { 221 solver& s; scoped_set_translatescoped_set_translate222 scoped_set_translate(solver& s, ast_manager& m, sat::sat_internalizer& si) : 223 s(s) { 224 s.m_to_m = &m; 225 s.m_to_si = &si; 226 } ~scoped_set_translatescoped_set_translate227 ~scoped_set_translate() { 228 s.m_to_m = &s.m; 229 s.m_to_si = &s.si; 230 } 231 }; 232 233 struct scoped_generation { 234 solver& s; 235 unsigned m_g; scoped_generationscoped_generation236 scoped_generation(solver& s, unsigned g): 237 s(s), 238 m_g(s.m_generation) { 239 s.m_generation = g; 240 } ~scoped_generationscoped_generation241 ~scoped_generation() { 242 s.m_generation = m_g; 243 } 244 }; 245 unsigned get_max_generation(expr* e) const; 246 247 // accessors 248 get_si()249 sat::sat_internalizer& get_si() { return si; } get_manager()250 ast_manager& get_manager() { return m; } get_enode(expr * e)251 enode* get_enode(expr* e) const { return m_egraph.find(e); } expr2literal(expr * e)252 sat::literal expr2literal(expr* e) const { return enode2literal(get_enode(e)); } enode2literal(enode * n)253 sat::literal enode2literal(enode* n) const { return sat::literal(n->bool_var(), false); } value(enode * n)254 lbool value(enode* n) const { return s().value(enode2literal(n)); } get_config()255 smt_params const& get_config() const { return m_config; } get_region()256 region& get_region() { return m_trail.get_region(); } get_egraph()257 egraph& get_egraph() { return m_egraph; } fid2solver(family_id fid)258 th_solver* fid2solver(family_id fid) const { return m_id2solver.get(fid, nullptr); } 259 260 template <typename C> push(C const & c)261 void push(C const& c) { m_trail.push(c); } 262 template <typename V> push_vec(ptr_vector<V> & vec,V * val)263 void push_vec(ptr_vector<V>& vec, V* val) { 264 vec.push_back(val); 265 push(push_back_trail< V*, false>(vec)); 266 } 267 template <typename V> push_vec(svector<V> & vec,V val)268 void push_vec(svector<V>& vec, V val) { 269 vec.push_back(val); 270 push(push_back_trail< V, false>(vec)); 271 } get_trail_stack()272 trail_stack& get_trail_stack() { return m_trail; } 273 274 void updt_params(params_ref const& p); set_lookahead(sat::lookahead * s)275 void set_lookahead(sat::lookahead* s) override { m_lookahead = s; } 276 void init_search() override; 277 double get_reward(literal l, ext_constraint_idx idx, sat::literal_occs_fun& occs) const override; 278 bool is_extended_binary(ext_justification_idx idx, literal_vector& r) override; 279 bool is_external(bool_var v) override; 280 bool propagated(literal l, ext_constraint_idx idx) override; 281 bool unit_propagate() override; 282 bool should_research(sat::literal_vector const& core) override; 283 void add_assumptions(sat::literal_set& assumptions) override; 284 bool tracking_assumptions() override; 285 286 void propagate(literal lit, ext_justification_idx idx); 287 bool propagate(enode* a, enode* b, ext_justification_idx idx); 288 void set_conflict(ext_justification_idx idx); 289 propagate(literal lit,th_explain * p)290 void propagate(literal lit, th_explain* p) { propagate(lit, p->to_index()); } propagate(enode * a,enode * b,th_explain * p)291 bool propagate(enode* a, enode* b, th_explain* p) { return propagate(a, b, p->to_index()); } to_justification(sat::literal l)292 size_t* to_justification(sat::literal l) { return to_ptr(l); } set_conflict(th_explain * p)293 void set_conflict(th_explain* p) { set_conflict(p->to_index()); } 294 295 bool set_root(literal l, literal r) override; 296 void flush_roots() override; 297 298 void get_antecedents(literal l, ext_justification_idx idx, literal_vector& r, bool probing) override; 299 void get_antecedents(literal l, th_explain& jst, literal_vector& r, bool probing); 300 void add_antecedent(enode* a, enode* b); 301 void add_diseq_antecedent(enode* a, enode* b); 302 void set_eliminated(bool_var v) override; 303 void asserted(literal l) override; 304 sat::check_result check() override; 305 void push() override; 306 void pop(unsigned n) override; 307 void user_push() override; 308 void user_pop(unsigned n) override; 309 void pre_simplify() override; 310 void simplify() override; 311 // have a way to replace l by r in all constraints 312 void clauses_modifed() override; 313 lbool get_phase(bool_var v) override; 314 std::ostream& display(std::ostream& out) const override; 315 std::ostream& display_justification(std::ostream& out, ext_justification_idx idx) const override; 316 std::ostream& display_constraint(std::ostream& out, ext_constraint_idx idx) const override; bpp(enode * n)317 euf::egraph::b_pp bpp(enode* n) const { return m_egraph.bpp(n); } pp(literal_vector const & lits)318 clause_pp pp(literal_vector const& lits) { return clause_pp(*this, lits); } 319 void collect_statistics(statistics& st) const override; 320 extension* copy(sat::solver* s) override; 321 enode* copy(solver& dst_ctx, enode* src_n); 322 void find_mutexes(literal_vector& lits, vector<literal_vector>& mutexes) override; 323 void gc() override; 324 void pop_reinit() override; 325 bool validate() override; 326 void init_use_list(sat::ext_use_list& ul) override; 327 bool is_blocked(literal l, ext_constraint_idx) override; 328 bool check_model(sat::model const& m) const override; 329 void gc_vars(unsigned num_vars) override; resource_limits_exceeded()330 bool resource_limits_exceeded() const { return false; } // TODO 331 332 333 // proof use_drat()334 bool use_drat() { return s().get_config().m_drat && (init_drat(), true); } get_drat()335 sat::drat& get_drat() { return s().get_drat(); } 336 void drat_bool_def(sat::bool_var v, expr* n); 337 void drat_eq_def(sat::literal lit, expr* eq); 338 339 // decompile 340 bool extract_pb(std::function<void(unsigned sz, literal const* c, unsigned k)>& card, 341 std::function<void(unsigned sz, literal const* c, unsigned const* coeffs, unsigned k)>& pb) override; 342 343 bool to_formulas(std::function<expr_ref(sat::literal)>& l2e, expr_ref_vector& fmls) override; 344 345 // internalize 346 sat::literal internalize(expr* e, bool sign, bool root, bool learned) override; 347 void internalize(expr* e, bool learned) override; 348 sat::literal mk_literal(expr* e); attach_th_var(enode * n,th_solver * th,theory_var v)349 void attach_th_var(enode* n, th_solver* th, theory_var v) { m_egraph.add_th_var(n, v, th->get_id()); } 350 void attach_node(euf::enode* n); 351 expr_ref mk_eq(expr* e1, expr* e2); mk_eq(euf::enode * n1,euf::enode * n2)352 expr_ref mk_eq(euf::enode* n1, euf::enode* n2) { return mk_eq(n1->get_expr(), n2->get_expr()); } 353 euf::enode* e_internalize(expr* e); 354 euf::enode* mk_enode(expr* e, unsigned n, enode* const* args); bool_var2expr(sat::bool_var v)355 expr* bool_var2expr(sat::bool_var v) const { return m_bool_var2expr.get(v, nullptr); } literal2expr(sat::literal lit)356 expr_ref literal2expr(sat::literal lit) const { expr* e = bool_var2expr(lit.var()); return (e && lit.sign()) ? expr_ref(m.mk_not(e), m) : expr_ref(e, m); } generation()357 unsigned generation() const { return m_generation; } 358 359 sat::literal attach_lit(sat::literal lit, expr* e); 360 void unhandled_function(func_decl* f); get_rewriter()361 th_rewriter& get_rewriter() { return m_rewriter; } rewrite(expr_ref & e)362 void rewrite(expr_ref& e) { m_rewriter(e); } 363 bool is_shared(euf::enode* n) const; 364 365 // relevancy 366 bool m_relevancy = true; 367 scoped_ptr<sat::dual_solver> m_dual_solver; 368 ptr_vector<expr> m_auto_relevant; 369 unsigned_vector m_auto_relevant_lim; 370 unsigned m_auto_relevant_scopes = 0; 371 relevancy_enabled()372 bool relevancy_enabled() const { return m_relevancy && get_config().m_relevancy_lvl > 0; } disable_relevancy(expr * e)373 void disable_relevancy(expr* e) { IF_VERBOSE(0, verbose_stream() << "disabling relevancy " << mk_pp(e, m) << "\n"); m_relevancy = false; } 374 void add_root(unsigned n, sat::literal const* lits); add_root(sat::literal_vector const & lits)375 void add_root(sat::literal_vector const& lits) { add_root(lits.size(), lits.data()); } add_root(sat::literal lit)376 void add_root(sat::literal lit) { add_root(1, &lit); } add_root(sat::literal a,sat::literal b)377 void add_root(sat::literal a, sat::literal b) { sat::literal lits[2] = {a, b}; add_root(2, lits); } add_aux(sat::literal_vector const & lits)378 void add_aux(sat::literal_vector const& lits) { add_aux(lits.size(), lits.data()); } 379 void add_aux(unsigned n, sat::literal const* lits); add_aux(sat::literal a)380 void add_aux(sat::literal a) { sat::literal lits[1] = { a }; add_aux(1, lits); } add_aux(sat::literal a,sat::literal b)381 void add_aux(sat::literal a, sat::literal b) { sat::literal lits[2] = {a, b}; add_aux(2, lits); } add_aux(sat::literal a,sat::literal b,sat::literal c)382 void add_aux(sat::literal a, sat::literal b, sat::literal c) { sat::literal lits[3] = { a, b, c }; add_aux(3, lits); } 383 void track_relevancy(sat::bool_var v); 384 bool is_relevant(expr* e) const; 385 bool is_relevant(enode* n) const; 386 void add_auto_relevant(expr* e); 387 void pop_relevant(unsigned n); 388 void push_relevant(); 389 390 391 // model construction 392 void update_model(model_ref& mdl); 393 obj_map<expr, enode*> const& values2root(); 394 void model_updated(model_ref& mdl); 395 expr* node2value(enode* n) const; 396 397 // diagnostics unhandled_functions()398 func_decl_ref_vector const& unhandled_functions() { return m_unhandled_functions; } 399 400 // user propagator 401 void user_propagate_init( 402 void* ctx, 403 ::solver::push_eh_t& push_eh, 404 ::solver::pop_eh_t& pop_eh, 405 ::solver::fresh_eh_t& fresh_eh); 406 bool watches_fixed(enode* n) const; 407 void assign_fixed(enode* n, expr* val, unsigned sz, literal const* explain); assign_fixed(enode * n,expr * val,literal_vector const & explain)408 void assign_fixed(enode* n, expr* val, literal_vector const& explain) { assign_fixed(n, val, explain.size(), explain.data()); } assign_fixed(enode * n,expr * val,literal explain)409 void assign_fixed(enode* n, expr* val, literal explain) { assign_fixed(n, val, 1, &explain); } 410 user_propagate_register_final(::solver::final_eh_t & final_eh)411 void user_propagate_register_final(::solver::final_eh_t& final_eh) { 412 check_for_user_propagator(); 413 m_user_propagator->register_final(final_eh); 414 } user_propagate_register_fixed(::solver::fixed_eh_t & fixed_eh)415 void user_propagate_register_fixed(::solver::fixed_eh_t& fixed_eh) { 416 check_for_user_propagator(); 417 m_user_propagator->register_fixed(fixed_eh); 418 } user_propagate_register_eq(::solver::eq_eh_t & eq_eh)419 void user_propagate_register_eq(::solver::eq_eh_t& eq_eh) { 420 check_for_user_propagator(); 421 m_user_propagator->register_eq(eq_eh); 422 } user_propagate_register_diseq(::solver::eq_eh_t & diseq_eh)423 void user_propagate_register_diseq(::solver::eq_eh_t& diseq_eh) { 424 check_for_user_propagator(); 425 m_user_propagator->register_diseq(diseq_eh); 426 } user_propagate_register(expr * e)427 unsigned user_propagate_register(expr* e) { 428 check_for_user_propagator(); 429 return m_user_propagator->add_expr(e); 430 } 431 432 // solver factory mk_solver()433 ::solver* mk_solver() { return m_mk_solver(); } set_mk_solver(std::function<::solver * (void)> & mk)434 void set_mk_solver(std::function<::solver*(void)>& mk) { m_mk_solver = mk; } 435 436 437 }; 438 439 inline std::ostream& operator<<(std::ostream& out, clause_pp const& p) { 440 return p.display(out); 441 } 442 443 }; 444 445 inline std::ostream& operator<<(std::ostream& out, euf::solver const& s) { 446 return s.display(out); 447 } 448 449