1 /*++
2 Copyright (c) 2006 Microsoft Corporation
3 
4 Module Name:
5 
6     der.cpp
7 
8 Abstract:
9 
10     <abstract>
11 
12 Author:
13 
14     Leonardo de Moura (leonardo) 2008-01-27.
15 
16 Revision History:
17 
18     Christoph Wintersteiger, 2010-03-30: Added Destr. Multi-Equality Resolution
19 
20 --*/
21 #include "ast/rewriter/der.h"
22 #include "ast/occurs.h"
23 #include "ast/for_each_expr.h"
24 #include "ast/rewriter/rewriter_def.h"
25 #include "ast/ast_util.h"
26 #include "ast/ast_pp.h"
27 #include "ast/ast_ll_pp.h"
28 #include "ast/ast_smt2_pp.h"
29 
is_var(expr * e,unsigned num_decls)30 static bool is_var(expr * e, unsigned num_decls) {
31     return is_var(e) && to_var(e)->get_idx() < num_decls;
32 }
33 
is_neg_var(ast_manager & m,expr * e,var * & v,unsigned num_decls)34 static bool is_neg_var(ast_manager & m, expr * e, var*& v, unsigned num_decls) {
35     expr* n = nullptr;
36     return m.is_not(e, n) && is_var(n) && (v = to_var(n), v->get_idx() < num_decls);
37 }
38 
39 /**
40    \brief Return true if \c e is of the form (not (= VAR t)) or (not (iff VAR t)) or (iff VAR t) or (iff (not VAR) t) or (VAR IDX) or (not (VAR IDX)).
41    The last case can be viewed
42 
43    Remark: Occurs check is not necessary here... the top-sort procedure will check for cycles...
44 
45 */
is_var_diseq(expr * e,unsigned num_decls,var * & v,expr_ref & t)46 bool der::is_var_diseq(expr * e, unsigned num_decls, var * & v, expr_ref & t) {
47     expr *eq, * lhs, *rhs;
48     auto set_result = [&](var *w, expr* s) {
49         v = w;
50         t = s;
51         TRACE("der", tout << mk_pp(e, m) << "\n";);
52         return true;
53     };
54 
55     // (not (= VAR t))
56     if (m.is_not(e, eq) && m.is_eq(eq, lhs, rhs)) {
57         if (!is_var(lhs, num_decls))
58             std::swap(lhs, rhs);
59         if (!is_var(lhs, num_decls))
60             return false;
61         return set_result(to_var(lhs), rhs);
62     }
63 
64     // (= VAR t)
65     if (m.is_eq(e, lhs, rhs) && m.is_bool(lhs)) {
66         // (iff VAR t) case
67         if (!is_var(lhs, num_decls))
68             std::swap(lhs, rhs);
69         if (is_var(lhs, num_decls)) {
70             rhs = mk_not(m, rhs);
71             m_new_exprs.push_back(rhs);
72             return set_result(to_var(lhs), rhs);
73         }
74         // (iff (not VAR) t) case
75         if (!is_neg_var(m, lhs, v, num_decls))
76             std::swap(lhs, rhs);
77         if (is_neg_var(m, lhs, v, num_decls)) {
78             return set_result(v, rhs);
79         }
80         return false;
81     }
82 
83     // VAR
84     if (is_var(e, num_decls)) {
85         return set_result(to_var(e), m.mk_false());
86     }
87 
88     // (not VAR)
89     if (is_neg_var(m, e, v, num_decls)) {
90         return set_result(v, m.mk_true());
91     }
92     return false;
93 }
94 
operator ()(quantifier * q,expr_ref & r,proof_ref & pr)95 void der::operator()(quantifier * q, expr_ref & r, proof_ref & pr) {
96     bool reduced = false;
97     pr = nullptr;
98     r  = q;
99 
100     TRACE("der", tout << mk_pp(q, m) << "\n";);
101 
102     // Keep applying it until r doesn't change anymore
103     do {
104         proof_ref curr_pr(m);
105         q  = to_quantifier(r);
106         reduce1(q, r, curr_pr);
107         if (q != r)
108             reduced = true;
109         if (m.proofs_enabled()) {
110             pr = m.mk_transitivity(pr, curr_pr);
111         }
112     }
113     while (q != r && is_quantifier(r));
114 
115     // Eliminate variables that have become unused
116     if (reduced && is_forall(r)) {
117         quantifier * q = to_quantifier(r);
118         r = elim_unused_vars(m, q, params_ref());
119         if (m.proofs_enabled()) {
120             proof * p1 = m.mk_elim_unused_vars(q, r);
121             pr = m.mk_transitivity(pr, p1);
122         }
123     }
124     m_new_exprs.reset();
125 }
126 
reduce1(quantifier * q,expr_ref & r,proof_ref & pr)127 void der::reduce1(quantifier * q, expr_ref & r, proof_ref & pr) {
128     if (!is_forall(q)) {
129         pr = nullptr;
130         r  = q;
131         return;
132     }
133 
134     expr * e = q->get_expr();
135     unsigned num_decls = q->get_num_decls();
136     var * v = nullptr;
137     expr_ref t(m);
138 
139     if (is_var_diseq(e, num_decls, v, t) && !occurs(v, t))
140         r = m.mk_false();
141     else {
142         expr_ref_vector ors(m);
143         flatten_or(e, ors);
144         unsigned num_args = ors.size();
145         unsigned diseq_count = 0;
146         unsigned largest_vinx = 0;
147 
148         m_map.reset();
149         m_pos2var.reset();
150         m_inx2var.reset();
151 
152         m_pos2var.reserve(num_args, -1);
153 
154         // Find all disequalities
155         for (unsigned i = 0; i < num_args; i++) {
156             if (is_var_diseq(ors.get(i), num_decls, v, t)) {
157                 unsigned idx = v->get_idx();
158                 if (m_map.get(idx, nullptr) == nullptr) {
159                     m_map.reserve(idx + 1);
160                     m_inx2var.reserve(idx + 1, 0);
161 
162                     m_map[idx] = t;
163                     m_inx2var[idx] = v;
164                     m_pos2var[i] = idx;
165                     diseq_count++;
166                     largest_vinx = (idx>largest_vinx) ? idx : largest_vinx;
167                 }
168             }
169         }
170 
171         if (diseq_count > 0) {
172             get_elimination_order();
173             SASSERT(m_order.size() <= diseq_count); // some might be missing because of cycles
174 
175             if (!m_order.empty()) {
176                 create_substitution(largest_vinx + 1);
177                 apply_substitution(q, ors, r);
178             }
179         }
180         else {
181             TRACE("der_bug", tout << "Did not find any diseq\n" << mk_pp(q, m) << "\n";);
182             r = q;
183         }
184     }
185     // Remark: get_elimination_order/top-sort checks for cycles, but it is not invoked for unit clauses.
186     // So, we must perform a occurs check here.
187 
188     if (m.proofs_enabled()) {
189         pr = r == q ? nullptr : m.mk_der(q, r);
190     }
191 }
192 
der_sort_vars(ptr_vector<var> & vars,expr_ref_vector & definitions,unsigned_vector & order)193 static void der_sort_vars(ptr_vector<var> & vars, expr_ref_vector & definitions, unsigned_vector & order) {
194     order.reset();
195 
196     // eliminate self loops, and definitions containing quantifiers.
197     bool found = false;
198     for (unsigned i = 0; i < definitions.size(); i++) {
199         var * v  = vars[i];
200         expr * t = definitions.get(i);
201         if (t == nullptr || has_quantifiers(t) || occurs(v, t))
202             definitions[i] = nullptr;
203         else
204             found = true; // found at least one candidate
205     }
206 
207     if (!found)
208         return;
209 
210     typedef std::pair<expr *, unsigned> frame;
211     svector<frame> todo;
212 
213     expr_fast_mark1 visiting;
214     expr_fast_mark2 done;
215 
216     unsigned vidx, num;
217 
218     for (unsigned i = 0; i < definitions.size(); i++) {
219         if (!definitions.get(i))
220             continue;
221         var * v = vars[i];
222         SASSERT(v->get_idx() == i);
223         SASSERT(todo.empty());
224         todo.push_back(frame(v, 0));
225         while (!todo.empty()) {
226         start:
227             frame & fr = todo.back();
228             expr * t   = fr.first;
229             if (done.is_marked(t)) {
230                 todo.pop_back();
231                 continue;
232             }
233             switch (t->get_kind()) {
234             case AST_VAR:
235                 vidx = to_var(t)->get_idx();
236                 if (fr.second == 0) {
237                     CTRACE("der_bug", vidx >= definitions.size(), tout << "vidx: " << vidx << "\n";);
238                     // Remark: The size of definitions may be smaller than the number of variables occurring in the quantified formula.
239                     if (definitions.get(vidx, nullptr) != nullptr) {
240                         if (visiting.is_marked(t)) {
241                             // cycle detected: remove t
242                             visiting.reset_mark(t);
243                             definitions[vidx] = nullptr;
244                         }
245                         else {
246                             visiting.mark(t);
247                             fr.second = 1;
248                             todo.push_back(frame(definitions.get(vidx), 0));
249                             goto start;
250                         }
251                     }
252                 }
253                 else {
254                     SASSERT(fr.second == 1);
255                     if (definitions.get(vidx, nullptr) != nullptr) {
256                         visiting.reset_mark(t);
257                         order.push_back(vidx);
258                     }
259                     else {
260                         // var was removed from the list of candidate vars to elim cycle
261                         // do nothing
262                     }
263                 }
264                 done.mark(t);
265                 todo.pop_back();
266                 break;
267             case AST_QUANTIFIER:
268                 UNREACHABLE();
269                 todo.pop_back();
270                 break;
271             case AST_APP:
272                 num = to_app(t)->get_num_args();
273                 while (fr.second < num) {
274                     expr * arg = to_app(t)->get_arg(fr.second);
275                     fr.second++;
276                     if (done.is_marked(arg))
277                         continue;
278                     todo.push_back(frame(arg, 0));
279                     goto start;
280                 }
281                 done.mark(t);
282                 todo.pop_back();
283                 break;
284             default:
285                 UNREACHABLE();
286                 todo.pop_back();
287                 break;
288             }
289         }
290     }
291 }
292 
get_elimination_order()293 void der::get_elimination_order() {
294     m_order.reset();
295 
296     TRACE("top_sort",
297           tout << "DEFINITIONS: " << std::endl;
298           unsigned i = 0;
299           for (expr* e : m_map) {
300               if (e) tout << "VAR " << i << " = " << mk_pp(e, m) << std::endl;
301               ++i;
302           }
303       );
304 
305     // der::top_sort ts(m);
306     der_sort_vars(m_inx2var, m_map, m_order);
307 
308     TRACE("der",
309           tout << "Elimination m_order:" << "\n";
310           tout << m_order << "\n";);
311 }
312 
create_substitution(unsigned sz)313 void der::create_substitution(unsigned sz) {
314     m_subst_map.reset();
315     m_subst_map.resize(sz, nullptr);
316 
317     for(unsigned i = 0; i < m_order.size(); i++) {
318         expr_ref cur(m_map.get(m_order[i]), m);
319 
320         // do all the previous substitutions before inserting
321         expr_ref r = m_subst(cur, m_subst_map.size(), m_subst_map.data());
322 
323         unsigned inx = sz - m_order[i]- 1;
324         SASSERT(m_subst_map[inx]==0);
325         m_subst_map[inx] = r;
326     }
327 }
328 
apply_substitution(quantifier * q,expr_ref_vector & ors,expr_ref & r)329 void der::apply_substitution(quantifier * q, expr_ref_vector& ors, expr_ref & r) {
330     unsigned num_args = ors.size();
331 
332     // get a new expression
333     m_new_args.reset();
334     for(unsigned i = 0; i < num_args; i++) {
335         int x = m_pos2var[i];
336         if (x != -1 && m_map.get(x) != nullptr)
337             continue; // this is a disequality with definition (vanishes)
338 
339         m_new_args.push_back(ors.get(i));
340     }
341 
342     expr_ref t(mk_or(m, m_new_args.size(), m_new_args.data()), m);
343     expr_ref new_e = m_subst(t, m_subst_map);
344 
345     // don't forget to update the quantifier patterns
346     expr_ref_buffer  new_patterns(m);
347     expr_ref_buffer  new_no_patterns(m);
348     for (unsigned j = 0; j < q->get_num_patterns(); j++) {
349         new_patterns.push_back(m_subst(q->get_pattern(j), m_subst_map.size(), m_subst_map.data()));
350     }
351 
352     for (unsigned j = 0; j < q->get_num_no_patterns(); j++) {
353         new_no_patterns.push_back(m_subst(q->get_no_pattern(j), m_subst_map.size(), m_subst_map.data()));
354     }
355 
356     r = m.update_quantifier(q, new_patterns.size(), new_patterns.data(),
357                             new_no_patterns.size(), new_no_patterns.data(), new_e);
358 }
359 
360 
361 struct der_rewriter_cfg : public default_rewriter_cfg {
362     ast_manager& m;
363     der   m_der;
364 
der_rewriter_cfgder_rewriter_cfg365     der_rewriter_cfg(ast_manager & m): m(m), m_der(m) {}
366 
reduce_quantifierder_rewriter_cfg367     bool reduce_quantifier(quantifier * old_q,
368                            expr * new_body,
369                            expr * const * new_patterns,
370                            expr * const * new_no_patterns,
371                            expr_ref & result,
372                            proof_ref & result_pr) {
373         quantifier_ref q1(m);
374         q1 = m.update_quantifier(old_q, old_q->get_num_patterns(), new_patterns,
375                                  old_q->get_num_no_patterns(), new_no_patterns, new_body);
376         m_der(q1, result, result_pr);
377         return true;
378     }
379 };
380 
381 template class rewriter_tpl<der_rewriter_cfg>;
382 
383 struct der_rewriter::imp : public rewriter_tpl<der_rewriter_cfg> {
384     der_rewriter_cfg m_cfg;
impder_rewriter::imp385     imp(ast_manager & m):
386         rewriter_tpl<der_rewriter_cfg>(m, m.proofs_enabled(), m_cfg),
387         m_cfg(m) {
388     }
389 };
390 
der_rewriter(ast_manager & m)391 der_rewriter::der_rewriter(ast_manager & m) {
392     m_imp = alloc(imp, m);
393 }
394 
~der_rewriter()395 der_rewriter::~der_rewriter() {
396     dealloc(m_imp);
397 }
398 
operator ()(expr * t,expr_ref & result,proof_ref & result_pr)399 void der_rewriter::operator()(expr * t, expr_ref & result, proof_ref & result_pr) {
400     m_imp->operator()(t, result, result_pr);
401 }
402 
cleanup()403 void der_rewriter::cleanup() {
404     ast_manager & m = m_imp->m_cfg.m;
405     dealloc(m_imp);
406     m_imp = alloc(imp, m);
407 }
408 
reset()409 void der_rewriter::reset() {
410     m_imp->reset();
411 }
412 
413 
414