1 /*++
2 Copyright (c) 2006 Microsoft Corporation
3 
4 Module Name:
5 
6     unifier.cpp
7 
8 Abstract:
9 
10     <abstract>
11 
12 Author:
13 
14     Leonardo de Moura (leonardo) 2008-01-28.
15 
16 Revision History:
17 
18 --*/
19 #include "ast/substitution/unifier.h"
20 #include "ast/ast_pp.h"
21 
reset(unsigned num_offsets)22 void unifier::reset(unsigned num_offsets) {
23     m_todo.reset();
24     m_find.reset();
25     m_size.reset();
26 }
27 
28 /**
29    \brief Find with path compression.
30 */
find(expr_offset p)31 expr_offset unifier::find(expr_offset p) {
32     buffer<expr_offset> path;
33     expr_offset next;
34     while (m_find.find(p, next)) {
35         path.push_back(p);
36         p = next;
37     }
38     buffer<expr_offset>::iterator it  = path.begin();
39     buffer<expr_offset>::iterator end = path.end();
40     for (; it != end; ++it) {
41         expr_offset & prev = *it;
42         m_find.insert(prev, p);
43     }
44     return p;
45 }
46 
save_var(expr_offset const & p,expr_offset const & t)47 void unifier::save_var(expr_offset const & p, expr_offset const & t) {
48     expr * n = p.get_expr();
49     if (is_var(n)) {
50         unsigned off = p.get_offset();
51         m_subst->insert(to_var(n)->get_idx(), off, t);
52     }
53 }
54 
55 
56 /**
57    \brief Merge the equivalence classes of n1 and n2. n2 will be the
58    root of the resultant equivalence class.
59 */
union1(expr_offset const & n1,expr_offset const & n2)60 void unifier::union1(expr_offset const & n1, expr_offset const & n2) {
61     DEBUG_CODE({
62         expr_offset f;
63         SASSERT(!m_find.find(n1, f));
64         SASSERT(!m_find.find(n2, f));
65     });
66     unsigned sz1 = 1;
67     unsigned sz2 = 1;
68     m_size.find(n1, sz1);
69     m_size.find(n2, sz2);
70     m_find.insert(n1, n2);
71     m_size.insert(n2, sz1 + sz2);
72     save_var(n1, n2);
73 }
74 
75 /**
76    \brief Merge the equivalence classes of n1 and n2. The root of the
77    resultant equivalence class is the one with more elements.
78 */
union2(expr_offset n1,expr_offset n2)79 void unifier::union2(expr_offset n1, expr_offset n2) {
80     DEBUG_CODE({
81         expr_offset f;
82         SASSERT(!m_find.find(n1, f));
83         SASSERT(!m_find.find(n2, f));
84     });
85     unsigned sz1 = 1;
86     unsigned sz2 = 1;
87     m_size.find(n1, sz1);
88     m_size.find(n2, sz2);
89     if (sz1 > sz2)
90         std::swap(n1, n2);
91     m_find.insert(n1, n2);
92     m_size.insert(n2, sz1 + sz2);
93     save_var(n1, n2);
94 }
95 
unify_core(expr_offset p1,expr_offset p2)96 bool unifier::unify_core(expr_offset p1, expr_offset p2) {
97     entry e(p1, p2);
98     m_todo.push_back(e);
99     while (!m_todo.empty()) {
100         entry const & e = m_todo.back();
101         p1 = find(e.first);
102         p2 = find(e.second);
103         m_todo.pop_back();
104         if (p1 != p2) {
105             expr * n1 = p1.get_expr();
106             expr * n2 = p2.get_expr();
107             SASSERT(!is_quantifier(n1));
108             SASSERT(!is_quantifier(n2));
109             bool v1 = is_var(n1);
110             bool v2 = is_var(n2);
111             if (v1 && v2) {
112                 union2(p1, p2);
113             }
114             else if (v1) {
115                 union1(p1, p2);
116             }
117             else if (v2) {
118                 union1(p2, p1);
119             }
120             else {
121                 app * a1 = to_app(n1);
122                 app * a2 = to_app(n2);
123 
124                 unsigned off1 = p1.get_offset();
125                 unsigned off2 = p2.get_offset();
126                 if (a1->get_decl() != a2->get_decl() || a1->get_num_args() != a2->get_num_args())
127                     return false;
128                 union2(p1, p2);
129                 unsigned j = a1->get_num_args();
130                 while (j > 0) {
131                     --j;
132                     entry new_e(expr_offset(a1->get_arg(j), off1),
133                                 expr_offset(a2->get_arg(j), off2));
134                     m_todo.push_back(new_e);
135                 }
136             }
137         }
138     }
139     return true;
140 }
141 
operator ()(unsigned num_exprs,expr ** es,substitution & s,bool use_offsets)142 bool unifier::operator()(unsigned num_exprs, expr ** es, substitution & s, bool use_offsets) {
143     SASSERT(num_exprs > 0);
144     unsigned num_offsets = use_offsets ? num_exprs : 1;
145     reset(num_offsets);
146     m_subst = &s;
147 #if 1
148     TRACE("unifier", for (unsigned i = 0; i < num_exprs; ++i) tout << mk_pp(es[i], m_manager) << "\n";);
149     for (unsigned i = s.get_num_bindings(); i > 0; ) {
150         --i;
151         std::pair<unsigned,unsigned> bound;
152         expr_offset root, child;
153         s.get_binding(i, bound, root);
154         TRACE("unifier", tout << bound.first << " |-> " << mk_pp(root.get_expr(), m_manager) << "\n";);
155         if (is_var(root.get_expr())) {
156             var* v = m_manager.mk_var(bound.first,to_var(root.get_expr())->get_sort());
157             child = expr_offset(v, bound.second);
158             unsigned sz1 = 1;
159             unsigned sz2 = 1;
160             m_size.find(child, sz1);
161             m_size.find(root, sz2);
162             m_find.insert(child, root);
163             m_size.insert(root, sz1 + sz2);
164         }
165     }
166 #endif
167     for (unsigned i = 0; i < num_exprs - 1; i++) {
168         if (!unify_core(expr_offset(es[i], use_offsets ? i : 0),
169                         expr_offset(es[i+1], use_offsets ? i + 1 : 0))) {
170             m_last_call_succeeded = false;
171             return m_last_call_succeeded;
172         }
173     }
174 
175     m_last_call_succeeded = m_subst->acyclic();
176     return m_last_call_succeeded;
177 }
178 
operator ()(expr * e1,expr * e2,substitution & s,bool use_offsets)179 bool unifier::operator()(expr * e1, expr * e2, substitution & s, bool use_offsets) {
180     expr * es[2] = { e1, e2 };
181     return operator()(2, es, s, use_offsets);
182 }
183 
184