1 /**++
2 Copyright (c) Arie Gurfinkel
3 
4 Module Name:
5 
6     qe_term_graph.cpp
7 
8 Abstract:
9 
10     Equivalence graph of terms
11 
12 Author:
13 
14     Arie Gurfinkel
15 
16 Notes:
17 
18 --*/
19 
20 #include "util/util.h"
21 #include "util/uint_set.h"
22 #include "util/obj_pair_hashtable.h"
23 #include "ast/ast_pp.h"
24 #include "ast/ast_util.h"
25 #include "ast/for_each_expr.h"
26 #include "ast/occurs.h"
27 #include "model/model_evaluator.h"
28 #include "qe/mbp/mbp_term_graph.h"
29 
30 namespace mbp {
31 
mk_neq(ast_manager & m,expr * e1,expr * e2)32     static expr_ref mk_neq(ast_manager &m, expr *e1, expr *e2) {
33         expr *t = nullptr;
34         // x != !x  == true
35         if ((m.is_not(e1, t) && t == e2) || (m.is_not(e2, t) && t == e1))
36             return expr_ref(m.mk_true(), m);
37         else if (m.are_distinct(e1, e2))
38             return expr_ref(m.mk_true(), m);
39         return expr_ref(m.mk_not(m.mk_eq(e1, e2)), m);
40     }
41 
42     namespace {
43         struct sort_lt_proc {
operator ()mbp::__anonf21a55300111::sort_lt_proc44             bool operator()(const expr* a, const expr *b) const {
45                 return a->get_sort()->get_id() < b->get_sort()->get_id();
46             }
47         };
48     }
49 
50     namespace is_pure_ns {
51         struct found{};
52         struct proc {
53             is_variable_proc &m_is_var;
procmbp::is_pure_ns::proc54             proc(is_variable_proc &is_var) : m_is_var(is_var) {}
operator ()mbp::is_pure_ns::proc55             void operator()(var *n) const {if (m_is_var(n)) throw found();}
operator ()mbp::is_pure_ns::proc56             void operator()(app const *n) const {if (m_is_var(n)) throw found();}
operator ()mbp::is_pure_ns::proc57             void operator()(quantifier *n) const {}
58         };
59     }
60 
is_pure(is_variable_proc & is_var,expr * e)61     bool is_pure(is_variable_proc &is_var, expr *e) {
62         try {
63             is_pure_ns::proc v(is_var);
64             quick_for_each_expr(v, e);
65         }
66         catch (const is_pure_ns::found &) {
67             return false;
68         }
69         return true;
70     }
71 
72     class term {
73         // -- an app represented by this term
74         expr_ref m_expr; // NSB: to make usable with exprs
75         // -- root of the equivalence class
76         term* m_root;
77         // -- next element in the equivalence class (cyclic linked list)
78         term* m_next;
79         // -- eq class size
80         unsigned m_class_size;
81         // -- general purpose mark
82         unsigned m_mark:1;
83         // -- general purpose second mark
84         unsigned m_mark2:1;
85         // -- is an interpreted constant
86         unsigned m_interpreted:1;
87 
88         // -- terms that contain this term as a child
89         ptr_vector<term> m_parents;
90 
91         // arguments of term.
92         ptr_vector<term> m_children;
93 
94     public:
term(expr_ref const & v,u_map<term * > & app2term)95         term(expr_ref const& v, u_map<term*>& app2term) :
96             m_expr(v),
97             m_root(this),
98             m_next(this),
99             m_class_size(1),
100             m_mark(false),
101             m_mark2(false),
102             m_interpreted(false) {
103             if (!is_app(m_expr)) return;
104             for (expr* e : *to_app(m_expr)) {
105                 term* t = app2term[e->get_id()];
106                 t->get_root().m_parents.push_back(this);
107                 m_children.push_back(t);
108             }
109         }
110 
~term()111         ~term() {}
112 
113         class parents {
114             term const& t;
115         public:
parents(term const & _t)116             parents(term const& _t):t(_t) {}
parents(term const * _t)117             parents(term const* _t):t(*_t) {}
begin() const118             ptr_vector<term>::const_iterator begin() const { return t.m_parents.begin(); }
end() const119             ptr_vector<term>::const_iterator end() const { return t.m_parents.end(); }
120         };
121 
122         class children {
123             term const& t;
124         public:
children(term const & _t)125             children(term const& _t):t(_t) {}
children(term const * _t)126             children(term const* _t):t(*_t) {}
begin() const127             ptr_vector<term>::const_iterator begin() const { return t.m_children.begin(); }
end() const128             ptr_vector<term>::const_iterator end() const { return t.m_children.end(); }
129         };
130 
131         // Congruence table hash function is based on
132         // roots of children and function declaration.
133 
get_hash() const134         unsigned get_hash() const {
135             unsigned a, b, c;
136             a = b = c = get_decl_id();
137             for (term * ch : children(this)) {
138                 a = ch->get_root().get_id();
139                 mix(a, b, c);
140             }
141             return c;
142         }
143 
cg_eq(term const * t1,term const * t2)144         static bool cg_eq(term const * t1, term const * t2) {
145             if (t1->get_decl_id() != t2->get_decl_id()) return false;
146             if (t1->m_children.size() != t2->m_children.size()) return false;
147             for (unsigned i = 0, sz = t1->m_children.size(); i < sz; ++ i) {
148                 if (t1->m_children[i]->get_root().get_id() != t2->m_children[i]->get_root().get_id()) return false;
149             }
150             return true;
151         }
152 
get_id() const153         unsigned get_id() const { return m_expr->get_id();}
154 
get_decl_id() const155         unsigned get_decl_id() const { return is_app(m_expr) ? to_app(m_expr)->get_decl()->get_id() : m_expr->get_id(); }
156 
is_marked() const157         bool is_marked() const {return m_mark;}
set_mark(bool v)158         void set_mark(bool v){m_mark = v;}
is_marked2() const159         bool is_marked2() const {return m_mark2;} // NSB: where is this used?
set_mark2(bool v)160         void set_mark2(bool v){m_mark2 = v;}      // NSB: where is this used?
161 
is_interpreted() const162         bool is_interpreted() const {return m_interpreted;}
is_theory() const163         bool is_theory() const { return !is_app(m_expr) || to_app(m_expr)->get_family_id() != null_family_id; }
mark_as_interpreted()164         void mark_as_interpreted() {m_interpreted=true;}
get_expr() const165         expr* get_expr() const {return m_expr;}
get_num_args() const166         unsigned get_num_args() const { return is_app(m_expr) ? to_app(m_expr)->get_num_args() : 0; }
167 
get_root() const168         term &get_root() const {return *m_root;}
is_root() const169         bool is_root() const {return m_root == this;}
set_root(term & r)170         void set_root(term &r) {m_root = &r;}
get_next() const171         term &get_next() const {return *m_next;}
add_parent(term * p)172         void add_parent(term* p) { m_parents.push_back(p); }
173 
get_class_size() const174         unsigned get_class_size() const {return m_class_size;}
175 
merge_eq_class(term & b)176         void merge_eq_class(term &b) {
177             std::swap(this->m_next, b.m_next);
178             m_class_size += b.get_class_size();
179             // -- reset (useful for debugging)
180             b.m_class_size = 0;
181         }
182 
183         // -- make this term the root of its equivalence class
mk_root()184         void mk_root() {
185             if (is_root()) return;
186 
187             term *curr = this;
188             do {
189                 if (curr->is_root()) {
190                     // found previous root
191                     SASSERT(curr != this);
192                     m_class_size = curr->get_class_size();
193                     curr->m_class_size = 0;
194                 }
195                 curr->set_root(*this);
196                 curr = &curr->get_next();
197             }
198             while (curr != this);
199         }
200 
display(std::ostream & out) const201         std::ostream& display(std::ostream& out) const {
202             out << get_id() << ": " << m_expr
203                 << (is_root() ? " R" : "") << " - ";
204             term const* r = &this->get_next();
205             while (r != this) {
206                 out << r->get_id() << " ";
207                 r = &r->get_next();
208             }
209             out << "\n";
210             return out;
211         }
212     };
213 
operator <<(std::ostream & out,term const & t)214     static std::ostream& operator<<(std::ostream& out, term const& t) {
215         return t.display(out);
216     }
217 
operator ()(const expr * e) const218     bool term_graph::is_variable_proc::operator()(const expr * e) const {
219         if (!is_app(e)) return false;
220         const app *a = ::to_app(e);
221         TRACE("qe_verbose", tout << a->get_family_id() << " " << m_solved.contains(a->get_decl()) << " " << m_decls.contains(a->get_decl()) << "\n";);
222         return
223             a->get_family_id() == null_family_id &&
224             !m_solved.contains(a->get_decl()) &&
225             m_exclude == m_decls.contains(a->get_decl());
226     }
227 
operator ()(const term & t) const228     bool term_graph::is_variable_proc::operator()(const term &t) const {
229         return (*this)(t.get_expr());
230     }
231 
set_decls(const func_decl_ref_vector & decls,bool exclude)232     void term_graph::is_variable_proc::set_decls(const func_decl_ref_vector &decls, bool exclude) {
233         reset();
234         m_exclude = exclude;
235         for (auto *d : decls) m_decls.insert(d);
236     }
mark_solved(const expr * e)237     void term_graph::is_variable_proc::mark_solved(const expr *e) {
238         if ((*this)(e) && is_app(e))
239             m_solved.insert(::to_app(e)->get_decl());
240     }
241 
242 
operator ()(term const * t) const243     unsigned term_graph::term_hash::operator()(term const* t) const { return t->get_hash(); }
244 
operator ()(term const * a,term const * b) const245     bool term_graph::term_eq::operator()(term const* a, term const* b) const { return term::cg_eq(a, b); }
246 
term_graph(ast_manager & man)247     term_graph::term_graph(ast_manager &man) : m(man), m_lits(m), m_pinned(m), m_projector(nullptr) {
248         m_plugins.register_plugin(mbp::mk_basic_solve_plugin(m, m_is_var));
249         m_plugins.register_plugin(mbp::mk_arith_solve_plugin(m, m_is_var));
250     }
251 
~term_graph()252     term_graph::~term_graph() {
253         dealloc(m_projector);
254         reset();
255     }
256 
is_pure_def(expr * atom,expr * & v)257     bool term_graph::is_pure_def(expr *atom, expr*& v) {
258         expr *e = nullptr;
259         return m.is_eq(atom, v, e) && m_is_var(v) && is_pure(m_is_var, e);
260     }
261 
get_family_id(ast_manager & m,expr * lit)262     static family_id get_family_id(ast_manager &m, expr *lit) {
263         if (m.is_not(lit, lit))
264             return get_family_id(m, lit);
265 
266         expr *a = nullptr, *b = nullptr;
267         // deal with equality using sort of range
268         if (m.is_eq (lit, a, b)) {
269             return a->get_sort()->get_family_id();
270         }
271         // extract family_id of top level app
272         else if (is_app(lit)) {
273             return to_app(lit)->get_decl()->get_family_id();
274         }
275         else {
276             return null_family_id;
277         }
278     }
add_lit(expr * l)279     void term_graph::add_lit(expr *l) {
280         expr_ref lit(m);
281         expr_ref_vector lits(m);
282         lits.push_back(l);
283         for (unsigned i = 0; i < lits.size(); ++i) {
284             l = lits.get(i);
285             family_id fid = get_family_id(m, l);
286             mbp::solve_plugin *pin = m_plugins.get_plugin(fid);
287             lit = pin ? (*pin)(l) : l;
288             if (m.is_and(lit)) {
289                 lits.append(::to_app(lit)->get_num_args(), ::to_app(lit)->get_args());
290             }
291             else {
292                 m_lits.push_back(lit);
293                 internalize_lit(lit);
294             }
295         }
296     }
297 
is_internalized(expr * a)298     bool term_graph::is_internalized(expr *a) {
299         return m_app2term.contains(a->get_id());
300     }
301 
get_term(expr * a)302     term* term_graph::get_term(expr *a) {
303         term *res;
304         return m_app2term.find (a->get_id(), res) ? res : nullptr;
305     }
306 
mk_term(expr * a)307     term *term_graph::mk_term(expr *a) {
308         expr_ref e(a, m);
309         term * t = alloc(term, e, m_app2term);
310         if (t->get_num_args() == 0 && m.is_unique_value(a)){
311             t->mark_as_interpreted();
312         }
313 
314         m_terms.push_back(t);
315         m_app2term.insert(a->get_id(), t);
316         return t;
317     }
318 
internalize_term(expr * t)319     term* term_graph::internalize_term(expr *t) {
320         term* res = get_term(t);
321         if (res) return res;
322         ptr_buffer<expr> todo;
323         todo.push_back(t);
324         while (!todo.empty()) {
325             t = todo.back();
326             res = get_term(t);
327             if (res) {
328                 todo.pop_back();
329                 continue;
330             }
331             unsigned sz = todo.size();
332             if (is_app(t)) {
333                 for (expr * arg : *::to_app(t)) {
334                     if (!get_term(arg))
335                         todo.push_back(arg);
336                 }
337             }
338             if (sz < todo.size()) continue;
339             todo.pop_back();
340             res = mk_term(t);
341         }
342         SASSERT(res);
343         return res;
344     }
345 
internalize_eq(expr * a1,expr * a2)346     void term_graph::internalize_eq(expr *a1, expr* a2) {
347         SASSERT(m_merge.empty());
348         merge(*internalize_term(a1), *internalize_term(a2));
349         merge_flush();
350         SASSERT(m_merge.empty());
351     }
352 
internalize_lit(expr * lit)353     void term_graph::internalize_lit(expr* lit) {
354         expr *e1 = nullptr, *e2 = nullptr, *v = nullptr;
355         if (m.is_eq (lit, e1, e2)) {
356             internalize_eq (e1, e2);
357         }
358         else {
359             internalize_term(lit);
360         }
361         if (is_pure_def(lit, v)) {
362             m_is_var.mark_solved(v);
363         }
364     }
365 
merge_flush()366     void term_graph::merge_flush() {
367         while (!m_merge.empty()) {
368             term* t1 = m_merge.back().first;
369             term* t2 = m_merge.back().second;
370             m_merge.pop_back();
371             merge(*t1, *t2);
372         }
373     }
374 
merge(term & t1,term & t2)375     void term_graph::merge(term &t1, term &t2) {
376         term *a = &t1.get_root();
377         term *b = &t2.get_root();
378 
379         if (a == b) return;
380 
381         // -- merge might invalidate term2app cache
382         m_term2app.reset();
383         m_pinned.reset();
384 
385         if (a->get_class_size() > b->get_class_size()) {
386             std::swap(a, b);
387         }
388 
389         // Remove parents of b from the cg table.
390         for (term* p : term::parents(b)) {
391             if (!p->is_marked()) {
392                 p->set_mark(true);
393                 m_cg_table.erase(p);
394             }
395         }
396         // make 'a' be the root of the equivalence class of 'b'
397         b->set_root(*a);
398         for (term *it = &b->get_next(); it != b; it = &it->get_next()) {
399             it->set_root(*a);
400         }
401 
402         // merge equivalence classes
403         a->merge_eq_class(*b);
404 
405         // Insert parents of b's old equilvalence class into the cg table
406         for (term* p : term::parents(b)) {
407             if (p->is_marked()) {
408                 term* p_old = m_cg_table.insert_if_not_there(p);
409                 p->set_mark(false);
410                 a->add_parent(p);
411                 // propagate new equalities.
412                 if (p->get_root().get_id() != p_old->get_root().get_id()) {
413                     m_merge.push_back(std::make_pair(p, p_old));
414                 }
415             }
416         }
417         SASSERT(marks_are_clear());
418     }
419 
mk_app_core(expr * e)420     expr* term_graph::mk_app_core (expr *e) {
421         if (is_app(e)) {
422             expr_ref_buffer kids(m);
423             app* a = ::to_app(e);
424             for (expr * arg : *a) {
425                 kids.push_back (mk_app(arg));
426             }
427             app* res = m.mk_app(a->get_decl(), a->get_num_args(), kids.data());
428             m_pinned.push_back(res);
429             return res;
430         }
431         else {
432             return e;
433         }
434     }
435 
mk_app(term const & r)436     expr_ref term_graph::mk_app(term const &r) {
437         SASSERT(r.is_root());
438 
439         if (r.get_num_args() == 0) {
440             return expr_ref(r.get_expr(), m);
441         }
442 
443         expr* res = nullptr;
444         if (m_term2app.find(r.get_id(), res)) {
445             return expr_ref(res, m);
446         }
447 
448         res = mk_app_core (r.get_expr());
449         m_term2app.insert(r.get_id(), res);
450         return expr_ref(res, m);
451 
452     }
453 
mk_app(expr * a)454     expr_ref term_graph::mk_app(expr *a) {
455         term *t = get_term(a);
456         if (!t)
457             return expr_ref(a, m);
458         else
459             return mk_app(t->get_root());
460 
461     }
462 
mk_equalities(term const & t,expr_ref_vector & out)463     void term_graph::mk_equalities(term const &t, expr_ref_vector &out) {
464         SASSERT(t.is_root());
465         expr_ref rep(mk_app(t), m);
466         for (term *it = &t.get_next(); it != &t; it = &it->get_next()) {
467             expr* mem = mk_app_core(it->get_expr());
468             out.push_back (m.mk_eq (rep, mem));
469         }
470     }
471 
mk_all_equalities(term const & t,expr_ref_vector & out)472     void term_graph::mk_all_equalities(term const &t, expr_ref_vector &out) {
473         mk_equalities(t, out);
474 
475         for (term *it = &t.get_next(); it != &t; it = &it->get_next ()) {
476             expr* a1 = mk_app_core (it->get_expr());
477             for (term *it2 = &it->get_next(); it2 != &t; it2 = &it2->get_next()) {
478                 expr* a2 =  mk_app_core(it2->get_expr());
479                 out.push_back (m.mk_eq (a1, a2));
480             }
481         }
482     }
483 
reset_marks()484     void term_graph::reset_marks() {
485         for (term * t : m_terms) {
486             t->set_mark(false);
487         }
488     }
489 
marks_are_clear()490     bool term_graph::marks_are_clear() {
491         for (term * t : m_terms) {
492             if (t->is_marked()) return false;
493         }
494         return true;
495     }
496 
497     /// Order of preference for roots of equivalence classes
498     /// XXX This should be factored out to let clients control the preference
term_lt(term const & t1,term const & t2)499     bool term_graph::term_lt(term const &t1, term const &t2) {
500         // prefer constants over applications
501         // prefer uninterpreted constants over values
502         // prefer smaller expressions over larger ones
503         if (t1.get_num_args() == 0 || t2.get_num_args() == 0) {
504             if (t1.get_num_args() == t2.get_num_args()) {
505                 // t1.get_num_args() == t2.get_num_args() == 0
506                 if (m.is_value(t1.get_expr()) == m.is_value(t2.get_expr()))
507                     return t1.get_id() < t2.get_id();
508                 return m.is_value(t2.get_expr());
509             }
510             return t1.get_num_args() < t2.get_num_args();
511         }
512 
513         unsigned sz1 = get_num_exprs(t1.get_expr());
514         unsigned sz2 = get_num_exprs(t2.get_expr());
515         return sz1 < sz2;
516     }
517 
pick_root(term & t)518     void term_graph::pick_root (term &t) {
519         term *r = &t;
520         for (term *it = &t.get_next(); it != &t; it = &it->get_next()) {
521             it->set_mark(true);
522             if (term_lt(*it, *r)) { r = it; }
523         }
524 
525         // -- if found something better, make it the new root
526         if (r != &t) {
527             r->mk_root();
528         }
529     }
530 
531     /// Choose better roots for equivalence classes
pick_roots()532     void term_graph::pick_roots() {
533         SASSERT(marks_are_clear());
534         for (term* t : m_terms) {
535             if (!t->is_marked() && t->is_root())
536                 pick_root(*t);
537         }
538         reset_marks();
539     }
540 
display(std::ostream & out)541     void term_graph::display(std::ostream &out) {
542         for (term * t : m_terms) {
543             out << *t;
544         }
545     }
546 
to_lits(expr_ref_vector & lits,bool all_equalities)547     void term_graph::to_lits (expr_ref_vector &lits, bool all_equalities) {
548         pick_roots();
549 
550         for (expr * a : m_lits) {
551             if (is_internalized(a)) {
552                 lits.push_back (::to_app(mk_app(a)));
553             }
554         }
555 
556         for (term * t : m_terms) {
557             if (!t->is_root())
558                 continue;
559             else if (all_equalities)
560                 mk_all_equalities (*t, lits);
561             else
562                 mk_equalities(*t, lits);
563         }
564     }
565 
to_expr()566     expr_ref term_graph::to_expr() {
567         expr_ref_vector lits(m);
568         to_lits(lits);
569         return mk_and(lits);
570     }
571 
reset()572     void term_graph::reset() {
573         m_term2app.reset();
574         m_pinned.reset();
575         m_app2term.reset();
576         std::for_each(m_terms.begin(), m_terms.end(), delete_proc<term>());
577         m_terms.reset();
578         m_lits.reset();
579         m_cg_table.reset();
580     }
581 
582     class term_graph::projector {
583         term_graph &m_tg;
584         ast_manager &m;
585         u_map<expr*> m_term2app;
586         u_map<expr*> m_root2rep;
587 
588         model_ref m_model;
589         expr_ref_vector m_pinned;  // tracks expr in the maps
590 
mk_pure(term const & t)591         expr* mk_pure(term const& t) {
592             TRACE("qe", t.display(tout););
593             expr* e = nullptr;
594             if (find_term2app(t, e)) return e;
595             e = t.get_expr();
596             if (!is_app(e)) return nullptr;
597             app* a = ::to_app(e);
598             expr_ref_buffer kids(m);
599             for (term* ch : term::children(t)) {
600                 // prefer a node that resembles current child,
601                 // otherwise, pick a root representative, if present.
602                 if (find_term2app(*ch, e)) {
603                     kids.push_back(e);
604                 }
605                 else if (m_root2rep.find(ch->get_root().get_id(), e)) {
606                     kids.push_back(e);
607                 }
608                 else {
609                     return nullptr;
610                 }
611                 TRACE("qe_verbose", tout << *ch << " -> " << mk_pp(e, m) << "\n";);
612             }
613             expr* pure = m.mk_app(a->get_decl(), kids.size(), kids.data());
614             m_pinned.push_back(pure);
615             add_term2app(t, pure);
616             return pure;
617         }
618 
619 
is_better_rep(expr * t1,expr * t2)620         bool is_better_rep(expr *t1, expr *t2) {
621             if (!t2) return t1 != nullptr;
622             return m.is_unique_value(t1) && !m.is_unique_value(t2);
623         }
624 
625         struct term_depth {
operator ()mbp::term_graph::projector::term_depth626             bool operator()(term const* t1, term const* t2) const {
627                 return get_depth(t1->get_expr()) < get_depth(t2->get_expr());
628             }
629         };
630 
631 
solve_core()632         void solve_core() {
633             ptr_vector<term> worklist;
634             for (term * t : m_tg.m_terms) {
635                 // skip pure terms
636                 if (!in_term2app(*t)) {
637                     worklist.push_back(t);
638                     t->set_mark(true);
639                 }
640             }
641             term_depth td;
642             std::sort(worklist.begin(), worklist.end(), td);
643 
644             for (unsigned i = 0; i < worklist.size(); ++i) {
645                 term* t = worklist[i];
646                 t->set_mark(false);
647                 if (in_term2app(*t))
648                     continue;
649 
650                 expr* pure = mk_pure(*t);
651                 if (!pure)
652                     continue;
653 
654                 add_term2app(*t, pure);
655                 expr* rep = nullptr;
656                 // ensure that the root has a representative
657                 m_root2rep.find(t->get_root().get_id(), rep);
658 
659                 if (!rep) {
660                     m_root2rep.insert(t->get_root().get_id(), pure);
661                     for (term * p : term::parents(t->get_root())) {
662                         SASSERT(!in_term2app(*p));
663                         if (!p->is_marked()) {
664                             p->set_mark(true);
665                             worklist.push_back(p);
666                         }
667                     }
668                 }
669             }
670             m_tg.reset_marks();
671         }
672 
find_app(term & t,expr * & res)673         bool find_app(term &t, expr *&res) {
674             return
675                 find_term2app(t, res) ||
676                 m_root2rep.find(t.get_root().get_id(), res);
677         }
678 
find_app(expr * lit,expr * & res)679         bool find_app(expr *lit, expr *&res) {
680             term const* t = m_tg.get_term(lit);
681             return
682                 find_term2app(*t, res) ||
683                 m_root2rep.find(t->get_root().get_id(), res);
684         }
685 
mk_lits(expr_ref_vector & res)686         void mk_lits(expr_ref_vector &res) {
687             expr *e = nullptr;
688             for (auto *lit : m_tg.m_lits) {
689                 if (!m.is_eq(lit) && find_app(lit, e))
690                     res.push_back(e);
691             }
692             TRACE("qe", tout << "literals: " << res << "\n";);
693         }
694 
lits2pure(expr_ref_vector & res)695         void lits2pure(expr_ref_vector& res) {
696             expr *e1 = nullptr, *e2 = nullptr, *p1 = nullptr, *p2 = nullptr;
697             for (auto *lit : m_tg.m_lits) {
698                 if (m.is_eq(lit, e1, e2)) {
699                     if (find_app(e1, p1) && find_app(e2, p2)) {
700                         if (p1 != p2)
701                             res.push_back(m.mk_eq(p1, p2));
702                     }
703                     else {
704                         TRACE("qe", tout << "skipping " << mk_pp(lit, m) << "\n";);
705                     }
706                 }
707                 else if (m.is_distinct(lit)) {
708                     ptr_buffer<expr> diff;
709                     for (expr* arg : *to_app(lit)) {
710                         if (find_app(arg, p1)) {
711                             diff.push_back(p1);
712                         }
713                     }
714                     if (diff.size() > 1) {
715                         res.push_back(m.mk_distinct(diff.size(), diff.data()));
716                     }
717                     else {
718                         TRACE("qe", tout << "skipping " << mk_pp(lit, m) << "\n";);
719                     }
720                 }
721                 else if (find_app(lit, p1)) {
722                     res.push_back(p1);
723                 }
724                 else {
725                     TRACE("qe", tout << "skipping " << mk_pp(lit, m) << "\n";);
726                 }
727             }
728             remove_duplicates(res);
729             TRACE("qe", tout << "literals: " << res << "\n";);
730         }
731 
remove_duplicates(expr_ref_vector & v)732         void remove_duplicates(expr_ref_vector& v) {
733             obj_hashtable<expr> seen;
734             unsigned j = 0;
735             for (expr* e : v) {
736                 if (!seen.contains(e)) {
737                     v[j++] = e;
738                     seen.insert(e);
739                 }
740             }
741             v.shrink(j);
742         }
743 
744         vector<ptr_vector<term>> m_decl2terms; // terms that use function f
745         ptr_vector<func_decl>    m_decls;
746 
collect_decl2terms()747         void collect_decl2terms() {
748             // Collect the projected function symbols.
749             m_decl2terms.reset();
750             m_decls.reset();
751             for (term *t : m_tg.m_terms) {
752                 expr* e = t->get_expr();
753                 if (!is_app(e)) continue;
754                 if (!is_projected(*t)) continue;
755                 app* a = to_app(e);
756                 func_decl* d = a->get_decl();
757                 if (d->get_arity() == 0) continue;
758                 unsigned id = d->get_decl_id();
759                 m_decl2terms.reserve(id+1);
760                 if (m_decl2terms[id].empty()) m_decls.push_back(d);
761                 m_decl2terms[id].push_back(t);
762             }
763         }
764 
args_are_distinct(expr_ref_vector & res)765         void args_are_distinct(expr_ref_vector& res) {
766             //
767             // for each projected function that occurs
768             // (may occur) in multiple congruence classes,
769             // produce assertions that non-congruent arguments
770             // are distinct.
771             //
772             for (func_decl* d : m_decls) {
773                 unsigned id = d->get_decl_id();
774                 ptr_vector<term> const& terms = m_decl2terms[id];
775                 if (terms.size() <= 1) continue;
776                 unsigned arity = d->get_arity();
777                 for (unsigned i = 0; i < arity; ++i) {
778                     obj_hashtable<expr> roots, root_vals;
779                     expr_ref_vector pinned(m);
780                     for (term* t : terms) {
781                         expr* arg = to_app(t->get_expr())->get_arg(i);
782                         term const& root = m_tg.get_term(arg)->get_root();
783                         expr* r = root.get_expr();
784                         // if a model is given, then use the equivalence class induced
785                         // by the model. Otherwise, use the congruence class.
786                         if (m_model) {
787                             expr_ref tmp(m);
788                             tmp = (*m_model)(r);
789                             if (!root_vals.contains(tmp)) {
790                                 root_vals.insert(tmp);
791                                 roots.insert(r);
792                                 pinned.push_back(tmp);
793                             }
794                         }
795                         else {
796                             roots.insert(r);
797                         }
798                     }
799                     if (roots.size() > 1) {
800                         ptr_buffer<expr> args;
801                         for (expr* r : roots) {
802                             args.push_back(r);
803                         }
804                         TRACE("qe", tout << "function: " << d->get_name() << "\n";);
805                         res.push_back(m.mk_distinct(args.size(), args.data()));
806                     }
807                 }
808             }
809         }
810 
mk_distinct(expr_ref_vector & res)811         void mk_distinct(expr_ref_vector& res) {
812             collect_decl2terms();
813             args_are_distinct(res);
814             TRACE("qe", tout << res << "\n";);
815         }
816 
mk_pure_equalities(const term & t,expr_ref_vector & res)817         void mk_pure_equalities(const term &t, expr_ref_vector &res) {
818             SASSERT(t.is_root());
819             expr *rep = nullptr;
820             if (!m_root2rep.find(t.get_id(), rep)) return;
821             obj_hashtable<expr> members;
822             members.insert(rep);
823             term const * r = &t;
824             do {
825                 expr* member = nullptr;
826                 if (find_term2app(*r, member) && !members.contains(member)) {
827                     res.push_back (m.mk_eq (rep, member));
828                     members.insert(member);
829                 }
830                 r = &r->get_next();
831             }
832             while (r != &t);
833         }
834 
is_projected(const term & t)835         bool is_projected(const term &t) {
836             return m_tg.m_is_var(t);
837         }
838 
mk_unpure_equalities(const term & t,expr_ref_vector & res)839         void mk_unpure_equalities(const term &t, expr_ref_vector &res) {
840             expr *rep = nullptr;
841             if (!m_root2rep.find(t.get_id(), rep)) return;
842             obj_hashtable<expr> members;
843             members.insert(rep);
844             term const * r = &t;
845             do {
846                 expr* member = mk_pure(*r);
847                 SASSERT(member);
848                 if (!members.contains(member) &&
849                     (!is_projected(*r) || !is_solved_eq(rep, member))) {
850                     res.push_back(m.mk_eq(rep, member));
851                     members.insert(member);
852                 }
853                 r = &r->get_next();
854             }
855             while (r != &t);
856         }
857 
858         template<bool pure>
mk_equalities(expr_ref_vector & res)859         void mk_equalities(expr_ref_vector &res) {
860             for (term *t : m_tg.m_terms) {
861                 if (!t->is_root()) continue;
862                 if (!m_root2rep.contains(t->get_id())) continue;
863                 if (pure)
864                     mk_pure_equalities(*t, res);
865                 else
866                     mk_unpure_equalities(*t, res);
867             }
868             TRACE("qe", tout << "literals: " << res << "\n";);
869         }
870 
mk_pure_equalities(expr_ref_vector & res)871         void mk_pure_equalities(expr_ref_vector &res) {
872             mk_equalities<true>(res);
873         }
874 
mk_unpure_equalities(expr_ref_vector & res)875         void mk_unpure_equalities(expr_ref_vector &res) {
876             mk_equalities<false>(res);
877         }
878 
879         // TBD: generalize for also the case of a (:var n)
is_solved_eq(expr * lhs,expr * rhs)880         bool is_solved_eq(expr *lhs, expr* rhs) {
881             return is_uninterp_const(rhs) && !occurs(rhs, lhs);
882         }
883 
884         /// Add equalities and disequalities for all pure representatives
885         /// based on their equivalence in the model
model_complete(expr_ref_vector & res)886         void model_complete(expr_ref_vector &res) {
887             if (!m_model) return;
888             obj_map<expr,expr*> val2rep;
889             model_evaluator mev(*m_model);
890             for (auto &kv : m_root2rep) {
891                 expr *rep = kv.m_value;
892                 expr_ref val(m);
893                 expr *u = nullptr;
894                 if (!mev.eval(rep, val)) continue;
895                 if (val2rep.find(val, u)) {
896                     res.push_back(m.mk_eq(u, rep));
897                 }
898                 else {
899                     val2rep.insert(val, rep);
900                 }
901             }
902 
903             // TBD: optimize further based on implied values (e.g.,
904             // some literals are forced to be true/false) and based on
905             // unique_values (e.g., (x=1 & y=1) does not require
906             // (x!=y) to be added
907             ptr_buffer<expr> reps;
908             for (auto &kv : val2rep) {
909                 expr *rep = kv.m_value;
910                 if (!m.is_unique_value(rep))
911                 reps.push_back(kv.m_value);
912             }
913 
914             if (reps.size() <= 1) return;
915 
916             // -- sort representatives, call mk_distinct on any range
917             // -- of the same sort longer than 1
918             std::sort(reps.data(), reps.data() + reps.size(), sort_lt_proc());
919             unsigned i = 0;
920             unsigned sz = reps.size();
921             while (i < sz) {
922                 sort* last_sort = res.get(i)->get_sort();
923                 unsigned j = i + 1;
924                 while (j < sz && last_sort == reps.get(j)->get_sort()) {++j;}
925                 if (j - i == 2) {
926                     expr_ref d(m);
927                     d = mk_neq(m, reps.get(i), reps.get(i+1));
928                     if (!m.is_true(d)) res.push_back(d);
929                 }
930                 else if (j - i > 2)
931                     res.push_back(m.mk_distinct(j - i, reps.data() + i));
932                 i = j;
933             }
934             TRACE("qe", tout << "after distinct: " << res << "\n";);
935         }
936 
display(std::ostream & out) const937         std::ostream& display(std::ostream& out) const {
938             m_tg.display(out);
939             out << "term2app:\n";
940             for (auto const& kv : m_term2app) {
941                 out << kv.m_key << " |-> " << mk_pp(kv.m_value, m) << "\n";
942             }
943             out << "root2rep:\n";
944             for (auto const& kv : m_root2rep) {
945                 out << kv.m_key << " |-> " << mk_pp(kv.m_value, m) << "\n";
946             }
947             return out;
948         }
949 
950     public:
projector(term_graph & tg)951         projector(term_graph &tg) : m_tg(tg), m(m_tg.m), m_pinned(m) {}
952 
add_term2app(term const & t,expr * a)953         void add_term2app(term const& t, expr* a) {
954             m_term2app.insert(t.get_id(), a);
955         }
956 
del_term2app(term const & t)957         void del_term2app(term const& t) {
958             m_term2app.remove(t.get_id());
959         }
960 
find_term2app(term const & t,expr * & r)961         bool find_term2app(term const& t, expr*& r) {
962             return m_term2app.find(t.get_id(), r);
963         }
964 
find_term2app(term const & t)965         expr* find_term2app(term const& t) {
966             expr* r = nullptr;
967             find_term2app(t, r);
968             return r;
969         }
970 
in_term2app(term const & t)971         bool in_term2app(term const& t) {
972             return m_term2app.contains(t.get_id());
973         }
974 
set_model(model & mdl)975         void set_model(model &mdl) { m_model = &mdl; }
976 
reset()977         void reset() {
978             m_tg.reset_marks();
979             m_term2app.reset();
980             m_root2rep.reset();
981             m_pinned.reset();
982             m_model.reset();
983         }
984 
project()985         expr_ref_vector project() {
986             expr_ref_vector res(m);
987             purify();
988             lits2pure(res);
989             mk_distinct(res);
990             reset();
991             return res;
992         }
993 
get_ackerman_disequalities()994         expr_ref_vector get_ackerman_disequalities() {
995             expr_ref_vector res(m);
996             purify();
997             lits2pure(res);
998             unsigned sz = res.size();
999             mk_distinct(res);
1000             reset();
1001             unsigned j = 0;
1002             for (unsigned i = sz; i < res.size(); ++i) {
1003                 res[j++] = res.get(i);
1004             }
1005             res.shrink(j);
1006             return res;
1007         }
1008 
solve()1009         expr_ref_vector solve() {
1010             expr_ref_vector res(m);
1011             purify();
1012             solve_core();
1013             mk_lits(res);
1014             mk_unpure_equalities(res);
1015             reset();
1016             return res;
1017         }
1018 
get_partition(model & mdl,bool include_bool)1019         vector<expr_ref_vector> get_partition(model& mdl, bool include_bool) {
1020             vector<expr_ref_vector> result;
1021             expr_ref_vector pinned(m);
1022             obj_map<expr, unsigned> pid;
1023             model::scoped_model_completion _smc(mdl, true);
1024             for (term *t : m_tg.m_terms) {
1025                 expr* a = t->get_expr();
1026                 if (!is_app(a)) continue;
1027                 if (m.is_bool(a) && !include_bool) continue;
1028                 expr_ref val = mdl(a);
1029                 unsigned p = 0;
1030                 // NB. works for simple domains Integers, Rationals,
1031                 // but not for algebraic numerals.
1032                 if (!pid.find(val, p)) {
1033                     p = pid.size();
1034                     pid.insert(val, p);
1035                     pinned.push_back(val);
1036                     result.push_back(expr_ref_vector(m));
1037                 }
1038                 result[p].push_back(a);
1039             }
1040             return result;
1041         }
1042 
shared_occurrences(family_id fid)1043         expr_ref_vector shared_occurrences(family_id fid) {
1044             expr_ref_vector result(m);
1045             for (term *t : m_tg.m_terms) {
1046                 expr* e = t->get_expr();
1047                 if (e->get_sort()->get_family_id() != fid) continue;
1048                 for (term * p : term::parents(t->get_root())) {
1049                     expr* pe = p->get_expr();
1050                     if (!is_app(pe)) continue;
1051                     if (to_app(pe)->get_family_id() == fid) continue;
1052                     if (to_app(pe)->get_family_id() == m.get_basic_family_id()) continue;
1053                     result.push_back(e);
1054                     break;
1055                 }
1056             }
1057             return result;
1058         }
1059 
purify()1060         void purify() {
1061             // - propagate representatives up over parents.
1062             //   use work-list + marking to propagate.
1063             // - produce equalities over represented classes.
1064             // - produce other literals over represented classes
1065             //   (walk disequalities in m_lits and represent
1066             //   lhs/rhs over decls or excluding decls)
1067 
1068             ptr_vector<term> worklist;
1069             for (term * t : m_tg.m_terms) {
1070                 worklist.push_back(t);
1071                 t->set_mark(true);
1072             }
1073             // traverse worklist in order of depth.
1074             term_depth td;
1075             std::sort(worklist.begin(), worklist.end(), td);
1076 
1077             for (unsigned i = 0; i < worklist.size(); ++i) {
1078                 term* t = worklist[i];
1079                 t->set_mark(false);
1080                 if (in_term2app(*t))
1081                     continue;
1082                 if (!t->is_theory() && is_projected(*t))
1083                     continue;
1084 
1085                 expr* pure = mk_pure(*t);
1086                 if (!pure) continue;
1087 
1088                 add_term2app(*t, pure);
1089                 TRACE("qe_verbose", tout << "purified " << *t << " " << mk_pp(pure, m) << "\n";);
1090                 expr* rep = nullptr;                // ensure that the root has a representative
1091                 m_root2rep.find(t->get_root().get_id(), rep);
1092 
1093                 // update rep with pure if it is better
1094                 if (pure != rep && is_better_rep(pure, rep)) {
1095                     m_root2rep.insert(t->get_root().get_id(), pure);
1096                     for (term * p : term::parents(t->get_root())) {
1097                         del_term2app(*p);
1098                         if (!p->is_marked()) {
1099                             p->set_mark(true);
1100                             worklist.push_back(p);
1101                         }
1102                     }
1103                 }
1104             }
1105 
1106             // Here we could also walk equivalence classes that
1107             // contain interpreted values by sort and extract
1108             // disequalities between non-unique value
1109             // representatives.  these disequalities are implied
1110             // and can be mined using other means, such as theory
1111             // aware core minimization
1112             m_tg.reset_marks();
1113             TRACE("qe", display(tout << "after purify\n"););
1114         }
1115 
1116     };
1117 
set_vars(func_decl_ref_vector const & decls,bool exclude)1118     void term_graph::set_vars(func_decl_ref_vector const& decls, bool exclude) {
1119         m_is_var.set_decls(decls, exclude);
1120     }
1121 
project()1122     expr_ref_vector term_graph::project() {
1123         // reset solved vars so that they are not considered pure by projector
1124         m_is_var.reset_solved();
1125         term_graph::projector p(*this);
1126         return p.project();
1127     }
1128 
project(model & mdl)1129     expr_ref_vector term_graph::project(model &mdl) {
1130         m_is_var.reset_solved();
1131         term_graph::projector p(*this);
1132         p.set_model(mdl);
1133         return p.project();
1134     }
1135 
solve()1136     expr_ref_vector term_graph::solve() {
1137         // reset solved vars so that they are not considered pure by projector
1138         m_is_var.reset_solved();
1139         term_graph::projector p(*this);
1140         return p.solve();
1141     }
1142 
get_ackerman_disequalities()1143     expr_ref_vector term_graph::get_ackerman_disequalities() {
1144         m_is_var.reset_solved();
1145         dealloc(m_projector);
1146         m_projector = alloc(term_graph::projector, *this);
1147         return m_projector->get_ackerman_disequalities();
1148     }
1149 
get_partition(model & mdl)1150     vector<expr_ref_vector> term_graph::get_partition(model& mdl) {
1151         dealloc(m_projector);
1152         m_projector = alloc(term_graph::projector, *this);
1153         return m_projector->get_partition(mdl, false);
1154     }
1155 
shared_occurrences(family_id fid)1156     expr_ref_vector term_graph::shared_occurrences(family_id fid) {
1157         term_graph::projector p(*this);
1158         return p.shared_occurrences(fid);
1159     }
1160 
add_model_based_terms(model & mdl,expr_ref_vector const & terms)1161     void term_graph::add_model_based_terms(model& mdl, expr_ref_vector const& terms) {
1162         for (expr* t : terms) {
1163             internalize_term(t);
1164         }
1165         m_is_var.reset_solved();
1166 
1167         SASSERT(!m_projector);
1168         m_projector = alloc(term_graph::projector, *this);
1169 
1170         // retrieve partition of terms
1171         vector<expr_ref_vector> equivs = m_projector->get_partition(mdl, true);
1172 
1173         // merge term graph on equal terms.
1174         for (auto const& cs : equivs) {
1175             term* t0 = get_term(cs[0]);
1176             for (unsigned i = 1; i < cs.size(); ++i) {
1177                 merge(*t0, *get_term(cs[i]));
1178             }
1179         }
1180         TRACE("qe",
1181               for (auto & es : equivs) {
1182                   tout << "equiv: ";
1183                   for (expr* t : es) tout << expr_ref(t, m) << " ";
1184                   tout << "\n";
1185               }
1186               display(tout););
1187         // create representatives for shared/projected variables.
1188         m_projector->set_model(mdl);
1189         m_projector->purify();
1190 
1191     }
1192 
rep_of(expr * e)1193     expr* term_graph::rep_of(expr* e) {
1194         SASSERT(m_projector);
1195         term* t = get_term(e);
1196         SASSERT(t && "only get representatives");
1197         return m_projector->find_term2app(*t);
1198     }
1199 
dcert(model & mdl,expr_ref_vector const & lits)1200     expr_ref_vector term_graph::dcert(model& mdl, expr_ref_vector const& lits) {
1201         TRACE("qe", tout << "dcert " << lits << "\n";);
1202         struct pair_t {
1203             expr* a, *b;
1204             pair_t(): a(nullptr), b(nullptr) {}
1205             pair_t(expr* _a, expr* _b):a(_a), b(_b) {
1206                 if (a->get_id() > b->get_id()) std::swap(a, b);
1207             }
1208             struct hash {
1209                 unsigned operator()(pair_t const& p) const { return mk_mix(p.a ? p.a->hash() : 0, p.b ? p.b->hash() : 0, 1); }
1210             };
1211             struct eq {
1212                 bool operator()(pair_t const& a, pair_t const& b) const { return a.a == b.a && a.b == b.b; }
1213             };
1214         };
1215         hashtable<pair_t, pair_t::hash, pair_t::eq> diseqs;
1216         expr_ref_vector result(m);
1217         add_lits(lits);
1218         svector<pair_t> todo;
1219 
1220         for (expr* e : lits) {
1221             expr* ne, *a, *b;
1222             if (m.is_not(e, ne) && m.is_eq(ne, a, b) && (is_uninterp(a) || is_uninterp(b))) {
1223                 diseqs.insert(pair_t(a, b));
1224             }
1225             else if (is_uninterp(e)) {
1226                 diseqs.insert(pair_t(e, m.mk_false()));
1227             }
1228             else if (m.is_not(e, ne) && is_uninterp(ne)) {
1229                 diseqs.insert(pair_t(ne, m.mk_true()));
1230             }
1231         }
1232         for (auto& p : diseqs) todo.push_back(p);
1233 
1234         auto const partitions = get_partition(mdl);
1235         obj_map<expr, unsigned> term2pid;
1236         unsigned id = 0;
1237         for (auto const& vec : partitions) {
1238             for (expr* e : vec) term2pid.insert(e, id);
1239             ++id;
1240         }
1241         auto partition_of = [&](expr* e) { return partitions[term2pid[e]]; };
1242         auto in_table = [&](expr* a, expr* b) {
1243             return diseqs.contains(pair_t(a, b));
1244         };
1245         auto same_function = [](expr* a, expr* b) {
1246             return is_app(a) && is_app(b) &&
1247             to_app(a)->get_decl() == to_app(b)->get_decl() && to_app(a)->get_family_id() == null_family_id;
1248         };
1249 
1250         // make sure that diseqs is closed under function applications
1251         // of uninterpreted functions.
1252         for (unsigned idx = 0; idx < todo.size(); ++idx) {
1253             auto p = todo[idx];
1254             for (expr* t1 : partition_of(p.a)) {
1255                 for (expr* t2 : partition_of(p.b)) {
1256                     if (same_function(t1, t2)) {
1257                         unsigned sz = to_app(t1)->get_num_args();
1258                         bool found = false;
1259                         pair_t q(t1, t2);
1260                         for (unsigned i = 0; i < sz; ++i) {
1261                             expr* arg1 = to_app(t1)->get_arg(i);
1262                             expr* arg2 = to_app(t2)->get_arg(i);
1263                             if (mdl(arg1) == mdl(t2)) {
1264                                 continue;
1265                             }
1266                             if (in_table(arg1, arg2)) {
1267                                 found = true;
1268                                 break;
1269                             }
1270                             q = pair_t(arg1, arg2);
1271                         }
1272                         if (!found) {
1273                             diseqs.insert(q);
1274                             todo.push_back(q);
1275                             result.push_back(m.mk_not(m.mk_eq(q.a, q.b)));
1276                         }
1277                     }
1278                 }
1279             }
1280         }
1281         for (auto const& terms : partitions) {
1282             expr* a = nullptr;
1283             for (expr* b : terms) {
1284                 if (is_uninterp(b)) {
1285                     if (a)
1286                         result.push_back(m.mk_eq(a, b));
1287                     else
1288                         a = b;
1289                 }
1290             }
1291         }
1292         TRACE("qe", tout << result << "\n";);
1293         return result;
1294     }
1295 
1296 }
1297