1 /*++ 2 Copyright (c) 2020 Microsoft Corporation 3 4 Module Name: 5 6 euf_egraph.h 7 8 Abstract: 9 10 E-graph layer 11 12 Author: 13 14 Nikolaj Bjorner (nbjorner) 2020-08-23 15 16 Notes: 17 18 It relies on 19 - data structures form the (legacy) SMT solver. 20 - it still uses eager path compression. 21 22 NB. The worklist is in reality inheritied from the legacy SMT solver. 23 It is claimed to have the same effect as delayed congruence table reconstruction from egg. 24 Similar to the legacy solver, parents are partially deduplicated. 25 26 --*/ 27 28 #pragma once 29 #include "util/statistics.h" 30 #include "util/trail.h" 31 #include "util/lbool.h" 32 #include "ast/euf/euf_enode.h" 33 #include "ast/euf/euf_etable.h" 34 #include "ast/ast_ll_pp.h" 35 36 namespace euf { 37 38 /*** 39 \brief store derived theory equalities and disequalities 40 Theory 'id' is notified with the equality/disequality of theory variables v1, v2. 41 For equalities, v1 and v2 are merged into the common root of child and root (their roots may 42 have been updated since the equality was derived, but the explanation for 43 v1 == v2 is provided by explaining the equality child == root. 44 For disequalities, m_child refers to an equality atom of the form e1 == e2. 45 It is equal to false under the current context. 46 The explanation for the disequality v1 != v2 is derived from explaining the 47 equality between the expression for v1 and e1, and the expression for v2 and e2 48 and the equality of m_eq and false: the literal corresponding to m_eq is false in the 49 current assignment stack, or m_child is congruent to false in the egraph. 50 */ 51 class th_eq { 52 53 theory_id m_id; 54 theory_var m_v1; 55 theory_var m_v2; 56 union { 57 enode* m_child; 58 expr* m_eq; 59 }; 60 enode* m_root; 61 public: is_eq()62 bool is_eq() const { return m_root != nullptr; } id()63 theory_id id() const { return m_id; } v1()64 theory_var v1() const { return m_v1; } v2()65 theory_var v2() const { return m_v2; } child()66 enode* child() const { SASSERT(is_eq()); return m_child; } root()67 enode* root() const { SASSERT(is_eq()); return m_root; } eq()68 expr* eq() const { SASSERT(!is_eq()); return m_eq; } th_eq(theory_id id,theory_var v1,theory_var v2,enode * c,enode * r)69 th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r) : 70 m_id(id), m_v1(v1), m_v2(v2), m_child(c), m_root(r) {} th_eq(theory_id id,theory_var v1,theory_var v2,expr * eq)71 th_eq(theory_id id, theory_var v1, theory_var v2, expr* eq) : 72 m_id(id), m_v1(v1), m_v2(v2), m_eq(eq), m_root(nullptr) {} 73 }; 74 75 class egraph { 76 77 typedef ptr_vector<trail> trail_stack; 78 79 struct to_merge { 80 enode* a, * b; 81 bool commutativity; to_mergeto_merge82 to_merge(enode* a, enode* b, bool c) : a(a), b(b), commutativity(c) {} 83 }; 84 85 struct stats { 86 unsigned m_num_merge; 87 unsigned m_num_th_eqs; 88 unsigned m_num_th_diseqs; 89 unsigned m_num_lits; 90 unsigned m_num_eqs; 91 unsigned m_num_conflicts; statsstats92 stats() { reset(); } resetstats93 void reset() { memset(this, 0, sizeof(*this)); } 94 }; 95 struct update_record { 96 struct toggle_merge {}; 97 struct add_th_var {}; 98 struct replace_th_var {}; 99 struct new_lit {}; 100 struct new_th_eq {}; 101 struct new_th_eq_qhead {}; 102 struct new_lits_qhead {}; 103 struct inconsistent {}; 104 struct value_assignment {}; 105 struct lbl_hash {}; 106 struct lbl_set {}; 107 struct update_children {}; 108 enum class tag_t { is_set_parent, is_add_node, is_toggle_merge, is_update_children, 109 is_add_th_var, is_replace_th_var, is_new_lit, is_new_th_eq, 110 is_lbl_hash, is_new_th_eq_qhead, is_new_lits_qhead, 111 is_inconsistent, is_value_assignment, is_lbl_set }; 112 tag_t tag; 113 enode* r1; 114 enode* n1; 115 union { 116 unsigned r2_num_parents; 117 struct { 118 unsigned m_th_id : 8; 119 unsigned m_old_th_var : 24; 120 }; 121 unsigned qhead; 122 bool m_inconsistent; 123 signed char m_lbl_hash; 124 unsigned long long m_lbls; 125 }; update_recordupdate_record126 update_record(enode* r1, enode* n1, unsigned r2_num_parents) : 127 tag(tag_t::is_set_parent), r1(r1), n1(n1), r2_num_parents(r2_num_parents) {} update_recordupdate_record128 update_record(enode* n) : 129 tag(tag_t::is_add_node), r1(n), n1(nullptr), r2_num_parents(UINT_MAX) {} update_recordupdate_record130 update_record(enode* n, toggle_merge) : 131 tag(tag_t::is_toggle_merge), r1(n), n1(nullptr), r2_num_parents(UINT_MAX) {} update_recordupdate_record132 update_record(enode* n, unsigned id, add_th_var) : 133 tag(tag_t::is_add_th_var), r1(n), n1(nullptr), r2_num_parents(id) {} update_recordupdate_record134 update_record(enode* n, theory_id id, theory_var v, replace_th_var) : 135 tag(tag_t::is_replace_th_var), r1(n), n1(nullptr), m_th_id(id), m_old_th_var(v) {} update_recordupdate_record136 update_record(new_lit) : 137 tag(tag_t::is_new_lit), r1(nullptr), n1(nullptr), r2_num_parents(0) {} update_recordupdate_record138 update_record(new_th_eq) : 139 tag(tag_t::is_new_th_eq), r1(nullptr), n1(nullptr), r2_num_parents(0) {} update_recordupdate_record140 update_record(unsigned qh, new_th_eq_qhead): 141 tag(tag_t::is_new_th_eq_qhead), r1(nullptr), n1(nullptr), qhead(qh) {} update_recordupdate_record142 update_record(unsigned qh, new_lits_qhead): 143 tag(tag_t::is_new_lits_qhead), r1(nullptr), n1(nullptr), qhead(qh) {} update_recordupdate_record144 update_record(bool inc, inconsistent) : 145 tag(tag_t::is_inconsistent), r1(nullptr), n1(nullptr), m_inconsistent(inc) {} update_recordupdate_record146 update_record(enode* n, value_assignment) : 147 tag(tag_t::is_value_assignment), r1(n), n1(nullptr), qhead(0) {} update_recordupdate_record148 update_record(enode* n, lbl_hash): 149 tag(tag_t::is_lbl_hash), r1(n), n1(nullptr), m_lbl_hash(n->m_lbl_hash) {} update_recordupdate_record150 update_record(enode* n, lbl_set): 151 tag(tag_t::is_lbl_set), r1(n), n1(nullptr), m_lbls(n->m_lbls.get()) {} update_recordupdate_record152 update_record(enode* n, update_children) : 153 tag(tag_t::is_update_children), r1(n), n1(nullptr), r2_num_parents(UINT_MAX) {} 154 }; 155 ast_manager& m; 156 svector<to_merge> m_to_merge; 157 etable m_table; 158 region m_region; 159 svector<update_record> m_updates; 160 unsigned_vector m_scopes; 161 enode_vector m_expr2enode; 162 enode* m_tmp_eq = nullptr; 163 enode* m_tmp_node = nullptr; 164 unsigned m_tmp_node_capacity = 0; 165 tmp_app m_tmp_app; 166 enode_vector m_nodes; 167 expr_ref_vector m_exprs; 168 func_decl_ref_vector m_eq_decls; 169 vector<enode_vector> m_decl2enodes; 170 enode_vector m_empty_enodes; 171 unsigned m_num_scopes = 0; 172 bool m_inconsistent = false; 173 enode *m_n1 = nullptr; 174 enode *m_n2 = nullptr; 175 justification m_justification; 176 unsigned m_new_lits_qhead = 0; 177 unsigned m_new_th_eqs_qhead = 0; 178 svector<enode_bool_pair> m_new_lits; 179 svector<th_eq> m_new_th_eqs; 180 bool_vector m_th_propagates_diseqs; 181 enode_vector m_todo; 182 stats m_stats; 183 bool m_uses_congruence = false; 184 std::function<void(enode*,enode*)> m_on_merge; 185 std::function<void(enode*)> m_on_make; 186 std::function<void(expr*,expr*,expr*)> m_used_eq; 187 std::function<void(app*,app*)> m_used_cc; 188 std::function<void(std::ostream&, void*)> m_display_justification; 189 push_eq(enode * r1,enode * n1,unsigned r2_num_parents)190 void push_eq(enode* r1, enode* n1, unsigned r2_num_parents) { 191 m_updates.push_back(update_record(r1, n1, r2_num_parents)); 192 } push_node(enode * n)193 void push_node(enode* n) { m_updates.push_back(update_record(n)); } 194 195 void add_th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r); 196 197 void add_th_diseqs(theory_id id, theory_var v1, enode* r); 198 bool th_propagates_diseqs(theory_id id) const; 199 void add_literal(enode* n, bool is_eq); 200 void undo_eq(enode* r1, enode* n1, unsigned r2_num_parents); 201 void undo_add_th_var(enode* n, theory_id id); 202 enode* mk_enode(expr* f, unsigned generation, unsigned num_args, enode * const* args); 203 void force_push(); 204 void set_conflict(enode* n1, enode* n2, justification j); 205 void merge(enode* n1, enode* n2, justification j); 206 void merge_th_eq(enode* n, enode* root); 207 void merge_justification(enode* n1, enode* n2, justification j); 208 void reinsert_parents(enode* r1, enode* r2); 209 void remove_parents(enode* r1, enode* r2); 210 void unmerge_justification(enode* n1); 211 void reinsert_equality(enode* p); 212 void update_children(enode* n); 213 void push_lca(enode* a, enode* b); 214 enode* find_lca(enode* a, enode* b); 215 void push_to_lca(enode* a, enode* lca); 216 void push_congruence(enode* n1, enode* n2, bool commutative); 217 void push_todo(enode* n); 218 void toggle_merge_enabled(enode* n, bool backtracking); 219 220 enode_bool_pair insert_table(enode* p); 221 void erase_from_table(enode* p); 222 223 template <typename T> explain_eq(ptr_vector<T> & justifications,enode * a,enode * b,justification const & j)224 void explain_eq(ptr_vector<T>& justifications, enode* a, enode* b, justification const& j) { 225 if (j.is_external()) 226 justifications.push_back(j.ext<T>()); 227 else if (j.is_congruence()) 228 push_congruence(a, b, j.is_commutative()); 229 } 230 template <typename T> 231 void explain_todo(ptr_vector<T>& justifications); 232 233 std::ostream& display(std::ostream& out, unsigned max_args, enode* n) const; 234 235 public: 236 egraph(ast_manager& m); 237 ~egraph(); find(expr * f)238 enode* find(expr* f) const { return m_expr2enode.get(f->get_id(), nullptr); } 239 enode* find(expr* f, unsigned n, enode* const* args); 240 enode* mk(expr* f, unsigned generation, unsigned n, enode *const* args); 241 enode_vector const& enodes_of(func_decl* f); push()242 void push() { if (!m_to_merge.empty()) propagate(); ++m_num_scopes; } 243 void pop(unsigned num_scopes); 244 245 /** 246 \brief merge nodes, all effects are deferred to the propagation step. 247 */ merge(enode * n1,enode * n2,void * reason)248 void merge(enode* n1, enode* n2, void* reason) { merge(n1, n2, justification::external(reason)); } 249 void new_diseq(enode* n); 250 251 252 /** 253 \brief propagate set of merges. 254 This call may detect an inconsistency. Then inconsistent() is true. 255 Use then explain() to extract an explanation for the conflict. 256 257 It may also infer new implied equalities, when the roots of the 258 equated nodes are merged. Use then new_eqs() to extract the vector 259 of new equalities. 260 */ 261 bool propagate(); inconsistent()262 bool inconsistent() const { return m_inconsistent; } 263 264 /** 265 * \brief check if two nodes are known to be disequal. 266 */ 267 bool are_diseq(enode* a, enode* b); 268 269 enode* get_enode_eq_to(func_decl* f, unsigned num_args, enode* const* args); 270 271 enode* tmp_eq(enode* a, enode* b); 272 273 /** 274 \brief Maintain and update cursor into propagated consequences. 275 The result of get_literal() is a pair (n, is_eq) 276 where \c n is an enode and \c is_eq indicates whether the enode 277 is an equality consequence. 278 */ 279 void add_th_diseq(theory_id id, theory_var v1, theory_var v2, expr* eq); has_literal()280 bool has_literal() const { return m_new_lits_qhead < m_new_lits.size(); } has_th_eq()281 bool has_th_eq() const { return m_new_th_eqs_qhead < m_new_th_eqs.size(); } get_literal()282 enode_bool_pair get_literal() const { return m_new_lits[m_new_lits_qhead]; } get_th_eq()283 th_eq get_th_eq() const { return m_new_th_eqs[m_new_th_eqs_qhead]; } next_literal()284 void next_literal() { force_push(); SASSERT(m_new_lits_qhead < m_new_lits.size()); m_new_lits_qhead++; } next_th_eq()285 void next_th_eq() { force_push(); SASSERT(m_new_th_eqs_qhead < m_new_th_eqs.size()); m_new_th_eqs_qhead++; } 286 287 void set_lbl_hash(enode* n); 288 289 290 void add_th_var(enode* n, theory_var v, theory_id id); 291 void set_th_propagates_diseqs(theory_id id); 292 void set_merge_enabled(enode* n, bool enable_merge); 293 void set_value(enode* n, lbool value); set_bool_var(enode * n,unsigned v)294 void set_bool_var(enode* n, unsigned v) { n->set_bool_var(v); } 295 set_on_merge(std::function<void (enode * root,enode * other)> & on_merge)296 void set_on_merge(std::function<void(enode* root,enode* other)>& on_merge) { m_on_merge = on_merge; } set_on_make(std::function<void (enode * n)> & on_make)297 void set_on_make(std::function<void(enode* n)>& on_make) { m_on_make = on_make; } set_used_eq(std::function<void (expr *,expr *,expr *)> & used_eq)298 void set_used_eq(std::function<void(expr*,expr*,expr*)>& used_eq) { m_used_eq = used_eq; } set_used_cc(std::function<void (app *,app *)> & used_cc)299 void set_used_cc(std::function<void(app*,app*)>& used_cc) { m_used_cc = used_cc; } set_display_justification(std::function<void (std::ostream &,void *)> & d)300 void set_display_justification(std::function<void (std::ostream&, void*)> & d) { m_display_justification = d; } 301 302 void begin_explain(); 303 void end_explain(); uses_congruence()304 bool uses_congruence() const { return m_uses_congruence; } 305 template <typename T> 306 void explain(ptr_vector<T>& justifications); 307 template <typename T> 308 void explain_eq(ptr_vector<T>& justifications, enode* a, enode* b); 309 template <typename T> 310 unsigned explain_diseq(ptr_vector<T>& justifications, enode* a, enode* b); nodes()311 enode_vector const& nodes() const { return m_nodes; } 312 get_manager()313 ast_manager& get_manager() { return m; } 314 315 void invariant(); 316 void copy_from(egraph const& src, std::function<void*(void*)>& copy_justification); 317 struct e_pp { 318 egraph const& g; 319 enode* n; e_ppe_pp320 e_pp(egraph const& g, enode* n) : g(g), n(n) {} displaye_pp321 std::ostream& display(std::ostream& out) const { return g.display(out, 0, n); } 322 }; pp(enode * n)323 e_pp pp(enode* n) const { return e_pp(*this, n); } 324 struct b_pp { 325 egraph const& g; 326 enode* n; b_ppb_pp327 b_pp(egraph const& g, enode* n) : g(g), n(n) {} displayb_pp328 std::ostream& display(std::ostream& out) const { return n ? (out << n->get_expr_id() << ": " << mk_bounded_pp(n->get_expr(), g.m)) : out << "null"; } 329 }; bpp(enode * n)330 b_pp bpp(enode* n) const { return b_pp(*this, n); } 331 std::ostream& display(std::ostream& out) const; 332 333 void collect_statistics(statistics& st) const; 334 num_scopes()335 unsigned num_scopes() const { return m_scopes.size() + m_num_scopes; } num_nodes()336 unsigned num_nodes() const { return m_nodes.size(); } 337 }; 338 339 inline std::ostream& operator<<(std::ostream& out, egraph const& g) { return g.display(out); } 340 inline std::ostream& operator<<(std::ostream& out, egraph::e_pp const& p) { return p.display(out); } 341 inline std::ostream& operator<<(std::ostream& out, egraph::b_pp const& p) { return p.display(out); } 342 } 343