1 /*++
2 Copyright (c) 2020 Microsoft Corporation
3 
4 Module Name:
5 
6     euf_egraph.h
7 
8 Abstract:
9 
10     E-graph layer
11 
12 Author:
13 
14     Nikolaj Bjorner (nbjorner) 2020-08-23
15 
16 Notes:
17 
18     It relies on
19     - data structures form the (legacy) SMT solver.
20       - it still uses eager path compression.
21 
22     NB. The worklist is in reality inheritied from the legacy SMT solver.
23     It is claimed to have the same effect as delayed congruence table reconstruction from egg.
24     Similar to the legacy solver, parents are partially deduplicated.
25 
26 --*/
27 
28 #pragma once
29 #include "util/statistics.h"
30 #include "util/trail.h"
31 #include "util/lbool.h"
32 #include "ast/euf/euf_enode.h"
33 #include "ast/euf/euf_etable.h"
34 #include "ast/ast_ll_pp.h"
35 
36 namespace euf {
37 
38     /***
39       \brief store derived theory equalities and disequalities
40       Theory 'id' is notified with the equality/disequality of theory variables v1, v2.
41       For equalities, v1 and v2 are merged into the common root of child and root (their roots may
42       have been updated since the equality was derived, but the explanation for
43       v1 == v2 is provided by explaining the equality child == root.
44       For disequalities, m_child refers to an equality atom of the form e1 == e2.
45       It is equal to false under the current context.
46       The explanation for the disequality v1 != v2 is derived from explaining the
47       equality between the expression for v1 and e1, and the expression for v2 and e2
48       and the equality of m_eq and false: the literal corresponding to m_eq is false in the
49       current assignment stack, or m_child is congruent to false in the egraph.
50     */
51     class th_eq {
52 
53         theory_id  m_id;
54         theory_var m_v1;
55         theory_var m_v2;
56         union {
57             enode* m_child;
58             expr* m_eq;
59         };
60         enode* m_root;
61     public:
is_eq()62         bool is_eq() const { return m_root != nullptr; }
id()63         theory_id id() const { return m_id; }
v1()64         theory_var v1() const { return m_v1; }
v2()65         theory_var v2() const { return m_v2; }
child()66         enode* child() const { SASSERT(is_eq()); return m_child; }
root()67         enode* root() const { SASSERT(is_eq()); return m_root; }
eq()68         expr* eq() const { SASSERT(!is_eq()); return m_eq; }
th_eq(theory_id id,theory_var v1,theory_var v2,enode * c,enode * r)69         th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r) :
70             m_id(id), m_v1(v1), m_v2(v2), m_child(c), m_root(r) {}
th_eq(theory_id id,theory_var v1,theory_var v2,expr * eq)71         th_eq(theory_id id, theory_var v1, theory_var v2, expr* eq) :
72             m_id(id), m_v1(v1), m_v2(v2), m_eq(eq), m_root(nullptr) {}
73     };
74 
75     class egraph {
76 
77         typedef ptr_vector<trail> trail_stack;
78 
79         struct to_merge {
80             enode* a, * b;
81             bool commutativity;
to_mergeto_merge82             to_merge(enode* a, enode* b, bool c) : a(a), b(b), commutativity(c) {}
83         };
84 
85         struct stats {
86             unsigned m_num_merge;
87             unsigned m_num_th_eqs;
88             unsigned m_num_th_diseqs;
89             unsigned m_num_lits;
90             unsigned m_num_eqs;
91             unsigned m_num_conflicts;
statsstats92             stats() { reset(); }
resetstats93             void reset() { memset(this, 0, sizeof(*this)); }
94         };
95         struct update_record {
96             struct toggle_merge {};
97             struct add_th_var {};
98             struct replace_th_var {};
99             struct new_lit {};
100             struct new_th_eq {};
101             struct new_th_eq_qhead {};
102             struct new_lits_qhead {};
103             struct inconsistent {};
104             struct value_assignment {};
105             struct lbl_hash {};
106             struct lbl_set {};
107             struct update_children {};
108             enum class tag_t { is_set_parent, is_add_node, is_toggle_merge, is_update_children,
109                     is_add_th_var, is_replace_th_var, is_new_lit, is_new_th_eq,
110                     is_lbl_hash, is_new_th_eq_qhead, is_new_lits_qhead,
111                     is_inconsistent, is_value_assignment, is_lbl_set };
112             tag_t  tag;
113             enode* r1;
114             enode* n1;
115             union {
116                 unsigned r2_num_parents;
117                 struct {
118                     unsigned   m_th_id : 8;
119                     unsigned   m_old_th_var : 24;
120                 };
121                 unsigned qhead;
122                 bool     m_inconsistent;
123                 signed char m_lbl_hash;
124                 unsigned long long m_lbls;
125             };
update_recordupdate_record126             update_record(enode* r1, enode* n1, unsigned r2_num_parents) :
127                 tag(tag_t::is_set_parent), r1(r1), n1(n1), r2_num_parents(r2_num_parents) {}
update_recordupdate_record128             update_record(enode* n) :
129                 tag(tag_t::is_add_node), r1(n), n1(nullptr), r2_num_parents(UINT_MAX) {}
update_recordupdate_record130             update_record(enode* n, toggle_merge) :
131                 tag(tag_t::is_toggle_merge), r1(n), n1(nullptr), r2_num_parents(UINT_MAX) {}
update_recordupdate_record132             update_record(enode* n, unsigned id, add_th_var) :
133                 tag(tag_t::is_add_th_var), r1(n), n1(nullptr), r2_num_parents(id) {}
update_recordupdate_record134             update_record(enode* n, theory_id id, theory_var v, replace_th_var) :
135                 tag(tag_t::is_replace_th_var), r1(n), n1(nullptr), m_th_id(id), m_old_th_var(v) {}
update_recordupdate_record136             update_record(new_lit) :
137                 tag(tag_t::is_new_lit), r1(nullptr), n1(nullptr), r2_num_parents(0) {}
update_recordupdate_record138             update_record(new_th_eq) :
139                 tag(tag_t::is_new_th_eq), r1(nullptr), n1(nullptr), r2_num_parents(0) {}
update_recordupdate_record140             update_record(unsigned qh, new_th_eq_qhead):
141                 tag(tag_t::is_new_th_eq_qhead), r1(nullptr), n1(nullptr), qhead(qh) {}
update_recordupdate_record142             update_record(unsigned qh, new_lits_qhead):
143                 tag(tag_t::is_new_lits_qhead), r1(nullptr), n1(nullptr), qhead(qh) {}
update_recordupdate_record144             update_record(bool inc, inconsistent) :
145                 tag(tag_t::is_inconsistent), r1(nullptr), n1(nullptr), m_inconsistent(inc) {}
update_recordupdate_record146             update_record(enode* n, value_assignment) :
147                 tag(tag_t::is_value_assignment), r1(n), n1(nullptr), qhead(0) {}
update_recordupdate_record148             update_record(enode* n, lbl_hash):
149                 tag(tag_t::is_lbl_hash), r1(n), n1(nullptr), m_lbl_hash(n->m_lbl_hash) {}
update_recordupdate_record150             update_record(enode* n, lbl_set):
151                 tag(tag_t::is_lbl_set), r1(n), n1(nullptr), m_lbls(n->m_lbls.get()) {}
update_recordupdate_record152             update_record(enode* n, update_children) :
153                 tag(tag_t::is_update_children), r1(n), n1(nullptr), r2_num_parents(UINT_MAX) {}
154         };
155         ast_manager&           m;
156         svector<to_merge>      m_to_merge;
157         etable                 m_table;
158         region                 m_region;
159         svector<update_record> m_updates;
160         unsigned_vector        m_scopes;
161         enode_vector           m_expr2enode;
162         enode*                 m_tmp_eq = nullptr;
163         enode*                 m_tmp_node = nullptr;
164         unsigned               m_tmp_node_capacity = 0;
165         tmp_app                m_tmp_app;
166         enode_vector           m_nodes;
167         expr_ref_vector        m_exprs;
168         func_decl_ref_vector   m_eq_decls;
169         vector<enode_vector>   m_decl2enodes;
170         enode_vector           m_empty_enodes;
171         unsigned               m_num_scopes = 0;
172         bool                   m_inconsistent = false;
173         enode                  *m_n1 = nullptr;
174         enode                  *m_n2 = nullptr;
175         justification          m_justification;
176         unsigned               m_new_lits_qhead = 0;
177         unsigned               m_new_th_eqs_qhead = 0;
178         svector<enode_bool_pair>  m_new_lits;
179         svector<th_eq>         m_new_th_eqs;
180         bool_vector            m_th_propagates_diseqs;
181         enode_vector           m_todo;
182         stats                  m_stats;
183         bool                   m_uses_congruence = false;
184         std::function<void(enode*,enode*)>     m_on_merge;
185         std::function<void(enode*)>            m_on_make;
186         std::function<void(expr*,expr*,expr*)> m_used_eq;
187         std::function<void(app*,app*)>         m_used_cc;
188         std::function<void(std::ostream&, void*)>   m_display_justification;
189 
push_eq(enode * r1,enode * n1,unsigned r2_num_parents)190         void push_eq(enode* r1, enode* n1, unsigned r2_num_parents) {
191             m_updates.push_back(update_record(r1, n1, r2_num_parents));
192         }
push_node(enode * n)193         void push_node(enode* n) { m_updates.push_back(update_record(n)); }
194 
195         void add_th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r);
196 
197         void add_th_diseqs(theory_id id, theory_var v1, enode* r);
198         bool th_propagates_diseqs(theory_id id) const;
199         void add_literal(enode* n, bool is_eq);
200         void undo_eq(enode* r1, enode* n1, unsigned r2_num_parents);
201         void undo_add_th_var(enode* n, theory_id id);
202         enode* mk_enode(expr* f, unsigned generation, unsigned num_args, enode * const* args);
203         void force_push();
204         void set_conflict(enode* n1, enode* n2, justification j);
205         void merge(enode* n1, enode* n2, justification j);
206         void merge_th_eq(enode* n, enode* root);
207         void merge_justification(enode* n1, enode* n2, justification j);
208         void reinsert_parents(enode* r1, enode* r2);
209         void remove_parents(enode* r1, enode* r2);
210         void unmerge_justification(enode* n1);
211         void reinsert_equality(enode* p);
212         void update_children(enode* n);
213         void push_lca(enode* a, enode* b);
214         enode* find_lca(enode* a, enode* b);
215         void push_to_lca(enode* a, enode* lca);
216         void push_congruence(enode* n1, enode* n2, bool commutative);
217         void push_todo(enode* n);
218         void toggle_merge_enabled(enode* n, bool backtracking);
219 
220         enode_bool_pair insert_table(enode* p);
221         void erase_from_table(enode* p);
222 
223         template <typename T>
explain_eq(ptr_vector<T> & justifications,enode * a,enode * b,justification const & j)224         void explain_eq(ptr_vector<T>& justifications, enode* a, enode* b, justification const& j) {
225             if (j.is_external())
226                 justifications.push_back(j.ext<T>());
227             else if (j.is_congruence())
228                 push_congruence(a, b, j.is_commutative());
229         }
230         template <typename T>
231         void explain_todo(ptr_vector<T>& justifications);
232 
233         std::ostream& display(std::ostream& out, unsigned max_args, enode* n) const;
234 
235     public:
236         egraph(ast_manager& m);
237         ~egraph();
find(expr * f)238         enode* find(expr* f) const { return m_expr2enode.get(f->get_id(), nullptr); }
239         enode* find(expr* f, unsigned n, enode* const* args);
240         enode* mk(expr* f, unsigned generation, unsigned n, enode *const* args);
241         enode_vector const& enodes_of(func_decl* f);
push()242         void push() { if (!m_to_merge.empty()) propagate(); ++m_num_scopes; }
243         void pop(unsigned num_scopes);
244 
245         /**
246            \brief merge nodes, all effects are deferred to the propagation step.
247          */
merge(enode * n1,enode * n2,void * reason)248         void merge(enode* n1, enode* n2, void* reason) { merge(n1, n2, justification::external(reason)); }
249         void new_diseq(enode* n);
250 
251 
252         /**
253            \brief propagate set of merges.
254            This call may detect an inconsistency. Then inconsistent() is true.
255            Use then explain() to extract an explanation for the conflict.
256 
257            It may also infer new implied equalities, when the roots of the
258            equated nodes are merged. Use then new_eqs() to extract the vector
259            of new equalities.
260         */
261         bool propagate();
inconsistent()262         bool inconsistent() const { return m_inconsistent; }
263 
264         /**
265         * \brief check if two nodes are known to be disequal.
266         */
267         bool are_diseq(enode* a, enode* b);
268 
269         enode* get_enode_eq_to(func_decl* f, unsigned num_args, enode* const* args);
270 
271         enode* tmp_eq(enode* a, enode* b);
272 
273         /**
274            \brief Maintain and update cursor into propagated consequences.
275            The result of get_literal() is a pair (n, is_eq)
276            where \c n is an enode and \c is_eq indicates whether the enode
277            is an equality consequence.
278          */
279         void       add_th_diseq(theory_id id, theory_var v1, theory_var v2, expr* eq);
has_literal()280         bool       has_literal() const { return m_new_lits_qhead < m_new_lits.size(); }
has_th_eq()281         bool       has_th_eq() const { return m_new_th_eqs_qhead < m_new_th_eqs.size(); }
get_literal()282         enode_bool_pair get_literal() const { return m_new_lits[m_new_lits_qhead]; }
get_th_eq()283         th_eq      get_th_eq() const { return m_new_th_eqs[m_new_th_eqs_qhead]; }
next_literal()284         void       next_literal() { force_push();  SASSERT(m_new_lits_qhead < m_new_lits.size()); m_new_lits_qhead++; }
next_th_eq()285         void       next_th_eq() { force_push(); SASSERT(m_new_th_eqs_qhead < m_new_th_eqs.size()); m_new_th_eqs_qhead++; }
286 
287         void set_lbl_hash(enode* n);
288 
289 
290         void add_th_var(enode* n, theory_var v, theory_id id);
291         void set_th_propagates_diseqs(theory_id id);
292         void set_merge_enabled(enode* n, bool enable_merge);
293         void set_value(enode* n, lbool value);
set_bool_var(enode * n,unsigned v)294         void set_bool_var(enode* n, unsigned v) { n->set_bool_var(v); }
295 
set_on_merge(std::function<void (enode * root,enode * other)> & on_merge)296         void set_on_merge(std::function<void(enode* root,enode* other)>& on_merge) { m_on_merge = on_merge; }
set_on_make(std::function<void (enode * n)> & on_make)297         void set_on_make(std::function<void(enode* n)>& on_make) { m_on_make = on_make; }
set_used_eq(std::function<void (expr *,expr *,expr *)> & used_eq)298         void set_used_eq(std::function<void(expr*,expr*,expr*)>& used_eq) { m_used_eq = used_eq; }
set_used_cc(std::function<void (app *,app *)> & used_cc)299         void set_used_cc(std::function<void(app*,app*)>& used_cc) { m_used_cc = used_cc; }
set_display_justification(std::function<void (std::ostream &,void *)> & d)300         void set_display_justification(std::function<void (std::ostream&, void*)> & d) { m_display_justification = d; }
301 
302         void begin_explain();
303         void end_explain();
uses_congruence()304         bool uses_congruence() const { return m_uses_congruence; }
305         template <typename T>
306         void explain(ptr_vector<T>& justifications);
307         template <typename T>
308         void explain_eq(ptr_vector<T>& justifications, enode* a, enode* b);
309         template <typename T>
310         unsigned explain_diseq(ptr_vector<T>& justifications, enode* a, enode* b);
nodes()311         enode_vector const& nodes() const { return m_nodes; }
312 
get_manager()313         ast_manager& get_manager() { return m; }
314 
315         void invariant();
316         void copy_from(egraph const& src, std::function<void*(void*)>& copy_justification);
317         struct e_pp {
318             egraph const& g;
319             enode*  n;
e_ppe_pp320             e_pp(egraph const& g, enode* n) : g(g), n(n) {}
displaye_pp321             std::ostream& display(std::ostream& out) const { return g.display(out, 0, n); }
322         };
pp(enode * n)323         e_pp pp(enode* n) const { return e_pp(*this, n); }
324         struct b_pp {
325             egraph const& g;
326             enode* n;
b_ppb_pp327             b_pp(egraph const& g, enode* n) : g(g), n(n) {}
displayb_pp328             std::ostream& display(std::ostream& out) const { return n ? (out << n->get_expr_id() << ": " << mk_bounded_pp(n->get_expr(), g.m)) : out << "null"; }
329         };
bpp(enode * n)330         b_pp bpp(enode* n) const { return b_pp(*this, n); }
331         std::ostream& display(std::ostream& out) const;
332 
333         void collect_statistics(statistics& st) const;
334 
num_scopes()335         unsigned num_scopes() const { return m_scopes.size() + m_num_scopes; }
num_nodes()336         unsigned num_nodes() const { return m_nodes.size(); }
337     };
338 
339     inline std::ostream& operator<<(std::ostream& out, egraph const& g) { return g.display(out); }
340     inline std::ostream& operator<<(std::ostream& out, egraph::e_pp const& p) { return p.display(out); }
341     inline std::ostream& operator<<(std::ostream& out, egraph::b_pp const& p) { return p.display(out); }
342 }
343