1 /*++
2 Copyright (c) 2011 Microsoft Corporation
3 
4 Module Name:
5 
6     solve_eqs_tactic.cpp
7 
8 Abstract:
9 
10     Tactic for solving equations and performing gaussian elimination.
11 
12 Author:
13 
14     Leonardo de Moura (leonardo) 2011-12-29.
15 
16 Revision History:
17 
18 --*/
19 #include "ast/rewriter/expr_replacer.h"
20 #include "ast/occurs.h"
21 #include "ast/ast_util.h"
22 #include "ast/ast_pp.h"
23 #include "ast/pb_decl_plugin.h"
24 #include "ast/recfun_decl_plugin.h"
25 #include "ast/rewriter/th_rewriter.h"
26 #include "ast/rewriter/rewriter_def.h"
27 #include "ast/rewriter/hoist_rewriter.h"
28 #include "tactic/goal_shared_occs.h"
29 #include "tactic/tactical.h"
30 #include "tactic/generic_model_converter.h"
31 #include "tactic/tactic_params.hpp"
32 
33 class solve_eqs_tactic : public tactic {
34     struct imp {
35         typedef generic_model_converter gmc;
36 
37         ast_manager &                 m_manager;
38         expr_replacer *               m_r;
39         bool                          m_r_owner;
40         arith_util                    m_a_util;
41         obj_map<expr, unsigned>       m_num_occs;
42         unsigned                      m_num_steps;
43         unsigned                      m_num_eliminated_vars;
44         bool                          m_theory_solver;
45         bool                          m_ite_solver;
46         unsigned                      m_max_occs;
47         bool                          m_context_solve;
48         scoped_ptr<expr_substitution> m_subst;
49         scoped_ptr<expr_substitution> m_norm_subst;
50         expr_sparse_mark              m_candidate_vars;
51         expr_sparse_mark              m_candidate_set;
52         ptr_vector<expr>              m_candidates;
53         expr_ref_vector               m_marked_candidates;
54         ptr_vector<app>               m_vars;
55         expr_sparse_mark              m_nonzero;
56         ptr_vector<app>               m_ordered_vars;
57         bool                          m_produce_proofs;
58         bool                          m_produce_unsat_cores;
59         bool                          m_produce_models;
60 
impsolve_eqs_tactic::imp61         imp(ast_manager & m, params_ref const & p, expr_replacer * r, bool owner):
62             m_manager(m),
63             m_r(r),
64             m_r_owner(r == nullptr || owner),
65             m_a_util(m),
66             m_num_steps(0),
67             m_num_eliminated_vars(0),
68             m_marked_candidates(m) {
69             updt_params(p);
70             if (m_r == nullptr)
71                 m_r = mk_default_expr_replacer(m, true);
72         }
73 
~impsolve_eqs_tactic::imp74         ~imp() {
75             if (m_r_owner)
76                 dealloc(m_r);
77         }
78 
msolve_eqs_tactic::imp79         ast_manager & m() const { return m_manager; }
80 
updt_paramssolve_eqs_tactic::imp81         void updt_params(params_ref const & p) {
82             tactic_params tp(p);
83             m_ite_solver     = p.get_bool("ite_solver", tp.solve_eqs_ite_solver());
84             m_theory_solver  = p.get_bool("theory_solver", tp.solve_eqs_theory_solver());
85             m_max_occs       = p.get_uint("solve_eqs_max_occs", tp.solve_eqs_max_occs());
86             m_context_solve  = p.get_bool("context_solve", tp.solve_eqs_context_solve());
87         }
88 
checkpointsolve_eqs_tactic::imp89         void checkpoint() {
90             tactic::checkpoint(m());
91         }
92 
93         // Check if the number of occurrences of t is below the specified threshold :solve-eqs-max-occs
check_occssolve_eqs_tactic::imp94         bool check_occs(expr * t) const {
95             if (m_max_occs == UINT_MAX)
96                 return true;
97             unsigned num = 0;
98             m_num_occs.find(t, num);
99             TRACE("solve_eqs_check_occs", tout << mk_ismt2_pp(t, m_manager) << " num_occs: " << num << " max: " << m_max_occs << "\n";);
100             return num <= m_max_occs;
101         }
102 
103         // Use: (= x def) and (= def x)
104 
trivial_solve1solve_eqs_tactic::imp105         bool trivial_solve1(expr * lhs, expr * rhs, app_ref & var, expr_ref & def, proof_ref & pr) {
106 
107             if (is_uninterp_const(lhs) && !m_candidate_vars.is_marked(lhs) && !occurs(lhs, rhs) && check_occs(lhs)) {
108                 var = to_app(lhs);
109                 def = rhs;
110                 pr  = nullptr;
111                 return true;
112             }
113             else {
114                 return false;
115             }
116         }
trivial_solvesolve_eqs_tactic::imp117         bool trivial_solve(expr * lhs, expr * rhs, app_ref & var, expr_ref & def, proof_ref & pr) {
118             if (trivial_solve1(lhs, rhs, var, def, pr))
119                 return true;
120             if (trivial_solve1(rhs, lhs, var, def, pr)) {
121                 if (m_produce_proofs) {
122                     pr = m().mk_commutativity(m().mk_eq(lhs, rhs));
123                 }
124                 return true;
125             }
126             return false;
127         }
128 
129         // (ite c (= x t1) (= x t2)) --> (= x (ite c t1 t2))
solve_ite_coresolve_eqs_tactic::imp130         bool solve_ite_core(app * ite, expr * lhs1, expr * rhs1, expr * lhs2, expr * rhs2, app_ref & var, expr_ref & def, proof_ref & pr) {
131             if (lhs1 != lhs2)
132                 return false;
133             if (!is_uninterp_const(lhs1) || m_candidate_vars.is_marked(lhs1))
134                 return false;
135             if (occurs(lhs1, ite->get_arg(0)) || occurs(lhs1, rhs1) || occurs(lhs1, rhs2))
136                 return false;
137             if (!check_occs(lhs1))
138                 return false;
139             var = to_app(lhs1);
140             def = m().mk_ite(ite->get_arg(0), rhs1, rhs2);
141 
142             if (m_produce_proofs)
143                 pr = m().mk_rewrite(ite, m().mk_eq(var, def));
144             return true;
145         }
146 
147         // (ite c (= x t1) (= x t2)) --> (= x (ite c t1 t2))
solve_itesolve_eqs_tactic::imp148         bool solve_ite(app * ite, app_ref & var, expr_ref & def, proof_ref & pr) {
149             expr * t = ite->get_arg(1);
150             expr * e = ite->get_arg(2);
151 
152             if (!m().is_eq(t) || !m().is_eq(e))
153                 return false;
154 
155             expr * lhs1 = to_app(t)->get_arg(0);
156             expr * rhs1 = to_app(t)->get_arg(1);
157             expr * lhs2 = to_app(e)->get_arg(0);
158             expr * rhs2 = to_app(e)->get_arg(1);
159 
160             return
161                 solve_ite_core(ite, lhs1, rhs1, lhs2, rhs2, var, def, pr) ||
162                 solve_ite_core(ite, rhs1, lhs1, lhs2, rhs2, var, def, pr) ||
163                 solve_ite_core(ite, lhs1, rhs1, rhs2, lhs2, var, def, pr) ||
164                 solve_ite_core(ite, rhs1, lhs1, rhs2, lhs2, var, def, pr);
165         }
166 
is_pos_literalsolve_eqs_tactic::imp167         bool is_pos_literal(expr * n) {
168             return is_app(n) && to_app(n)->get_num_args() == 0 && to_app(n)->get_family_id() == null_family_id;
169         }
170 
is_neg_literalsolve_eqs_tactic::imp171         bool is_neg_literal(expr * n) {
172             if (m_manager.is_not(n))
173                 return is_pos_literal(to_app(n)->get_arg(0));
174             return false;
175         }
176 
177 
178         /**
179            \brief Given t of the form (f s_0 ... s_n),
180            return true if x occurs in some s_j for j != i
181         */
occurs_exceptsolve_eqs_tactic::imp182         bool occurs_except(expr * x, app * t, unsigned i) {
183             unsigned num = t->get_num_args();
184             for (unsigned j = 0; j < num; j++) {
185                 if (i != j && occurs(x, t->get_arg(j)))
186                     return true;
187             }
188             return false;
189         }
190 
add_possolve_eqs_tactic::imp191         void add_pos(expr* f) {
192             expr* lhs = nullptr, *rhs = nullptr;
193             rational val;
194             if (m_a_util.is_le(f, lhs, rhs) && m_a_util.is_numeral(rhs, val) && val.is_neg()) {
195                 m_nonzero.mark(lhs);
196             }
197             else if (m_a_util.is_ge(f, lhs, rhs) && m_a_util.is_numeral(rhs, val) && val.is_pos()) {
198                 m_nonzero.mark(lhs);
199             }
200             else if (m().is_not(f, f)) {
201                 if (m_a_util.is_le(f, lhs, rhs) && m_a_util.is_numeral(rhs, val) && !val.is_neg()) {
202                     m_nonzero.mark(lhs);
203                 }
204                 else if (m_a_util.is_ge(f, lhs, rhs) && m_a_util.is_numeral(rhs, val) && !val.is_pos()) {
205                     m_nonzero.mark(lhs);
206                 }
207                 else if (m().is_eq(f, lhs, rhs) && m_a_util.is_numeral(rhs, val) && val.is_zero()) {
208                     m_nonzero.mark(lhs);
209                 }
210             }
211         }
212 
is_nonzerosolve_eqs_tactic::imp213         bool is_nonzero(expr* e) {
214             return m_nonzero.is_marked(e);
215         }
216 
isolate_varsolve_eqs_tactic::imp217         bool isolate_var(app* arg, app_ref& var, expr_ref& div, unsigned i, app* lhs, expr* rhs) {
218             if (!m_a_util.is_mul(arg)) return false;
219             unsigned n = arg->get_num_args();
220             for (unsigned j = 0; j < n; ++j) {
221                 expr* e = arg->get_arg(j);
222                 bool ok = is_uninterp_const(e) && check_occs(e) && !occurs(e, rhs) && !occurs_except(e, lhs, i);
223                 if (!ok) continue;
224                 var = to_app(e);
225                 for (unsigned k = 0; ok && k < n; ++k) {
226                     expr* arg_k = arg->get_arg(k);
227                     ok = k == j || (!occurs(var, arg_k) && is_nonzero(arg_k));
228                 }
229                 if (!ok) continue;
230                 ptr_vector<expr> args;
231                 for (unsigned k = 0; k < n; ++k) {
232                     if (k != j) args.push_back(arg->get_arg(k));
233                 }
234                 div = m_a_util.mk_mul(args.size(), args.data());
235                 return true;
236             }
237             return false;
238         }
239 
solve_nlsolve_eqs_tactic::imp240         bool solve_nl(app * lhs, expr * rhs, expr* eq, app_ref& var, expr_ref & def, proof_ref & pr) {
241             SASSERT(m_a_util.is_add(lhs));
242             if (m_a_util.is_int(lhs)) return false;
243             unsigned num = lhs->get_num_args();
244             expr_ref div(m());
245             for (unsigned i = 0; i < num; i++) {
246                 expr * arg = lhs->get_arg(i);
247                 if (is_app(arg) && isolate_var(to_app(arg), var, div, i, lhs, rhs)) {
248                     ptr_vector<expr> args;
249                     for (unsigned k = 0; k < num; ++k) {
250                         if (k != i) args.push_back(lhs->get_arg(k));
251                     }
252                     def = m_a_util.mk_sub(rhs, m_a_util.mk_add(args.size(), args.data()));
253                     def = m_a_util.mk_div(def, div);
254                     if (m_produce_proofs)
255                         pr = m().mk_rewrite(eq, m().mk_eq(var, def));
256                     return true;
257                 }
258             }
259             return false;
260         }
261 
solve_arith_coresolve_eqs_tactic::imp262         bool solve_arith_core(app * lhs, expr * rhs, expr * eq, app_ref & var, expr_ref & def, proof_ref & pr) {
263             SASSERT(m_a_util.is_add(lhs));
264             bool is_int  = m_a_util.is_int(lhs);
265             expr * a = nullptr;
266             expr * v = nullptr;
267             rational a_val;
268             unsigned num = lhs->get_num_args();
269             unsigned i;
270             for (i = 0; i < num; i++) {
271                 expr * arg = lhs->get_arg(i);
272                 if (is_uninterp_const(arg) && !m_candidate_vars.is_marked(arg) && check_occs(arg) && !occurs(arg, rhs) && !occurs_except(arg, lhs, i)) {
273                     a_val = rational(1);
274                     v     = arg;
275                     break;
276                 }
277                 else if (m_a_util.is_mul(arg, a, v) &&
278                          is_uninterp_const(v) &&
279                          !m_candidate_vars.is_marked(v) &&
280                          m_a_util.is_numeral(a, a_val) &&
281                          !a_val.is_zero() &&
282                          (!is_int || a_val.is_minus_one()) &&
283                          check_occs(v) &&
284                          !occurs(v, rhs) &&
285                          !occurs_except(v, lhs, i)) {
286                     break;
287                 }
288             }
289             if (i == num)
290                 return false;
291             var = to_app(v);
292             expr_ref inv_a(m());
293             if (!a_val.is_one()) {
294                 inv_a = m_a_util.mk_numeral(rational(1)/a_val, is_int);
295                 rhs   = m_a_util.mk_mul(inv_a, rhs);
296             }
297 
298             ptr_buffer<expr> other_args;
299             for (unsigned j = 0; j < num; j++) {
300                 if (i != j) {
301                     if (inv_a)
302                         other_args.push_back(m_a_util.mk_mul(inv_a, lhs->get_arg(j)));
303                     else
304                         other_args.push_back(lhs->get_arg(j));
305                 }
306             }
307             switch (other_args.size()) {
308             case 0:
309                 def = rhs;
310                 break;
311             case 1:
312                 def = m_a_util.mk_sub(rhs, other_args[0]);
313                 break;
314             default:
315                 def = m_a_util.mk_sub(rhs, m_a_util.mk_add(other_args.size(), other_args.data()));
316                 break;
317             }
318             if (m_produce_proofs)
319                 pr = m().mk_rewrite(eq, m().mk_eq(var, def));
320             return true;
321         }
322 
solve_modsolve_eqs_tactic::imp323         bool solve_mod(expr * lhs, expr * rhs, expr * eq, app_ref & var, expr_ref & def, proof_ref & pr) {
324             rational r1, r2;
325             expr* arg1;
326             if (m_produce_proofs)
327                 return false;
328 
329             auto fresh = [&]() { return m().mk_fresh_const("mod", m_a_util.mk_int()); };
330             auto mk_int = [&](rational const& r) { return m_a_util.mk_int(r); };
331             auto add = [&](expr* a, expr* b) { return m_a_util.mk_add(a, b); };
332             auto mul = [&](expr* a, expr* b) { return m_a_util.mk_mul(a, b); };
333 
334             VERIFY(m_a_util.is_mod(lhs, lhs, arg1));
335             if (!m_a_util.is_numeral(arg1, r1) || !r1.is_pos()) {
336                 return false;
337             }
338             //
339             // solve lhs mod r1 = r2
340             // as lhs = r1*mod!1 + r2
341             //
342             if (m_a_util.is_numeral(rhs, r2) && !r2.is_neg() && r2 < r1) {
343                 expr_ref def0(m());
344                 def0 = add(mk_int(r2), mul(fresh(), mk_int(r1)));
345                 return solve_eq(lhs, def0, eq, var, def, pr);
346             }
347             return false;
348         }
349 
solve_arithsolve_eqs_tactic::imp350         bool solve_arith(expr * lhs, expr * rhs, expr * eq, app_ref & var, expr_ref & def, proof_ref & pr) {
351             return
352                 (m_a_util.is_add(lhs) && solve_arith_core(to_app(lhs), rhs, eq, var, def, pr)) ||
353                 (m_a_util.is_add(rhs) && solve_arith_core(to_app(rhs), lhs, eq, var, def, pr)) ||
354                 (m_a_util.is_mod(lhs) && solve_mod(lhs, rhs, eq, var, def, pr)) ||
355                 (m_a_util.is_mod(rhs) && solve_mod(rhs, lhs, eq, var, def, pr));
356         }
357 
358 
solve_eqsolve_eqs_tactic::imp359         bool solve_eq(expr* arg1, expr* arg2, expr* eq, app_ref& var, expr_ref & def, proof_ref& pr) {
360             if (trivial_solve(arg1, arg2, var, def, pr))
361                 return true;
362             if (m_theory_solver) {
363                 if (solve_arith(arg1, arg2, eq, var, def, pr))
364                     return true;
365             }
366             return false;
367         }
368 
solvesolve_eqs_tactic::imp369         bool solve(expr * f, app_ref & var, expr_ref & def, proof_ref & pr) {
370             expr* arg1 = nullptr, *arg2 = nullptr;
371             if (m().is_eq(f, arg1, arg2)) {
372                 return solve_eq(arg1, arg2, f, var, def, pr);
373             }
374 
375             if (m_ite_solver && m().is_ite(f))
376                 return solve_ite(to_app(f), var, def, pr);
377 
378             if (is_pos_literal(f)) {
379                 if (m_candidate_vars.is_marked(f))
380                     return false;
381                 var = to_app(f);
382                 def = m().mk_true();
383                 if (m_produce_proofs) {
384                     // [rewrite]: (iff (iff l true) l)
385                     // [symmetry T1]: (iff l (iff l true))
386                     pr = m().mk_rewrite(m().mk_eq(var, def), var);
387                     pr = m().mk_symmetry(pr);
388                 }
389                 TRACE("solve_eqs_bug2", tout << "eliminating: " << mk_ismt2_pp(f, m()) << "\n";);
390                 return true;
391             }
392 
393             if (is_neg_literal(f)) {
394                 var = to_app(to_app(f)->get_arg(0));
395                 if (m_candidate_vars.is_marked(var))
396                     return false;
397                 def = m().mk_false();
398                 if (m_produce_proofs) {
399                     // [rewrite]: (iff (iff l false) ~l)
400                     // [symmetry T1]: (iff ~l (iff l false))
401                     pr = m().mk_rewrite(m().mk_eq(var, def), f);
402                     pr = m().mk_symmetry(pr);
403                 }
404                 return true;
405             }
406 
407             return false;
408         }
409 
insert_solutionsolve_eqs_tactic::imp410         void insert_solution(goal const& g, unsigned idx, expr* f, app* var, expr* def, proof* pr) {
411 
412             if (!is_safe(var))
413                 return;
414             m_vars.push_back(var);
415             m_candidates.push_back(f);
416             m_candidate_set.mark(f);
417             m_candidate_vars.mark(var);
418             m_marked_candidates.push_back(f);
419             if (m_produce_proofs) {
420                 if (!pr)
421                     pr = g.pr(idx);
422                 else
423                     pr = m().mk_modus_ponens(g.pr(idx), pr);
424             }
425             m_subst->insert(var, def, pr, g.dep(idx));
426         }
427 
428         /**
429            \brief Start collecting candidates
430         */
collectsolve_eqs_tactic::imp431         void collect(goal const & g) {
432             m_subst->reset();
433             m_norm_subst->reset();
434             m_r->set_substitution(nullptr);
435             m_candidate_vars.reset();
436             m_candidate_set.reset();
437             m_candidates.reset();
438             m_marked_candidates.reset();
439             m_vars.reset();
440             m_nonzero.reset();
441             app_ref  var(m());
442             expr_ref  def(m());
443             proof_ref pr(m());
444             unsigned size = g.size();
445             for (unsigned idx = 0; idx < size; idx++) {
446                 add_pos(g.form(idx));
447             }
448             for (unsigned idx = 0; idx < size; idx++) {
449                 checkpoint();
450                 expr * f = g.form(idx);
451                 pr = nullptr;
452                 if (solve(f, var, def, pr)) {
453                     insert_solution(g, idx, f, var, def, pr);
454                 }
455                 m_num_steps++;
456             }
457 
458             TRACE("solve_eqs",
459                   tout << "candidate vars:\n";
460                   for (app* v : m_vars) {
461                       tout << mk_ismt2_pp(v, m()) << " ";
462                   }
463                   tout << "\n";);
464         }
465 
466         struct nnf_context {
467             bool m_is_and;
468             expr_ref_vector m_args;
469             unsigned m_index;
nnf_contextsolve_eqs_tactic::imp::nnf_context470             nnf_context(bool is_and, expr_ref_vector const& args, unsigned idx):
471                 m_is_and(is_and),
472                 m_args(args),
473                 m_index(idx)
474             {}
475         };
476 
477         ptr_vector<expr> m_todo;
mark_occurssolve_eqs_tactic::imp478         void mark_occurs(expr_mark& occ, goal const& g, expr* v) {
479             expr_fast_mark2 visited;
480             occ.mark(v, true);
481             visited.mark(v, true);
482             for (unsigned j = 0; j < g.size(); ++j) {
483                 m_todo.push_back(g.form(j));
484             }
485             while (!m_todo.empty()) {
486                 expr* e = m_todo.back();
487                 if (visited.is_marked(e)) {
488                     m_todo.pop_back();
489                     continue;
490                 }
491                 if (is_app(e)) {
492                     bool does_occur = false;
493                     bool all_visited = true;
494                     for (expr* arg : *to_app(e)) {
495                         if (!visited.is_marked(arg)) {
496                             m_todo.push_back(arg);
497                             all_visited = false;
498                         }
499                         else {
500                             does_occur |= occ.is_marked(arg);
501                         }
502                     }
503                     if (all_visited) {
504                         occ.mark(e, does_occur);
505                         visited.mark(e, true);
506                         m_todo.pop_back();
507                     }
508                 }
509                 else if (is_quantifier(e)) {
510                     expr* body = to_quantifier(e)->get_expr();
511                     if (visited.is_marked(body)) {
512                         visited.mark(e, true);
513                         occ.mark(e, occ.is_marked(body));
514                         m_todo.pop_back();
515                     }
516                     else {
517                         m_todo.push_back(body);
518                     }
519                 }
520                 else {
521                     visited.mark(e, true);
522                     m_todo.pop_back();
523                 }
524             }
525         }
526 
is_compatiblesolve_eqs_tactic::imp527         bool is_compatible(goal const& g, unsigned idx, vector<nnf_context> const & path, expr* v, expr* eq) {
528             expr_mark occ;
529             svector<lbool> cache;
530             mark_occurs(occ, g, v);
531             return is_goal_compatible(g, occ, cache, idx, v, eq) && is_path_compatible(occ, cache, path, v, eq);
532         }
533 
is_goal_compatiblesolve_eqs_tactic::imp534         bool is_goal_compatible(goal const& g, expr_mark& occ, svector<lbool>& cache, unsigned idx, expr* v, expr* eq) {
535             bool all_e = false;
536             for (unsigned j = 0; j < g.size(); ++j) {
537                 if (j != idx && !check_eq_compat_rec(occ, cache, g.form(j), v, eq, all_e)) {
538                     TRACE("solve_eqs", tout << "occurs goal " << mk_pp(eq, m()) << "\n";);
539                     return false;
540                 }
541             }
542             return true;
543         }
544 
545         //
546         // all_e := all disjunctions contain eq
547         //
548         // or, all_e -> skip if all disjunctions contain eq
549         // or, all_e -> fail if some disjunction contains v but not eq
550         // or, all_e -> all_e := false if some disjunction does not contain v
551         // and, all_e -> all_e
552         //
553 
is_path_compatiblesolve_eqs_tactic::imp554         bool is_path_compatible(expr_mark& occ, svector<lbool>& cache, vector<nnf_context> const & path, expr* v, expr* eq) {
555             bool all_e = true;
556             auto is_marked = [&](expr* e) {
557                 if (occ.is_marked(e))
558                     return true;
559                 if (m().is_not(e, e) && occ.is_marked(e))
560                     return true;
561                 return false;
562             };
563             for (unsigned i = path.size(); i-- > 0; ) {
564                 auto const& p = path[i];
565                 auto const& args = p.m_args;
566                 if (p.m_is_and && !all_e) {
567                     for (unsigned j = 0; j < args.size(); ++j) {
568                         if (j != p.m_index && is_marked(args[j])) {
569                             TRACE("solve_eqs", tout << "occurs and " << mk_pp(eq, m()) << " " << mk_pp(args[j], m()) << "\n";);
570                             return false;
571                         }
572                     }
573                 }
574                 else if (!p.m_is_and) {
575                     for (unsigned j = 0; j < args.size(); ++j) {
576                         if (j != p.m_index) {
577                             if (occurs(v, args[j])) {
578                                 if (!check_eq_compat_rec(occ, cache, args[j], v, eq, all_e)) {
579                                     TRACE("solve_eqs", tout << "occurs or " << mk_pp(eq, m()) << " " << mk_pp(args[j], m()) << "\n";);
580                                     return false;
581                                 }
582                             }
583                             else {
584                                 all_e = false;
585                             }
586                         }
587                     }
588                 }
589             }
590             return true;
591         }
592 
check_eq_compat_recsolve_eqs_tactic::imp593         bool check_eq_compat_rec(expr_mark& occ, svector<lbool>& cache, expr* f, expr* v, expr* eq, bool& all) {
594             expr_ref_vector args(m());
595             expr* f1 = nullptr;
596             // flattening may introduce fresh negations,
597             // occ is not defined on these negations
598             if (!m().is_not(f) && !occ.is_marked(f)) {
599                 all = false;
600                 return true;
601             }
602             unsigned idx = f->get_id();
603             if (cache.size() > idx && cache[idx] != l_undef) {
604                 return cache[idx] == l_true;
605             }
606             if (m().is_not(f, f1) && m().is_or(f1)) {
607                 flatten_and(f, args);
608                 for (expr* arg : args) {
609                     if (arg == eq) {
610                         cache.reserve(idx+1, l_undef);
611                         cache[idx] = l_true;
612                         return true;
613                     }
614                 }
615             }
616             else if (m().is_or(f)) {
617                 flatten_or(f, args);
618             }
619             else {
620                 return false;
621             }
622 
623             for (expr* arg : args) {
624                 if (!check_eq_compat_rec(occ, cache, arg, v, eq, all)) {
625                     cache.reserve(idx+1, l_undef);
626                     cache[idx] = l_false;
627                     return false;
628                 }
629             }
630             cache.reserve(idx+1, l_undef);
631             cache[idx] = l_true;
632             return true;
633         }
634 
hoist_nnfsolve_eqs_tactic::imp635         void hoist_nnf(goal const& g, expr* f, vector<nnf_context> & path, unsigned idx, unsigned depth, ast_mark& mark) {
636             if (depth > 3 || mark.is_marked(f)) {
637                 return;
638             }
639             mark.mark(f, true);
640             checkpoint();
641             app_ref var(m());
642             expr_ref def(m());
643             proof_ref pr(m());
644             expr_ref_vector args(m());
645             expr* f1 = nullptr;
646 
647             if (m().is_not(f, f1) && m().is_or(f1)) {
648                 flatten_and(f, args);
649                 for (unsigned i = 0; i < args.size(); ++i) {
650                     pr = nullptr;
651                     expr* arg = args.get(i), *lhs = nullptr, *rhs = nullptr;
652                     if (m().is_eq(arg, lhs, rhs)) {
653                         if (trivial_solve1(lhs, rhs, var, def, pr) && is_compatible(g, idx, path, var, arg)) {
654                             insert_solution(g, idx, arg, var, def, pr);
655                         }
656                         else if (trivial_solve1(rhs, lhs, var, def, pr) && is_compatible(g, idx, path, var, arg)) {
657                             insert_solution(g, idx, arg, var, def, pr);
658                         }
659                         else {
660                             IF_VERBOSE(10000,
661                                        verbose_stream() << "eq not solved " << mk_pp(arg, m()) << "\n";
662                                        verbose_stream() << is_uninterp_const(lhs) << " " << !m_candidate_vars.is_marked(lhs) << " "
663                                        << !occurs(lhs, rhs) << " " << check_occs(lhs) << "\n";);
664                         }
665                     }
666                     else {
667                         path.push_back(nnf_context(true, args, i));
668                         hoist_nnf(g, arg, path, idx, depth + 1, mark);
669                         path.pop_back();
670                     }
671                 }
672             }
673             else if (m().is_or(f)) {
674                 flatten_or(f, args);
675                 for (unsigned i = 0; i < args.size(); ++i) {
676                     path.push_back(nnf_context(false, args, i));
677                     hoist_nnf(g, args.get(i), path, idx, depth + 1, mark);
678                     path.pop_back();
679                 }
680             }
681         }
682 
collect_hoistsolve_eqs_tactic::imp683         void collect_hoist(goal const& g) {
684             unsigned size = g.size();
685             ast_mark mark;
686             vector<nnf_context> path;
687             for (unsigned idx = 0; idx < size; idx++) {
688                 checkpoint();
689                 hoist_nnf(g, g.form(idx), path, idx, 0, mark);
690             }
691         }
692 
distribute_and_orsolve_eqs_tactic::imp693         void distribute_and_or(goal & g) {
694             if (m_produce_proofs)
695                 return;
696             unsigned size = g.size();
697             hoist_rewriter_star rw(m());
698             th_rewriter thrw(m());
699             expr_ref tmp(m()), tmp2(m());
700 
701             TRACE("solve_eqs", g.display(tout););
702             for (unsigned idx = 0; !g.inconsistent() && idx < size; idx++) {
703                 checkpoint();
704                 if (g.is_decided_unsat()) break;
705                 expr* f = g.form(idx);
706                 proof_ref pr1(m()), pr2(m());
707                 thrw(f, tmp, pr1);
708                 rw(tmp, tmp2, pr2);
709                 TRACE("solve_eqs", tout << mk_pp(f, m()) << "\n->\n" << tmp << "\n->\n" << tmp2
710                       << "\n" << pr1 << "\n" << pr2 << "\n" << mk_pp(g.pr(idx), m()) << "\n";);
711                 pr1 = m().mk_transitivity(pr1, pr2);
712                 if (!pr1) pr1 = g.pr(idx); else pr1 = m().mk_modus_ponens(g.pr(idx), pr1);
713                 g.update(idx, tmp2, pr1, g.dep(idx));
714             }
715         }
716 
717         expr_mark m_unsafe_vars;
718 
filter_unsafe_varssolve_eqs_tactic::imp719         void filter_unsafe_vars() {
720             m_unsafe_vars.reset();
721             recfun::util rec(m());
722             for (func_decl* f : rec.get_rec_funs())
723                 for (expr* term : subterms::all(expr_ref(rec.get_def(f).get_rhs(), m())))
724                     m_unsafe_vars.mark(term);
725         }
726 
is_safesolve_eqs_tactic::imp727         bool is_safe(expr* f) {
728             return !m_unsafe_vars.is_marked(f);
729         }
730 
sort_varssolve_eqs_tactic::imp731         void sort_vars() {
732             SASSERT(m_candidates.size() == m_vars.size());
733             TRACE("solve_eqs_bug", tout << "sorting vars...\n";);
734             m_ordered_vars.reset();
735 
736 
737             // The variables (and its definitions) in m_subst must remain alive until the end of this procedure.
738             // Reason: they are scheduled for unmarking in visiting/done.
739             // They should remain alive while they are on the stack.
740             // To make sure this is the case, whenever a variable (and its definition) is removed from m_subst,
741             // I add them to the saved vector.
742 
743             expr_ref_vector saved(m());
744 
745             expr_fast_mark1 visiting;
746             expr_fast_mark2 done;
747 
748             typedef std::pair<expr *, unsigned> frame;
749             svector<frame> todo;
750             unsigned num = 0;
751             for (app* v : m_vars) {
752                 checkpoint();
753                 if (!m_candidate_vars.is_marked(v))
754                     continue;
755                 todo.push_back(frame(v, 0));
756                 while (!todo.empty()) {
757                 start:
758                     frame & fr = todo.back();
759                     expr * t   = fr.first;
760                     m_num_steps++;
761                     TRACE("solve_eqs_bug", tout << "processing:\n" << mk_ismt2_pp(t, m()) << "\n";);
762                     if (t->get_ref_count() > 1 && done.is_marked(t)) {
763                         todo.pop_back();
764                         continue;
765                     }
766                     switch (t->get_kind()) {
767                     case AST_VAR:
768                         todo.pop_back();
769                         break;
770                     case AST_QUANTIFIER:
771                         num = to_quantifier(t)->get_num_children();
772                         while (fr.second < num) {
773                             expr * c = to_quantifier(t)->get_child(fr.second);
774                             fr.second++;
775                             if (c->get_ref_count() > 1 && done.is_marked(c))
776                                 continue;
777                             todo.push_back(frame(c, 0));
778                             goto start;
779                         }
780                         if (t->get_ref_count() > 1)
781                             done.mark(t);
782                         todo.pop_back();
783                         break;
784                     case AST_APP:
785                         num = to_app(t)->get_num_args();
786                         if (num == 0) {
787                             if (fr.second == 0) {
788                                 if (m_candidate_vars.is_marked(t)) {
789                                     if (visiting.is_marked(t)) {
790                                         // cycle detected: remove t
791                                         visiting.reset_mark(t);
792                                         m_candidate_vars.mark(t, false);
793                                         SASSERT(!m_candidate_vars.is_marked(t));
794 
795                                         // Must save t and its definition.
796                                         // See comment in the beginning of the function
797                                         expr * def = nullptr;
798                                         proof * pr;
799                                         expr_dependency * dep;
800                                         m_subst->find(to_app(t), def, pr, dep);
801                                         SASSERT(def != 0);
802                                         saved.push_back(t);
803                                         saved.push_back(def);
804                                         //
805 
806                                         m_subst->erase(t);
807                                     }
808                                     else {
809                                         visiting.mark(t);
810                                         fr.second = 1;
811                                         expr * def = nullptr;
812                                         proof * pr;
813                                         expr_dependency * dep;
814                                         m_subst->find(to_app(t), def, pr, dep);
815                                         SASSERT(def != 0);
816                                         todo.push_back(frame(def, 0));
817                                         goto start;
818                                     }
819                                 }
820                             }
821                             else {
822                                 SASSERT(fr.second == 1);
823                                 if (m_candidate_vars.is_marked(t)) {
824                                     visiting.reset_mark(t);
825                                     m_ordered_vars.push_back(to_app(t));
826                                 }
827                                 else {
828                                     // var was removed from the list of candidate vars to elim cycle
829                                     // do nothing
830                                 }
831                             }
832                         }
833                         else {
834                             while (fr.second < num) {
835                                 expr * arg = to_app(t)->get_arg(fr.second);
836                                 fr.second++;
837                                 if (arg->get_ref_count() > 1 && done.is_marked(arg))
838                                     continue;
839                                 todo.push_back(frame(arg, 0));
840                                 goto start;
841                             }
842                         }
843                         if (t->get_ref_count() > 1)
844                             done.mark(t);
845                         todo.pop_back();
846                         break;
847                     default:
848                         UNREACHABLE();
849                         todo.pop_back();
850                         break;
851                     }
852                 }
853             }
854 
855             // cleanup
856             unsigned idx = 0;
857             for (expr* v : m_vars) {
858                 if (!m_candidate_vars.is_marked(v)) {
859                     m_candidate_set.mark(m_candidates[idx], false);
860                     m_marked_candidates.push_back(m_candidates[idx]);
861                     m_marked_candidates.push_back(v);
862                 }
863                 ++idx;
864             }
865 
866             IF_VERBOSE(10000,
867                        verbose_stream() << "ordered vars: ";
868                        for (app* v : m_ordered_vars) verbose_stream() << mk_pp(v, m()) << " ";
869                        verbose_stream() << "\n";);
870             TRACE("solve_eqs",
871                   tout << "ordered vars:\n";
872                   for (app* v : m_ordered_vars) {
873                       SASSERT(m_candidate_vars.is_marked(v));
874                       tout << mk_ismt2_pp(v, m()) << " ";
875                   }
876                   tout << "\n";);
877             m_candidate_vars.reset();
878         }
879 
normalizesolve_eqs_tactic::imp880         void normalize() {
881             m_norm_subst->reset();
882             m_r->set_substitution(m_norm_subst.get());
883 
884 
885             expr_dependency_ref new_dep(m());
886             for (app * v : m_ordered_vars) {
887                 checkpoint();
888                 expr_ref new_def(m());
889                 proof_ref new_pr(m());
890                 expr * def = nullptr;
891                 proof * pr = nullptr;
892                 expr_dependency * dep = nullptr;
893                 m_subst->find(v, def, pr, dep);
894                 SASSERT(def);
895                 m_r->operator()(def, new_def, new_pr, new_dep);
896                 m_num_steps += m_r->get_num_steps() + 1;
897                 if (m_produce_proofs)
898                     new_pr = m().mk_transitivity(pr, new_pr);
899                 new_dep = m().mk_join(dep, new_dep);
900                 m_norm_subst->insert(v, new_def, new_pr, new_dep);
901                 // we updated the substituting, but we don't need to reset m_r
902                 // because all cached values there do not depend on v.
903             }
904             m_subst->reset();
905             TRACE("solve_eqs",
906                   tout << "after normalizing variables\n";
907                   for (expr * v : m_ordered_vars) {
908                       expr * def = 0;
909                       proof * pr = 0;
910                       expr_dependency * dep = 0;
911                       m_norm_subst->find(v, def, pr, dep);
912                       tout << mk_ismt2_pp(v, m()) << "\n----->\n" << mk_ismt2_pp(def, m()) << "\n\n";
913                   });
914         }
915 
substitutesolve_eqs_tactic::imp916         void substitute(goal & g) {
917             // force the cache of m_r to be reset.
918             m_r->set_substitution(m_norm_subst.get());
919 
920             expr_ref new_f(m());
921             proof_ref new_pr(m());
922             expr_dependency_ref new_dep(m());
923             unsigned size = g.size();
924             for (unsigned idx = 0; idx < size; idx++) {
925                 checkpoint();
926                 expr * f = g.form(idx);
927                 TRACE("gaussian_leak", tout << "processing:\n" << mk_ismt2_pp(f, m()) << "\n";);
928                 if (m_candidate_set.is_marked(f)) {
929                     m_marked_candidates.push_back(f);
930                     // f may be deleted after the following update.
931                     // so, we must remove the mark before doing the update
932                     m_candidate_set.mark(f, false);
933                     SASSERT(!m_candidate_set.is_marked(f));
934                     g.update(idx, m().mk_true(), m().mk_true_proof(), nullptr);
935                     m_num_steps ++;
936                     continue;
937                 }
938 
939                 m_r->operator()(f, new_f, new_pr, new_dep);
940 
941                 TRACE("solve_eqs_subst", tout << mk_ismt2_pp(f, m()) << "\n--->\n" << mk_ismt2_pp(new_f, m()) << "\n";);
942                 m_num_steps += m_r->get_num_steps() + 1;
943                 if (m_produce_proofs)
944                     new_pr = m().mk_modus_ponens(g.pr(idx), new_pr);
945                 if (m_produce_unsat_cores)
946                     new_dep = m().mk_join(g.dep(idx), new_dep);
947 
948                 g.update(idx, new_f, new_pr, new_dep);
949                 if (g.inconsistent())
950                     return;
951             }
952             g.elim_true();
953             TRACE("solve_eqs", g.display(tout << "after applying substitution\n"););
954 #if 0
955             DEBUG_CODE({
956                     for (expr* v : m_ordered_vars) {
957                         for (unsigned j = 0; j < g.size(); j++) {
958                             CASSERT("solve_eqs_bug", !occurs(v, g.form(j)));
959                         }
960                     }});
961 #endif
962         }
963 
save_elim_varssolve_eqs_tactic::imp964         void save_elim_vars(model_converter_ref & mc) {
965             IF_VERBOSE(100, if (!m_ordered_vars.empty()) verbose_stream() << "num. eliminated vars: " << m_ordered_vars.size() << "\n";);
966             m_num_eliminated_vars += m_ordered_vars.size();
967             if (m_produce_models) {
968                 if (!mc.get())
969                     mc = alloc(gmc, m(), "solve-eqs");
970                 for (app* v : m_ordered_vars) {
971                     expr * def = nullptr;
972                     proof * pr;
973                     expr_dependency * dep = nullptr;
974                     m_norm_subst->find(v, def, pr, dep);
975                     SASSERT(def);
976                     static_cast<gmc*>(mc.get())->add(v, def);
977                 }
978             }
979         }
980 
collect_num_occssolve_eqs_tactic::imp981         void collect_num_occs(expr * t, expr_fast_mark1 & visited) {
982             ptr_buffer<app, 128> stack;
983 
984             auto visit = [&](expr* arg) {
985                 if (is_uninterp_const(arg)) {
986                     m_num_occs.insert_if_not_there(arg, 0)++;
987                 }
988                 if (!visited.is_marked(arg) && is_app(arg)) {
989                     visited.mark(arg, true);
990                     stack.push_back(to_app(arg));
991                 }
992             };
993 
994             visit(t);
995 
996             while (!stack.empty()) {
997                 app * t = stack.back();
998                 stack.pop_back();
999                 for (expr* arg : *t)
1000                     visit(arg);
1001             }
1002         }
1003 
collect_num_occssolve_eqs_tactic::imp1004         void collect_num_occs(goal const & g) {
1005             if (m_max_occs == UINT_MAX)
1006                 return; // no need to compute num occs
1007             m_num_occs.reset();
1008             expr_fast_mark1 visited;
1009             unsigned sz = g.size();
1010             for (unsigned i = 0; i < sz; i++)
1011                 collect_num_occs(g.form(i), visited);
1012         }
1013 
get_num_stepssolve_eqs_tactic::imp1014         unsigned get_num_steps() const {
1015             return m_num_steps;
1016         }
1017 
get_num_eliminated_varssolve_eqs_tactic::imp1018         unsigned get_num_eliminated_vars() const {
1019             return m_num_eliminated_vars;
1020         }
1021 
1022         //
1023         // TBD: rewrite the tactic to first apply a topological sorting that
1024         // approximates the dependencies between variables. Then apply
1025         // simplification on top of this sorting, so that it can apply sub-quadratic
1026         // equality and unit propagation.
1027         //
operator ()solve_eqs_tactic::imp1028         void operator()(goal_ref const & g, goal_ref_buffer & result) {
1029             model_converter_ref mc;
1030             tactic_report report("solve_eqs", *g);
1031             TRACE("goal", g->display(tout););
1032             m_produce_models = g->models_enabled();
1033             m_produce_proofs = g->proofs_enabled();
1034             m_produce_unsat_cores = g->unsat_core_enabled();
1035 
1036             if (!g->inconsistent()) {
1037                 m_subst      = alloc(expr_substitution, m(), m_produce_unsat_cores, m_produce_proofs);
1038                 m_norm_subst = alloc(expr_substitution, m(), m_produce_unsat_cores, m_produce_proofs);
1039                 unsigned rounds = 0;
1040 
1041                 filter_unsafe_vars();
1042                 while (rounds < 20) {
1043                     ++rounds;
1044                     if (!m_produce_proofs && m_context_solve && rounds < 3) {
1045                         distribute_and_or(*(g.get()));
1046                     }
1047                     collect_num_occs(*g);
1048                     collect(*g);
1049                     if (!m_produce_proofs && m_context_solve && rounds < 3) {
1050                         collect_hoist(*g);
1051                     }
1052                     if (m_subst->empty()) {
1053                         break;
1054                     }
1055                     sort_vars();
1056                     if (m_ordered_vars.empty()) {
1057                         break;
1058                     }
1059                     normalize();
1060                     substitute(*(g.get()));
1061                     if (g->inconsistent()) {
1062                         break;
1063                     }
1064                     save_elim_vars(mc);
1065                     TRACE("solve_eqs_round", g->display(tout); if (mc) mc->display(tout););
1066                     if (rounds > 10 && m_ordered_vars.size() == 1)
1067                         break;
1068                 }
1069             }
1070             g->inc_depth();
1071             g->add(mc.get());
1072             result.push_back(g.get());
1073         }
1074     };
1075 
1076     imp *      m_imp;
1077     params_ref m_params;
1078 public:
solve_eqs_tactic(ast_manager & m,params_ref const & p,expr_replacer * r,bool owner)1079     solve_eqs_tactic(ast_manager & m, params_ref const & p, expr_replacer * r, bool owner):
1080         m_params(p) {
1081         m_imp = alloc(imp, m, p, r, owner);
1082     }
1083 
translate(ast_manager & m)1084     tactic * translate(ast_manager & m) override {
1085         return alloc(solve_eqs_tactic, m, m_params, mk_expr_simp_replacer(m, m_params), true);
1086     }
1087 
~solve_eqs_tactic()1088     ~solve_eqs_tactic() override {
1089         dealloc(m_imp);
1090     }
1091 
updt_params(params_ref const & p)1092     void updt_params(params_ref const & p) override {
1093         m_params = p;
1094         m_imp->updt_params(p);
1095     }
1096 
collect_param_descrs(param_descrs & r)1097     void collect_param_descrs(param_descrs & r) override {
1098         r.insert("solve_eqs_max_occs", CPK_UINT, "(default: infty) maximum number of occurrences for considering a variable for gaussian eliminations.");
1099         r.insert("theory_solver", CPK_BOOL, "(default: true) use theory solvers.");
1100         r.insert("ite_solver", CPK_BOOL, "(default: true) use if-then-else solver.");
1101         r.insert("context_solve", CPK_BOOL, "(default: false) solve equalities under disjunctions.");
1102     }
1103 
operator ()(goal_ref const & in,goal_ref_buffer & result)1104     void operator()(goal_ref const & in,
1105                     goal_ref_buffer & result) override {
1106         (*m_imp)(in, result);
1107         report_tactic_progress(":num-elim-vars", m_imp->get_num_eliminated_vars());
1108     }
1109 
cleanup()1110     void cleanup() override {
1111         unsigned num_elim_vars = m_imp->m_num_eliminated_vars;
1112         ast_manager & m = m_imp->m();
1113         expr_replacer * r = m_imp->m_r;
1114         if (r)
1115             r->set_substitution(nullptr);
1116         bool owner = m_imp->m_r_owner;
1117         m_imp->m_r_owner  = false; // stole replacer
1118 
1119         imp * d = alloc(imp, m, m_params, r, owner);
1120         d->m_num_eliminated_vars = num_elim_vars;
1121         std::swap(d, m_imp);
1122         dealloc(d);
1123     }
1124 
collect_statistics(statistics & st) const1125     void collect_statistics(statistics & st) const override {
1126         st.update("eliminated vars", m_imp->get_num_eliminated_vars());
1127     }
1128 
reset_statistics()1129     void reset_statistics() override {
1130         m_imp->m_num_eliminated_vars = 0;
1131     }
1132 
1133 };
1134 
mk_solve_eqs_tactic(ast_manager & m,params_ref const & p,expr_replacer * r)1135 tactic * mk_solve_eqs_tactic(ast_manager & m, params_ref const & p, expr_replacer * r) {
1136     if (r == nullptr)
1137         return clean(alloc(solve_eqs_tactic, m, p, mk_expr_simp_replacer(m, p), true));
1138     else
1139         return clean(alloc(solve_eqs_tactic, m, p, r, false));
1140 }
1141