1 /*++
2 Copyright (c) 2011 Microsoft Corporation
3 
4 Module Name:
5 
6     solve_eqs_tactic.cpp
7 
8 Abstract:
9 
10     Tactic for solving equations and performing gaussian elimination.
11 
12 Author:
13 
14     Leonardo de Moura (leonardo) 2011-12-29.
15 
16 Revision History:
17 
18 --*/
19 #include "ast/rewriter/expr_replacer.h"
20 #include "ast/occurs.h"
21 #include "ast/ast_util.h"
22 #include "ast/ast_pp.h"
23 #include "ast/pb_decl_plugin.h"
24 #include "ast/rewriter/th_rewriter.h"
25 #include "ast/rewriter/rewriter_def.h"
26 #include "ast/rewriter/hoist_rewriter.h"
27 #include "tactic/goal_shared_occs.h"
28 #include "tactic/tactical.h"
29 #include "tactic/generic_model_converter.h"
30 #include "tactic/tactic_params.hpp"
31 
32 class solve_eqs_tactic : public tactic {
33     struct imp {
34         typedef generic_model_converter gmc;
35 
36         ast_manager &                 m_manager;
37         expr_replacer *               m_r;
38         bool                          m_r_owner;
39         arith_util                    m_a_util;
40         obj_map<expr, unsigned>       m_num_occs;
41         unsigned                      m_num_steps;
42         unsigned                      m_num_eliminated_vars;
43         bool                          m_theory_solver;
44         bool                          m_ite_solver;
45         unsigned                      m_max_occs;
46         bool                          m_context_solve;
47         scoped_ptr<expr_substitution> m_subst;
48         scoped_ptr<expr_substitution> m_norm_subst;
49         expr_sparse_mark              m_candidate_vars;
50         expr_sparse_mark              m_candidate_set;
51         ptr_vector<expr>              m_candidates;
52         expr_ref_vector               m_marked_candidates;
53         ptr_vector<app>               m_vars;
54         expr_sparse_mark              m_nonzero;
55         ptr_vector<app>               m_ordered_vars;
56         bool                          m_produce_proofs;
57         bool                          m_produce_unsat_cores;
58         bool                          m_produce_models;
59 
impsolve_eqs_tactic::imp60         imp(ast_manager & m, params_ref const & p, expr_replacer * r, bool owner):
61             m_manager(m),
62             m_r(r),
63             m_r_owner(r == nullptr || owner),
64             m_a_util(m),
65             m_num_steps(0),
66             m_num_eliminated_vars(0),
67             m_marked_candidates(m) {
68             updt_params(p);
69             if (m_r == nullptr)
70                 m_r = mk_default_expr_replacer(m, true);
71         }
72 
~impsolve_eqs_tactic::imp73         ~imp() {
74             if (m_r_owner)
75                 dealloc(m_r);
76         }
77 
msolve_eqs_tactic::imp78         ast_manager & m() const { return m_manager; }
79 
updt_paramssolve_eqs_tactic::imp80         void updt_params(params_ref const & p) {
81             tactic_params tp(p);
82             m_ite_solver     = p.get_bool("ite_solver", tp.solve_eqs_ite_solver());
83             m_theory_solver  = p.get_bool("theory_solver", tp.solve_eqs_theory_solver());
84             m_max_occs       = p.get_uint("solve_eqs_max_occs", tp.solve_eqs_max_occs());
85             m_context_solve  = p.get_bool("context_solve", tp.solve_eqs_context_solve());
86         }
87 
checkpointsolve_eqs_tactic::imp88         void checkpoint() {
89             tactic::checkpoint(m());
90         }
91 
92         // Check if the number of occurrences of t is below the specified threshold :solve-eqs-max-occs
check_occssolve_eqs_tactic::imp93         bool check_occs(expr * t) const {
94             if (m_max_occs == UINT_MAX)
95                 return true;
96             unsigned num = 0;
97             m_num_occs.find(t, num);
98             TRACE("solve_eqs_check_occs", tout << mk_ismt2_pp(t, m_manager) << " num_occs: " << num << " max: " << m_max_occs << "\n";);
99             return num <= m_max_occs;
100         }
101 
102         // Use: (= x def) and (= def x)
103 
trivial_solve1solve_eqs_tactic::imp104         bool trivial_solve1(expr * lhs, expr * rhs, app_ref & var, expr_ref & def, proof_ref & pr) {
105 
106             if (is_uninterp_const(lhs) && !m_candidate_vars.is_marked(lhs) && !occurs(lhs, rhs) && check_occs(lhs)) {
107                 var = to_app(lhs);
108                 def = rhs;
109                 pr  = nullptr;
110                 return true;
111             }
112             else {
113                 return false;
114             }
115         }
trivial_solvesolve_eqs_tactic::imp116         bool trivial_solve(expr * lhs, expr * rhs, app_ref & var, expr_ref & def, proof_ref & pr) {
117             if (trivial_solve1(lhs, rhs, var, def, pr))
118                 return true;
119             if (trivial_solve1(rhs, lhs, var, def, pr)) {
120                 if (m_produce_proofs) {
121                     pr = m().mk_commutativity(m().mk_eq(lhs, rhs));
122                 }
123                 return true;
124             }
125             return false;
126         }
127 
128         // (ite c (= x t1) (= x t2)) --> (= x (ite c t1 t2))
solve_ite_coresolve_eqs_tactic::imp129         bool solve_ite_core(app * ite, expr * lhs1, expr * rhs1, expr * lhs2, expr * rhs2, app_ref & var, expr_ref & def, proof_ref & pr) {
130             if (lhs1 != lhs2)
131                 return false;
132             if (!is_uninterp_const(lhs1) || m_candidate_vars.is_marked(lhs1))
133                 return false;
134             if (occurs(lhs1, ite->get_arg(0)) || occurs(lhs1, rhs1) || occurs(lhs1, rhs2))
135                 return false;
136             if (!check_occs(lhs1))
137                 return false;
138             var = to_app(lhs1);
139             def = m().mk_ite(ite->get_arg(0), rhs1, rhs2);
140 
141             if (m_produce_proofs)
142                 pr = m().mk_rewrite(ite, m().mk_eq(var, def));
143             return true;
144         }
145 
146         // (ite c (= x t1) (= x t2)) --> (= x (ite c t1 t2))
solve_itesolve_eqs_tactic::imp147         bool solve_ite(app * ite, app_ref & var, expr_ref & def, proof_ref & pr) {
148             expr * t = ite->get_arg(1);
149             expr * e = ite->get_arg(2);
150 
151             if (!m().is_eq(t) || !m().is_eq(e))
152                 return false;
153 
154             expr * lhs1 = to_app(t)->get_arg(0);
155             expr * rhs1 = to_app(t)->get_arg(1);
156             expr * lhs2 = to_app(e)->get_arg(0);
157             expr * rhs2 = to_app(e)->get_arg(1);
158 
159             return
160                 solve_ite_core(ite, lhs1, rhs1, lhs2, rhs2, var, def, pr) ||
161                 solve_ite_core(ite, rhs1, lhs1, lhs2, rhs2, var, def, pr) ||
162                 solve_ite_core(ite, lhs1, rhs1, rhs2, lhs2, var, def, pr) ||
163                 solve_ite_core(ite, rhs1, lhs1, rhs2, lhs2, var, def, pr);
164         }
165 
is_pos_literalsolve_eqs_tactic::imp166         bool is_pos_literal(expr * n) {
167             return is_app(n) && to_app(n)->get_num_args() == 0 && to_app(n)->get_family_id() == null_family_id;
168         }
169 
is_neg_literalsolve_eqs_tactic::imp170         bool is_neg_literal(expr * n) {
171             if (m_manager.is_not(n))
172                 return is_pos_literal(to_app(n)->get_arg(0));
173             return false;
174         }
175 
176 
177         /**
178            \brief Given t of the form (f s_0 ... s_n),
179            return true if x occurs in some s_j for j != i
180         */
occurs_exceptsolve_eqs_tactic::imp181         bool occurs_except(expr * x, app * t, unsigned i) {
182             unsigned num = t->get_num_args();
183             for (unsigned j = 0; j < num; j++) {
184                 if (i != j && occurs(x, t->get_arg(j)))
185                     return true;
186             }
187             return false;
188         }
189 
add_possolve_eqs_tactic::imp190         void add_pos(expr* f) {
191             expr* lhs = nullptr, *rhs = nullptr;
192             rational val;
193             if (m_a_util.is_le(f, lhs, rhs) && m_a_util.is_numeral(rhs, val) && val.is_neg()) {
194                 m_nonzero.mark(lhs);
195             }
196             else if (m_a_util.is_ge(f, lhs, rhs) && m_a_util.is_numeral(rhs, val) && val.is_pos()) {
197                 m_nonzero.mark(lhs);
198             }
199             else if (m().is_not(f, f)) {
200                 if (m_a_util.is_le(f, lhs, rhs) && m_a_util.is_numeral(rhs, val) && !val.is_neg()) {
201                     m_nonzero.mark(lhs);
202                 }
203                 else if (m_a_util.is_ge(f, lhs, rhs) && m_a_util.is_numeral(rhs, val) && !val.is_pos()) {
204                     m_nonzero.mark(lhs);
205                 }
206                 else if (m().is_eq(f, lhs, rhs) && m_a_util.is_numeral(rhs, val) && val.is_zero()) {
207                     m_nonzero.mark(lhs);
208                 }
209             }
210         }
211 
is_nonzerosolve_eqs_tactic::imp212         bool is_nonzero(expr* e) {
213             return m_nonzero.is_marked(e);
214         }
215 
isolate_varsolve_eqs_tactic::imp216         bool isolate_var(app* arg, app_ref& var, expr_ref& div, unsigned i, app* lhs, expr* rhs) {
217             if (!m_a_util.is_mul(arg)) return false;
218             unsigned n = arg->get_num_args();
219             for (unsigned j = 0; j < n; ++j) {
220                 expr* e = arg->get_arg(j);
221                 bool ok = is_uninterp_const(e) && check_occs(e) && !occurs(e, rhs) && !occurs_except(e, lhs, i);
222                 if (!ok) continue;
223                 var = to_app(e);
224                 for (unsigned k = 0; ok && k < n; ++k) {
225                     expr* arg_k = arg->get_arg(k);
226                     ok = k == j || (!occurs(var, arg_k) && is_nonzero(arg_k));
227                 }
228                 if (!ok) continue;
229                 ptr_vector<expr> args;
230                 for (unsigned k = 0; k < n; ++k) {
231                     if (k != j) args.push_back(arg->get_arg(k));
232                 }
233                 div = m_a_util.mk_mul(args.size(), args.c_ptr());
234                 return true;
235             }
236             return false;
237         }
238 
solve_nlsolve_eqs_tactic::imp239         bool solve_nl(app * lhs, expr * rhs, expr* eq, app_ref& var, expr_ref & def, proof_ref & pr) {
240             SASSERT(m_a_util.is_add(lhs));
241             if (m_a_util.is_int(lhs)) return false;
242             unsigned num = lhs->get_num_args();
243             expr_ref div(m());
244             for (unsigned i = 0; i < num; i++) {
245                 expr * arg = lhs->get_arg(i);
246                 if (is_app(arg) && isolate_var(to_app(arg), var, div, i, lhs, rhs)) {
247                     ptr_vector<expr> args;
248                     for (unsigned k = 0; k < num; ++k) {
249                         if (k != i) args.push_back(lhs->get_arg(k));
250                     }
251                     def = m_a_util.mk_sub(rhs, m_a_util.mk_add(args.size(), args.c_ptr()));
252                     def = m_a_util.mk_div(def, div);
253                     if (m_produce_proofs)
254                         pr = m().mk_rewrite(eq, m().mk_eq(var, def));
255                     return true;
256                 }
257             }
258             return false;
259         }
260 
solve_arith_coresolve_eqs_tactic::imp261         bool solve_arith_core(app * lhs, expr * rhs, expr * eq, app_ref & var, expr_ref & def, proof_ref & pr) {
262             SASSERT(m_a_util.is_add(lhs));
263             bool is_int  = m_a_util.is_int(lhs);
264             expr * a = nullptr;
265             expr * v = nullptr;
266             rational a_val;
267             unsigned num = lhs->get_num_args();
268             unsigned i;
269             for (i = 0; i < num; i++) {
270                 expr * arg = lhs->get_arg(i);
271                 if (is_uninterp_const(arg) && !m_candidate_vars.is_marked(arg) && check_occs(arg) && !occurs(arg, rhs) && !occurs_except(arg, lhs, i)) {
272                     a_val = rational(1);
273                     v     = arg;
274                     break;
275                 }
276                 else if (m_a_util.is_mul(arg, a, v) &&
277                          is_uninterp_const(v) &&
278                          !m_candidate_vars.is_marked(v) &&
279                          m_a_util.is_numeral(a, a_val) &&
280                          !a_val.is_zero() &&
281                          (!is_int || a_val.is_minus_one()) &&
282                          check_occs(v) &&
283                          !occurs(v, rhs) &&
284                          !occurs_except(v, lhs, i)) {
285                     break;
286                 }
287             }
288             if (i == num)
289                 return false;
290             var = to_app(v);
291             expr_ref inv_a(m());
292             if (!a_val.is_one()) {
293                 inv_a = m_a_util.mk_numeral(rational(1)/a_val, is_int);
294                 rhs   = m_a_util.mk_mul(inv_a, rhs);
295             }
296 
297             ptr_buffer<expr> other_args;
298             for (unsigned j = 0; j < num; j++) {
299                 if (i != j) {
300                     if (inv_a)
301                         other_args.push_back(m_a_util.mk_mul(inv_a, lhs->get_arg(j)));
302                     else
303                         other_args.push_back(lhs->get_arg(j));
304                 }
305             }
306             switch (other_args.size()) {
307             case 0:
308                 def = rhs;
309                 break;
310             case 1:
311                 def = m_a_util.mk_sub(rhs, other_args[0]);
312                 break;
313             default:
314                 def = m_a_util.mk_sub(rhs, m_a_util.mk_add(other_args.size(), other_args.c_ptr()));
315                 break;
316             }
317             if (m_produce_proofs)
318                 pr = m().mk_rewrite(eq, m().mk_eq(var, def));
319             return true;
320         }
321 
solve_modsolve_eqs_tactic::imp322         bool solve_mod(expr * lhs, expr * rhs, expr * eq, app_ref & var, expr_ref & def, proof_ref & pr) {
323             rational r1, r2;
324             expr* arg1;
325             if (m_produce_proofs)
326                 return false;
327 
328             auto fresh = [&]() { return m().mk_fresh_const("mod", m_a_util.mk_int()); };
329             auto mk_int = [&](rational const& r) { return m_a_util.mk_int(r); };
330             auto add = [&](expr* a, expr* b) { return m_a_util.mk_add(a, b); };
331             auto mul = [&](expr* a, expr* b) { return m_a_util.mk_mul(a, b); };
332 
333             VERIFY(m_a_util.is_mod(lhs, lhs, arg1));
334             if (!m_a_util.is_numeral(arg1, r1) || !r1.is_pos()) {
335                 return false;
336             }
337             //
338             // solve lhs mod r1 = r2
339             // as lhs = r1*mod!1 + r2
340             //
341             if (m_a_util.is_numeral(rhs, r2) && !r2.is_neg() && r2 < r1) {
342                 expr_ref def0(m());
343                 def0 = add(mk_int(r2), mul(fresh(), mk_int(r1)));
344                 return solve_eq(lhs, def0, eq, var, def, pr);
345             }
346             return false;
347         }
348 
solve_arithsolve_eqs_tactic::imp349         bool solve_arith(expr * lhs, expr * rhs, expr * eq, app_ref & var, expr_ref & def, proof_ref & pr) {
350             return
351                 (m_a_util.is_add(lhs) && solve_arith_core(to_app(lhs), rhs, eq, var, def, pr)) ||
352                 (m_a_util.is_add(rhs) && solve_arith_core(to_app(rhs), lhs, eq, var, def, pr)) ||
353                 (m_a_util.is_mod(lhs) && solve_mod(lhs, rhs, eq, var, def, pr)) ||
354                 (m_a_util.is_mod(rhs) && solve_mod(rhs, lhs, eq, var, def, pr));
355         }
356 
357 
solve_eqsolve_eqs_tactic::imp358         bool solve_eq(expr* arg1, expr* arg2, expr* eq, app_ref& var, expr_ref & def, proof_ref& pr) {
359             if (trivial_solve(arg1, arg2, var, def, pr))
360                 return true;
361             if (m_theory_solver) {
362                 if (solve_arith(arg1, arg2, eq, var, def, pr))
363                     return true;
364             }
365             return false;
366         }
367 
solvesolve_eqs_tactic::imp368         bool solve(expr * f, app_ref & var, expr_ref & def, proof_ref & pr) {
369             expr* arg1 = nullptr, *arg2 = nullptr;
370             if (m().is_eq(f, arg1, arg2)) {
371                 return solve_eq(arg1, arg2, f, var, def, pr);
372             }
373 
374             if (m_ite_solver && m().is_ite(f))
375                 return solve_ite(to_app(f), var, def, pr);
376 
377             if (is_pos_literal(f)) {
378                 if (m_candidate_vars.is_marked(f))
379                     return false;
380                 var = to_app(f);
381                 def = m().mk_true();
382                 if (m_produce_proofs) {
383                     // [rewrite]: (iff (iff l true) l)
384                     // [symmetry T1]: (iff l (iff l true))
385                     pr = m().mk_rewrite(m().mk_eq(var, def), var);
386                     pr = m().mk_symmetry(pr);
387                 }
388                 TRACE("solve_eqs_bug2", tout << "eliminating: " << mk_ismt2_pp(f, m()) << "\n";);
389                 return true;
390             }
391 
392             if (is_neg_literal(f)) {
393                 var = to_app(to_app(f)->get_arg(0));
394                 if (m_candidate_vars.is_marked(var))
395                     return false;
396                 def = m().mk_false();
397                 if (m_produce_proofs) {
398                     // [rewrite]: (iff (iff l false) ~l)
399                     // [symmetry T1]: (iff ~l (iff l false))
400                     pr = m().mk_rewrite(m().mk_eq(var, def), f);
401                     pr = m().mk_symmetry(pr);
402                 }
403                 return true;
404             }
405 
406             return false;
407         }
408 
insert_solutionsolve_eqs_tactic::imp409         void insert_solution(goal const& g, unsigned idx, expr* f, app* var, expr* def, proof* pr) {
410             m_vars.push_back(var);
411             m_candidates.push_back(f);
412             m_candidate_set.mark(f);
413             m_candidate_vars.mark(var);
414             m_marked_candidates.push_back(f);
415             if (m_produce_proofs) {
416                 if (!pr)
417                     pr = g.pr(idx);
418                 else
419                     pr = m().mk_modus_ponens(g.pr(idx), pr);
420             }
421             m_subst->insert(var, def, pr, g.dep(idx));
422         }
423 
424         /**
425            \brief Start collecting candidates
426         */
collectsolve_eqs_tactic::imp427         void collect(goal const & g) {
428             m_subst->reset();
429             m_norm_subst->reset();
430             m_r->set_substitution(nullptr);
431             m_candidate_vars.reset();
432             m_candidate_set.reset();
433             m_candidates.reset();
434             m_marked_candidates.reset();
435             m_vars.reset();
436             m_nonzero.reset();
437             app_ref  var(m());
438             expr_ref  def(m());
439             proof_ref pr(m());
440             unsigned size = g.size();
441             for (unsigned idx = 0; idx < size; idx++) {
442                 add_pos(g.form(idx));
443             }
444             for (unsigned idx = 0; idx < size; idx++) {
445                 checkpoint();
446                 expr * f = g.form(idx);
447                 pr = nullptr;
448                 if (solve(f, var, def, pr)) {
449                     insert_solution(g, idx, f, var, def, pr);
450                 }
451                 m_num_steps++;
452             }
453 
454             TRACE("solve_eqs",
455                   tout << "candidate vars:\n";
456                   for (app* v : m_vars) {
457                       tout << mk_ismt2_pp(v, m()) << " ";
458                   }
459                   tout << "\n";);
460         }
461 
462         struct nnf_context {
463             bool m_is_and;
464             expr_ref_vector m_args;
465             unsigned m_index;
nnf_contextsolve_eqs_tactic::imp::nnf_context466             nnf_context(bool is_and, expr_ref_vector const& args, unsigned idx):
467                 m_is_and(is_and),
468                 m_args(args),
469                 m_index(idx)
470             {}
471         };
472 
473         ptr_vector<expr> m_todo;
mark_occurssolve_eqs_tactic::imp474         void mark_occurs(expr_mark& occ, goal const& g, expr* v) {
475             expr_fast_mark2 visited;
476             occ.mark(v, true);
477             visited.mark(v, true);
478             for (unsigned j = 0; j < g.size(); ++j) {
479                 m_todo.push_back(g.form(j));
480             }
481             while (!m_todo.empty()) {
482                 expr* e = m_todo.back();
483                 if (visited.is_marked(e)) {
484                     m_todo.pop_back();
485                     continue;
486                 }
487                 if (is_app(e)) {
488                     bool does_occur = false;
489                     bool all_visited = true;
490                     for (expr* arg : *to_app(e)) {
491                         if (!visited.is_marked(arg)) {
492                             m_todo.push_back(arg);
493                             all_visited = false;
494                         }
495                         else {
496                             does_occur |= occ.is_marked(arg);
497                         }
498                     }
499                     if (all_visited) {
500                         occ.mark(e, does_occur);
501                         visited.mark(e, true);
502                         m_todo.pop_back();
503                     }
504                 }
505                 else if (is_quantifier(e)) {
506                     expr* body = to_quantifier(e)->get_expr();
507                     if (visited.is_marked(body)) {
508                         visited.mark(e, true);
509                         occ.mark(e, occ.is_marked(body));
510                         m_todo.pop_back();
511                     }
512                     else {
513                         m_todo.push_back(body);
514                     }
515                 }
516                 else {
517                     visited.mark(e, true);
518                     m_todo.pop_back();
519                 }
520             }
521         }
522 
is_compatiblesolve_eqs_tactic::imp523         bool is_compatible(goal const& g, unsigned idx, vector<nnf_context> const & path, expr* v, expr* eq) {
524             expr_mark occ;
525             svector<lbool> cache;
526             mark_occurs(occ, g, v);
527             return is_goal_compatible(g, occ, cache, idx, v, eq) && is_path_compatible(occ, cache, path, v, eq);
528         }
529 
is_goal_compatiblesolve_eqs_tactic::imp530         bool is_goal_compatible(goal const& g, expr_mark& occ, svector<lbool>& cache, unsigned idx, expr* v, expr* eq) {
531             bool all_e = false;
532             for (unsigned j = 0; j < g.size(); ++j) {
533                 if (j != idx && !check_eq_compat_rec(occ, cache, g.form(j), v, eq, all_e)) {
534                     TRACE("solve_eqs", tout << "occurs goal " << mk_pp(eq, m()) << "\n";);
535                     return false;
536                 }
537             }
538             return true;
539         }
540 
541         //
542         // all_e := all disjunctions contain eq
543         //
544         // or, all_e -> skip if all disjunctions contain eq
545         // or, all_e -> fail if some disjunction contains v but not eq
546         // or, all_e -> all_e := false if some disjunction does not contain v
547         // and, all_e -> all_e
548         //
549 
is_path_compatiblesolve_eqs_tactic::imp550         bool is_path_compatible(expr_mark& occ, svector<lbool>& cache, vector<nnf_context> const & path, expr* v, expr* eq) {
551             bool all_e = true;
552             auto is_marked = [&](expr* e) {
553                 if (occ.is_marked(e))
554                     return true;
555                 if (m().is_not(e, e) && occ.is_marked(e))
556                     return true;
557                 return false;
558             };
559             for (unsigned i = path.size(); i-- > 0; ) {
560                 auto const& p = path[i];
561                 auto const& args = p.m_args;
562                 if (p.m_is_and && !all_e) {
563                     for (unsigned j = 0; j < args.size(); ++j) {
564                         if (j != p.m_index && is_marked(args[j])) {
565                             TRACE("solve_eqs", tout << "occurs and " << mk_pp(eq, m()) << " " << mk_pp(args[j], m()) << "\n";);
566                             return false;
567                         }
568                     }
569                 }
570                 else if (!p.m_is_and) {
571                     for (unsigned j = 0; j < args.size(); ++j) {
572                         if (j != p.m_index) {
573                             if (occurs(v, args[j])) {
574                                 if (!check_eq_compat_rec(occ, cache, args[j], v, eq, all_e)) {
575                                     TRACE("solve_eqs", tout << "occurs or " << mk_pp(eq, m()) << " " << mk_pp(args[j], m()) << "\n";);
576                                     return false;
577                                 }
578                             }
579                             else {
580                                 all_e = false;
581                             }
582                         }
583                     }
584                 }
585             }
586             return true;
587         }
588 
check_eq_compat_recsolve_eqs_tactic::imp589         bool check_eq_compat_rec(expr_mark& occ, svector<lbool>& cache, expr* f, expr* v, expr* eq, bool& all) {
590             expr_ref_vector args(m());
591             expr* f1 = nullptr;
592             // flattening may introduce fresh negations,
593             // occ is not defined on these negations
594             if (!m().is_not(f) && !occ.is_marked(f)) {
595                 all = false;
596                 return true;
597             }
598             unsigned idx = f->get_id();
599             if (cache.size() > idx && cache[idx] != l_undef) {
600                 return cache[idx] == l_true;
601             }
602             if (m().is_not(f, f1) && m().is_or(f1)) {
603                 flatten_and(f, args);
604                 for (expr* arg : args) {
605                     if (arg == eq) {
606                         cache.reserve(idx+1, l_undef);
607                         cache[idx] = l_true;
608                         return true;
609                     }
610                 }
611             }
612             else if (m().is_or(f)) {
613                 flatten_or(f, args);
614             }
615             else {
616                 return false;
617             }
618 
619             for (expr* arg : args) {
620                 if (!check_eq_compat_rec(occ, cache, arg, v, eq, all)) {
621                     cache.reserve(idx+1, l_undef);
622                     cache[idx] = l_false;
623                     return false;
624                 }
625             }
626             cache.reserve(idx+1, l_undef);
627             cache[idx] = l_true;
628             return true;
629         }
630 
hoist_nnfsolve_eqs_tactic::imp631         void hoist_nnf(goal const& g, expr* f, vector<nnf_context> & path, unsigned idx, unsigned depth, ast_mark& mark) {
632             if (depth > 3 || mark.is_marked(f)) {
633                 return;
634             }
635             mark.mark(f, true);
636             checkpoint();
637             app_ref var(m());
638             expr_ref def(m());
639             proof_ref pr(m());
640             expr_ref_vector args(m());
641             expr* f1 = nullptr;
642 
643             if (m().is_not(f, f1) && m().is_or(f1)) {
644                 flatten_and(f, args);
645                 for (unsigned i = 0; i < args.size(); ++i) {
646                     pr = nullptr;
647                     expr* arg = args.get(i), *lhs = nullptr, *rhs = nullptr;
648                     if (m().is_eq(arg, lhs, rhs)) {
649                         if (trivial_solve1(lhs, rhs, var, def, pr) && is_compatible(g, idx, path, var, arg)) {
650                             insert_solution(g, idx, arg, var, def, pr);
651                         }
652                         else if (trivial_solve1(rhs, lhs, var, def, pr) && is_compatible(g, idx, path, var, arg)) {
653                             insert_solution(g, idx, arg, var, def, pr);
654                         }
655                         else {
656                             IF_VERBOSE(10000,
657                                        verbose_stream() << "eq not solved " << mk_pp(arg, m()) << "\n";
658                                        verbose_stream() << is_uninterp_const(lhs) << " " << !m_candidate_vars.is_marked(lhs) << " "
659                                        << !occurs(lhs, rhs) << " " << check_occs(lhs) << "\n";);
660                         }
661                     }
662                     else {
663                         path.push_back(nnf_context(true, args, i));
664                         hoist_nnf(g, arg, path, idx, depth + 1, mark);
665                         path.pop_back();
666                     }
667                 }
668             }
669             else if (m().is_or(f)) {
670                 flatten_or(f, args);
671                 for (unsigned i = 0; i < args.size(); ++i) {
672                     path.push_back(nnf_context(false, args, i));
673                     hoist_nnf(g, args.get(i), path, idx, depth + 1, mark);
674                     path.pop_back();
675                 }
676             }
677         }
678 
collect_hoistsolve_eqs_tactic::imp679         void collect_hoist(goal const& g) {
680             unsigned size = g.size();
681             ast_mark mark;
682             vector<nnf_context> path;
683             for (unsigned idx = 0; idx < size; idx++) {
684                 checkpoint();
685                 hoist_nnf(g, g.form(idx), path, idx, 0, mark);
686             }
687         }
688 
distribute_and_orsolve_eqs_tactic::imp689         void distribute_and_or(goal & g) {
690             if (m_produce_proofs)
691                 return;
692             unsigned size = g.size();
693             hoist_rewriter_star rw(m());
694             th_rewriter thrw(m());
695             expr_ref tmp(m()), tmp2(m());
696 
697             TRACE("solve_eqs", g.display(tout););
698             for (unsigned idx = 0; !g.inconsistent() && idx < size; idx++) {
699                 checkpoint();
700                 if (g.is_decided_unsat()) break;
701                 expr* f = g.form(idx);
702                 proof_ref pr1(m()), pr2(m());
703                 thrw(f, tmp, pr1);
704                 rw(tmp, tmp2, pr2);
705                 TRACE("solve_eqs", tout << mk_pp(f, m()) << "\n->\n" << tmp << "\n->\n" << tmp2
706                       << "\n" << pr1 << "\n" << pr2 << "\n" << mk_pp(g.pr(idx), m()) << "\n";);
707                 pr1 = m().mk_transitivity(pr1, pr2);
708                 if (!pr1) pr1 = g.pr(idx); else pr1 = m().mk_modus_ponens(g.pr(idx), pr1);
709                 g.update(idx, tmp2, pr1, g.dep(idx));
710             }
711 
712         }
713 
sort_varssolve_eqs_tactic::imp714         void sort_vars() {
715             SASSERT(m_candidates.size() == m_vars.size());
716             TRACE("solve_eqs_bug", tout << "sorting vars...\n";);
717             m_ordered_vars.reset();
718 
719 
720             // The variables (and its definitions) in m_subst must remain alive until the end of this procedure.
721             // Reason: they are scheduled for unmarking in visiting/done.
722             // They should remain alive while they are on the stack.
723             // To make sure this is the case, whenever a variable (and its definition) is removed from m_subst,
724             // I add them to the saved vector.
725 
726             expr_ref_vector saved(m());
727 
728             expr_fast_mark1 visiting;
729             expr_fast_mark2 done;
730 
731             typedef std::pair<expr *, unsigned> frame;
732             svector<frame> todo;
733             unsigned num = 0;
734             for (app* v : m_vars) {
735                 checkpoint();
736                 if (!m_candidate_vars.is_marked(v))
737                     continue;
738                 todo.push_back(frame(v, 0));
739                 while (!todo.empty()) {
740                 start:
741                     frame & fr = todo.back();
742                     expr * t   = fr.first;
743                     m_num_steps++;
744                     TRACE("solve_eqs_bug", tout << "processing:\n" << mk_ismt2_pp(t, m()) << "\n";);
745                     if (t->get_ref_count() > 1 && done.is_marked(t)) {
746                         todo.pop_back();
747                         continue;
748                     }
749                     switch (t->get_kind()) {
750                     case AST_VAR:
751                         todo.pop_back();
752                         break;
753                     case AST_QUANTIFIER:
754                         num = to_quantifier(t)->get_num_children();
755                         while (fr.second < num) {
756                             expr * c = to_quantifier(t)->get_child(fr.second);
757                             fr.second++;
758                             if (c->get_ref_count() > 1 && done.is_marked(c))
759                                 continue;
760                             todo.push_back(frame(c, 0));
761                             goto start;
762                         }
763                         if (t->get_ref_count() > 1)
764                             done.mark(t);
765                         todo.pop_back();
766                         break;
767                     case AST_APP:
768                         num = to_app(t)->get_num_args();
769                         if (num == 0) {
770                             if (fr.second == 0) {
771                                 if (m_candidate_vars.is_marked(t)) {
772                                     if (visiting.is_marked(t)) {
773                                         // cycle detected: remove t
774                                         visiting.reset_mark(t);
775                                         m_candidate_vars.mark(t, false);
776                                         SASSERT(!m_candidate_vars.is_marked(t));
777 
778                                         // Must save t and its definition.
779                                         // See comment in the beginning of the function
780                                         expr * def = nullptr;
781                                         proof * pr;
782                                         expr_dependency * dep;
783                                         m_subst->find(to_app(t), def, pr, dep);
784                                         SASSERT(def != 0);
785                                         saved.push_back(t);
786                                         saved.push_back(def);
787                                         //
788 
789                                         m_subst->erase(t);
790                                     }
791                                     else {
792                                         visiting.mark(t);
793                                         fr.second = 1;
794                                         expr * def = nullptr;
795                                         proof * pr;
796                                         expr_dependency * dep;
797                                         m_subst->find(to_app(t), def, pr, dep);
798                                         SASSERT(def != 0);
799                                         todo.push_back(frame(def, 0));
800                                         goto start;
801                                     }
802                                 }
803                             }
804                             else {
805                                 SASSERT(fr.second == 1);
806                                 if (m_candidate_vars.is_marked(t)) {
807                                     visiting.reset_mark(t);
808                                     m_ordered_vars.push_back(to_app(t));
809                                 }
810                                 else {
811                                     // var was removed from the list of candidate vars to elim cycle
812                                     // do nothing
813                                 }
814                             }
815                         }
816                         else {
817                             while (fr.second < num) {
818                                 expr * arg = to_app(t)->get_arg(fr.second);
819                                 fr.second++;
820                                 if (arg->get_ref_count() > 1 && done.is_marked(arg))
821                                     continue;
822                                 todo.push_back(frame(arg, 0));
823                                 goto start;
824                             }
825                         }
826                         if (t->get_ref_count() > 1)
827                             done.mark(t);
828                         todo.pop_back();
829                         break;
830                     default:
831                         UNREACHABLE();
832                         todo.pop_back();
833                         break;
834                     }
835                 }
836             }
837 
838             // cleanup
839             unsigned idx = 0;
840             for (expr* v : m_vars) {
841                 if (!m_candidate_vars.is_marked(v)) {
842                     m_candidate_set.mark(m_candidates[idx], false);
843                     m_marked_candidates.push_back(m_candidates[idx]);
844                     m_marked_candidates.push_back(v);
845                 }
846                 ++idx;
847             }
848 
849             IF_VERBOSE(10000,
850                        verbose_stream() << "ordered vars: ";
851                        for (app* v : m_ordered_vars) verbose_stream() << mk_pp(v, m()) << " ";
852                        verbose_stream() << "\n";);
853             TRACE("solve_eqs",
854                   tout << "ordered vars:\n";
855                   for (app* v : m_ordered_vars) {
856                       SASSERT(m_candidate_vars.is_marked(v));
857                       tout << mk_ismt2_pp(v, m()) << " ";
858                   }
859                   tout << "\n";);
860             m_candidate_vars.reset();
861         }
862 
normalizesolve_eqs_tactic::imp863         void normalize() {
864             m_norm_subst->reset();
865             m_r->set_substitution(m_norm_subst.get());
866 
867 
868             expr_dependency_ref new_dep(m());
869             for (app * v : m_ordered_vars) {
870                 checkpoint();
871                 expr_ref new_def(m());
872                 proof_ref new_pr(m());
873                 expr * def = nullptr;
874                 proof * pr = nullptr;
875                 expr_dependency * dep = nullptr;
876                 m_subst->find(v, def, pr, dep);
877                 SASSERT(def);
878                 m_r->operator()(def, new_def, new_pr, new_dep);
879                 m_num_steps += m_r->get_num_steps() + 1;
880                 if (m_produce_proofs)
881                     new_pr = m().mk_transitivity(pr, new_pr);
882                 new_dep = m().mk_join(dep, new_dep);
883                 m_norm_subst->insert(v, new_def, new_pr, new_dep);
884                 // we updated the substituting, but we don't need to reset m_r
885                 // because all cached values there do not depend on v.
886             }
887             m_subst->reset();
888             TRACE("solve_eqs",
889                   tout << "after normalizing variables\n";
890                   for (expr * v : m_ordered_vars) {
891                       expr * def = 0;
892                       proof * pr = 0;
893                       expr_dependency * dep = 0;
894                       m_norm_subst->find(v, def, pr, dep);
895                       tout << mk_ismt2_pp(v, m()) << "\n----->\n" << mk_ismt2_pp(def, m()) << "\n\n";
896                   });
897         }
898 
substitutesolve_eqs_tactic::imp899         void substitute(goal & g) {
900             // force the cache of m_r to be reset.
901             m_r->set_substitution(m_norm_subst.get());
902 
903             expr_ref new_f(m());
904             proof_ref new_pr(m());
905             expr_dependency_ref new_dep(m());
906             unsigned size = g.size();
907             for (unsigned idx = 0; idx < size; idx++) {
908                 checkpoint();
909                 expr * f = g.form(idx);
910                 TRACE("gaussian_leak", tout << "processing:\n" << mk_ismt2_pp(f, m()) << "\n";);
911                 if (m_candidate_set.is_marked(f)) {
912                     m_marked_candidates.push_back(f);
913                     // f may be deleted after the following update.
914                     // so, we must remove the mark before doing the update
915                     m_candidate_set.mark(f, false);
916                     SASSERT(!m_candidate_set.is_marked(f));
917                     g.update(idx, m().mk_true(), m().mk_true_proof(), nullptr);
918                     m_num_steps ++;
919                     continue;
920                 }
921 
922                 m_r->operator()(f, new_f, new_pr, new_dep);
923 
924                 TRACE("solve_eqs_subst", tout << mk_ismt2_pp(f, m()) << "\n--->\n" << mk_ismt2_pp(new_f, m()) << "\n";);
925                 m_num_steps += m_r->get_num_steps() + 1;
926                 if (m_produce_proofs)
927                     new_pr = m().mk_modus_ponens(g.pr(idx), new_pr);
928                 if (m_produce_unsat_cores)
929                     new_dep = m().mk_join(g.dep(idx), new_dep);
930 
931                 g.update(idx, new_f, new_pr, new_dep);
932                 if (g.inconsistent())
933                     return;
934             }
935             g.elim_true();
936             TRACE("solve_eqs", g.display(tout << "after applying substitution\n"););
937 #if 0
938             DEBUG_CODE({
939                     for (expr* v : m_ordered_vars) {
940                         for (unsigned j = 0; j < g.size(); j++) {
941                             CASSERT("solve_eqs_bug", !occurs(v, g.form(j)));
942                         }
943                     }});
944 #endif
945         }
946 
save_elim_varssolve_eqs_tactic::imp947         void save_elim_vars(model_converter_ref & mc) {
948             IF_VERBOSE(100, if (!m_ordered_vars.empty()) verbose_stream() << "num. eliminated vars: " << m_ordered_vars.size() << "\n";);
949             m_num_eliminated_vars += m_ordered_vars.size();
950             if (m_produce_models) {
951                 if (!mc.get())
952                     mc = alloc(gmc, m(), "solve-eqs");
953                 for (app* v : m_ordered_vars) {
954                     expr * def = nullptr;
955                     proof * pr;
956                     expr_dependency * dep = nullptr;
957                     m_norm_subst->find(v, def, pr, dep);
958                     SASSERT(def);
959                     static_cast<gmc*>(mc.get())->add(v, def);
960                 }
961             }
962         }
963 
collect_num_occssolve_eqs_tactic::imp964         void collect_num_occs(expr * t, expr_fast_mark1 & visited) {
965             ptr_buffer<app, 128> stack;
966 
967             auto visit = [&](expr* arg) {
968                 if (is_uninterp_const(arg)) {
969                     m_num_occs.insert_if_not_there(arg, 0)++;
970                 }
971                 if (!visited.is_marked(arg) && is_app(arg)) {
972                     visited.mark(arg, true);
973                     stack.push_back(to_app(arg));
974                 }
975             };
976 
977             visit(t);
978 
979             while (!stack.empty()) {
980                 app * t = stack.back();
981                 stack.pop_back();
982                 for (expr* arg : *t)
983                     visit(arg);
984             }
985         }
986 
collect_num_occssolve_eqs_tactic::imp987         void collect_num_occs(goal const & g) {
988             if (m_max_occs == UINT_MAX)
989                 return; // no need to compute num occs
990             m_num_occs.reset();
991             expr_fast_mark1 visited;
992             unsigned sz = g.size();
993             for (unsigned i = 0; i < sz; i++)
994                 collect_num_occs(g.form(i), visited);
995         }
996 
get_num_stepssolve_eqs_tactic::imp997         unsigned get_num_steps() const {
998             return m_num_steps;
999         }
1000 
get_num_eliminated_varssolve_eqs_tactic::imp1001         unsigned get_num_eliminated_vars() const {
1002             return m_num_eliminated_vars;
1003         }
1004 
1005         //
1006         // TBD: rewrite the tactic to first apply a topological sorting that
1007         // approximates the dependencies between variables. Then apply
1008         // simplification on top of this sorting, so that it can apply sub-quadratic
1009         // equality and unit propagation.
1010         //
operator ()solve_eqs_tactic::imp1011         void operator()(goal_ref const & g, goal_ref_buffer & result) {
1012             model_converter_ref mc;
1013             tactic_report report("solve_eqs", *g);
1014             TRACE("goal", g->display(tout););
1015             m_produce_models = g->models_enabled();
1016             m_produce_proofs = g->proofs_enabled();
1017             m_produce_unsat_cores = g->unsat_core_enabled();
1018 
1019             if (!g->inconsistent()) {
1020                 m_subst      = alloc(expr_substitution, m(), m_produce_unsat_cores, m_produce_proofs);
1021                 m_norm_subst = alloc(expr_substitution, m(), m_produce_unsat_cores, m_produce_proofs);
1022                 unsigned rounds = 0;
1023                 while (rounds < 20) {
1024                     ++rounds;
1025                     if (!m_produce_proofs && m_context_solve && rounds < 3) {
1026                         distribute_and_or(*(g.get()));
1027                     }
1028                     collect_num_occs(*g);
1029                     collect(*g);
1030                     if (!m_produce_proofs && m_context_solve && rounds < 3) {
1031                         collect_hoist(*g);
1032                     }
1033                     if (m_subst->empty()) {
1034                         break;
1035                     }
1036                     sort_vars();
1037                     if (m_ordered_vars.empty()) {
1038                         break;
1039                     }
1040                     normalize();
1041                     substitute(*(g.get()));
1042                     if (g->inconsistent()) {
1043                         break;
1044                     }
1045                     save_elim_vars(mc);
1046                     TRACE("solve_eqs_round", g->display(tout); if (mc) mc->display(tout););
1047                     if (rounds > 10 && m_ordered_vars.size() == 1)
1048                         break;
1049                 }
1050             }
1051             g->inc_depth();
1052             g->add(mc.get());
1053             result.push_back(g.get());
1054         }
1055     };
1056 
1057     imp *      m_imp;
1058     params_ref m_params;
1059 public:
solve_eqs_tactic(ast_manager & m,params_ref const & p,expr_replacer * r,bool owner)1060     solve_eqs_tactic(ast_manager & m, params_ref const & p, expr_replacer * r, bool owner):
1061         m_params(p) {
1062         m_imp = alloc(imp, m, p, r, owner);
1063     }
1064 
translate(ast_manager & m)1065     tactic * translate(ast_manager & m) override {
1066         return alloc(solve_eqs_tactic, m, m_params, mk_expr_simp_replacer(m, m_params), true);
1067     }
1068 
~solve_eqs_tactic()1069     ~solve_eqs_tactic() override {
1070         dealloc(m_imp);
1071     }
1072 
updt_params(params_ref const & p)1073     void updt_params(params_ref const & p) override {
1074         m_params = p;
1075         m_imp->updt_params(p);
1076     }
1077 
collect_param_descrs(param_descrs & r)1078     void collect_param_descrs(param_descrs & r) override {
1079         r.insert("solve_eqs_max_occs", CPK_UINT, "(default: infty) maximum number of occurrences for considering a variable for gaussian eliminations.");
1080         r.insert("theory_solver", CPK_BOOL, "(default: true) use theory solvers.");
1081         r.insert("ite_solver", CPK_BOOL, "(default: true) use if-then-else solver.");
1082         r.insert("context_solve", CPK_BOOL, "(default: false) solve equalities under disjunctions.");
1083     }
1084 
operator ()(goal_ref const & in,goal_ref_buffer & result)1085     void operator()(goal_ref const & in,
1086                     goal_ref_buffer & result) override {
1087         (*m_imp)(in, result);
1088         report_tactic_progress(":num-elim-vars", m_imp->get_num_eliminated_vars());
1089     }
1090 
cleanup()1091     void cleanup() override {
1092         unsigned num_elim_vars = m_imp->m_num_eliminated_vars;
1093         ast_manager & m = m_imp->m();
1094         expr_replacer * r = m_imp->m_r;
1095         if (r)
1096             r->set_substitution(nullptr);
1097         bool owner = m_imp->m_r_owner;
1098         m_imp->m_r_owner  = false; // stole replacer
1099 
1100         imp * d = alloc(imp, m, m_params, r, owner);
1101         d->m_num_eliminated_vars = num_elim_vars;
1102         std::swap(d, m_imp);
1103         dealloc(d);
1104     }
1105 
collect_statistics(statistics & st) const1106     void collect_statistics(statistics & st) const override {
1107         st.update("eliminated vars", m_imp->get_num_eliminated_vars());
1108     }
1109 
reset_statistics()1110     void reset_statistics() override {
1111         m_imp->m_num_eliminated_vars = 0;
1112     }
1113 
1114 };
1115 
mk_solve_eqs_tactic(ast_manager & m,params_ref const & p,expr_replacer * r)1116 tactic * mk_solve_eqs_tactic(ast_manager & m, params_ref const & p, expr_replacer * r) {
1117     if (r == nullptr)
1118         return clean(alloc(solve_eqs_tactic, m, p, mk_expr_simp_replacer(m, p), true));
1119     else
1120         return clean(alloc(solve_eqs_tactic, m, p, r, false));
1121 }
1122