1 /*++
2 Copyright (c) 2018 Microsoft Corporation, Simon Cruanes
3 
4 Module Name:
5 
6     theory_recfun.cpp
7 
8 Abstract:
9 
10     Theory responsible for unrolling recursive functions
11 
12 Author:
13 
14     Simon Cruanes December 2017
15 
16 Revision History:
17 
18 --*/
19 
20 #include "util/stats.h"
21 #include "ast/ast_util.h"
22 #include "ast/for_each_expr.h"
23 #include "smt/theory_recfun.h"
24 
25 
26 #define TRACEFN(x) TRACE("recfun", tout << x << '\n';)
27 
28 namespace smt {
29 
theory_recfun(context & ctx)30     theory_recfun::theory_recfun(context& ctx)
31         : theory(ctx, ctx.get_manager().mk_family_id("recfun")),
32           m_plugin(*reinterpret_cast<recfun::decl::plugin*>(m.get_plugin(get_family_id()))),
33           m_util(m_plugin.u()),
34           m_disabled_guards(m),
35           m_enabled_guards(m),
36           m_preds(m),
37           m_num_rounds(0),
38           m_q_case_expand(),
39           m_q_body_expand() {
40         m_num_rounds = 0;
41         }
42 
~theory_recfun()43     theory_recfun::~theory_recfun() {
44         reset_eh();
45     }
46 
get_name() const47     char const * theory_recfun::get_name() const { return "recfun"; }
48 
mk_fresh(context * new_ctx)49     theory* theory_recfun::mk_fresh(context* new_ctx) {
50         return alloc(theory_recfun, *new_ctx);
51     }
52 
init_search_eh()53     void theory_recfun::init_search_eh() {
54     }
55 
internalize_atom(app * atom,bool gate_ctx)56     bool theory_recfun::internalize_atom(app * atom, bool gate_ctx) {
57         force_push();
58         TRACEFN(mk_pp(atom, m));
59         if (!u().has_defs()) {
60             return false;
61         }
62         for (expr * arg : *atom) {
63             ctx.internalize(arg, false);
64         }
65         if (!ctx.e_internalized(atom)) {
66             ctx.mk_enode(atom, false, true, false);
67         }
68         if (!ctx.b_internalized(atom)) {
69             bool_var v = ctx.mk_bool_var(atom);
70             ctx.set_var_theory(v, get_id());
71         }
72         if (!ctx.relevancy() && u().is_defined(atom)) {
73             push_case_expand(alloc(case_expansion, u(), atom));
74         }
75         return true;
76     }
77 
internalize_term(app * term)78     bool theory_recfun::internalize_term(app * term) {
79         force_push();
80         if (!u().has_defs()) {
81             return false;
82         }
83         for (expr* e : *term) {
84             ctx.internalize(e, false);
85         }
86         // the internalization of the arguments may have triggered the internalization of term.
87         if (!ctx.e_internalized(term)) {
88             ctx.mk_enode(term, false, false, true);
89             if (!ctx.relevancy() && u().is_defined(term)) {
90                 push_case_expand(alloc(case_expansion, u(), term));
91             }
92         }
93 
94         return true;
95     }
96 
reset_queues()97     void theory_recfun::reset_queues() {
98         for (auto* e : m_q_case_expand) {
99             dealloc(e);
100         }
101         m_q_case_expand.reset();
102         for (auto* e : m_q_body_expand) {
103             dealloc(e);
104         }
105         m_q_body_expand.reset();
106         m_q_clauses.clear();
107     }
108 
reset_eh()109     void theory_recfun::reset_eh() {
110         reset_queues();
111         m_stats.reset();
112         theory::reset_eh();
113         m_disabled_guards.reset();
114         m_enabled_guards.reset();
115         m_q_guards.reset();
116         for (auto & kv : m_guard2pending) {
117             dealloc(kv.m_value);
118         }
119         m_guard2pending.reset();
120     }
121 
122     /*
123      * when `n` becomes relevant, if it's `f(t1...tn)` with `f` defined,
124      * then case-expand `n`. If it's a macro we can also immediately
125      * body-expand it.
126      */
relevant_eh(app * n)127     void theory_recfun::relevant_eh(app * n) {
128         SASSERT(ctx.relevancy());
129         TRACEFN("relevant_eh: (defined) " <<  u().is_defined(n) << " " << mk_pp(n, m));
130         if (u().is_defined(n) && u().has_defs()) {
131             push_case_expand(alloc(case_expansion, u(), n));
132         }
133     }
134 
push_scope_eh()135     void theory_recfun::push_scope_eh() {
136         if (lazy_push())
137             return;
138         theory::push_scope_eh();
139         m_preds_lim.push_back(m_preds.size());
140     }
141 
pop_scope_eh(unsigned num_scopes)142     void theory_recfun::pop_scope_eh(unsigned num_scopes) {
143         if (lazy_pop(num_scopes))
144             return;
145         theory::pop_scope_eh(num_scopes);
146         reset_queues();
147 
148         // restore depth book-keeping
149         unsigned new_lim = m_preds_lim.size()-num_scopes;
150 #if 0
151         // depth tracking of recursive unfolding is
152         // turned off when enabling this code:
153         unsigned start = m_preds_lim[new_lim];
154         for (unsigned i = start; i < m_preds.size(); ++i) {
155             m_pred_depth.remove(m_preds.get(i));
156         }
157         m_preds.resize(start);
158 #endif
159         m_preds_lim.shrink(new_lim);
160     }
161 
restart_eh()162     void theory_recfun::restart_eh() {
163         TRACEFN("restart");
164         reset_queues();
165         theory::restart_eh();
166     }
167 
can_propagate()168     bool theory_recfun::can_propagate() {
169         return
170             !m_q_case_expand.empty() ||
171             !m_q_body_expand.empty() ||
172             !m_q_clauses.empty() ||
173             !m_q_guards.empty();
174     }
175 
propagate()176     void theory_recfun::propagate() {
177 
178         for (expr* g : m_q_guards) {
179             expr* ng = nullptr;
180             VERIFY(m.is_not(g, ng));
181             activate_guard(ng, *m_guard2pending[g]);
182         }
183         m_q_guards.reset();
184 
185         for (literal_vector & c : m_q_clauses) {
186             TRACEFN("add axiom " << pp_lits(ctx, c));
187             ctx.mk_th_axiom(get_id(), c);
188         }
189         m_q_clauses.clear();
190 
191         for (unsigned i = 0; i < m_q_case_expand.size(); ++i) {
192             case_expansion* e = m_q_case_expand[i];
193             if (e->m_def->is_fun_macro()) {
194                 // body expand immediately
195                 assert_macro_axiom(*e);
196             }
197             else {
198                 // case expand
199                 SASSERT(e->m_def->is_fun_defined());
200                 assert_case_axioms(*e);
201             }
202             dealloc(e);
203             m_q_case_expand[i] = nullptr;
204         }
205         m_stats.m_case_expansions += m_q_case_expand.size();
206         m_q_case_expand.reset();
207 
208         for (unsigned i = 0; i < m_q_body_expand.size(); ++i) {
209             assert_body_axiom(*m_q_body_expand[i]);
210             dealloc(m_q_body_expand[i]);
211             m_q_body_expand[i] = nullptr;
212         }
213         m_stats.m_body_expansions += m_q_body_expand.size();
214         m_q_body_expand.reset();
215     }
216 
217     /**
218      * make clause `depth_limit => ~guard`
219      * the guard appears at a depth below the current cutoff.
220      */
disable_guard(expr * guard,expr_ref_vector const & guards)221     void theory_recfun::disable_guard(expr* guard, expr_ref_vector const& guards) {
222         expr_ref nguard(m.mk_not(guard), m);
223         if (is_disabled_guard(nguard))
224             return;
225         SASSERT(!is_enabled_guard(nguard));
226         literal_vector c;
227         app_ref dlimit = m_util.mk_num_rounds_pred(m_num_rounds);
228         c.push_back(~mk_literal(dlimit));
229         c.push_back(~mk_literal(guard));
230         m_disabled_guards.push_back(nguard);
231         SASSERT(!m_guard2pending.contains(nguard));
232         m_guard2pending.insert(nguard, alloc(expr_ref_vector, guards));
233         TRACEFN("add clause\n" << pp_lits(ctx, c));
234         m_q_clauses.push_back(std::move(c));
235     }
236 
237     /**
238      * retrieve depth associated with predicate or expression.
239      */
get_depth(expr * e)240     unsigned theory_recfun::get_depth(expr* e) {
241         SASSERT(u().is_defined(e) || u().is_case_pred(e));
242         unsigned d = 0;
243         m_pred_depth.find(e, d);
244         return d;
245     }
246 
247     /**
248      * Update depth of subterms of e with respect to d.
249      */
set_depth_rec(unsigned d,expr * e)250     void theory_recfun::set_depth_rec(unsigned d, expr* e) {
251         struct insert_c {
252             theory_recfun& th;
253             unsigned m_depth;
254             insert_c(theory_recfun& th, unsigned d): th(th), m_depth(d) {}
255             void operator()(app* e) { th.set_depth(m_depth, e); }
256             void operator()(quantifier*) {}
257             void operator()(var*) {}
258         };
259         insert_c proc(*this, d);
260         for_each_expr(proc, e);
261     }
262 
set_depth(unsigned depth,expr * e)263     void theory_recfun::set_depth(unsigned depth, expr* e) {
264         if ((u().is_defined(e) || u().is_case_pred(e)) && !m_pred_depth.contains(e)) {
265             m_pred_depth.insert(e, depth);
266             m_preds.push_back(e);
267         }
268     }
269 
270     /**
271      * if `is_true` and `v = C_f_i(t1...tn)`,
272      *    then body-expand i-th case of `f(t1...tn)`
273      */
assign_eh(bool_var v,bool is_true)274     void theory_recfun::assign_eh(bool_var v, bool is_true) {
275         expr* e = ctx.bool_var2expr(v);
276         if (is_true && u().is_case_pred(e)) {
277             TRACEFN("assign_case_pred_true " << mk_pp(e, m));
278             // body-expand
279             push_body_expand(alloc(body_expansion, u(), to_app(e)));
280         }
281     }
282 
283      // replace `vars` by `args` in `e`
apply_args(unsigned depth,recfun::vars const & vars,ptr_vector<expr> const & args,expr * e)284     expr_ref theory_recfun::apply_args(
285         unsigned depth,
286         recfun::vars const & vars,
287         ptr_vector<expr> const & args,
288         expr * e) {
289         SASSERT(is_standard_order(vars));
290         var_subst subst(m, true);
291         expr_ref new_body(m);
292         new_body = subst(e, args.size(), args.c_ptr());
293         ctx.get_rewriter()(new_body); // simplify
294         set_depth_rec(depth + 1, new_body);
295         return new_body;
296     }
297 
mk_literal(expr * e)298     literal theory_recfun::mk_literal(expr* e) {
299         ctx.internalize(e, false);
300         literal lit = ctx.get_literal(e);
301         ctx.mark_as_relevant(lit);
302         return lit;
303     }
304 
mk_eq_lit(expr * l,expr * r)305     literal theory_recfun::mk_eq_lit(expr* l, expr* r) {
306         literal lit;
307         if (m.is_true(r) || m.is_false(r)) {
308             std::swap(l, r);
309         }
310         if (m.is_true(l)) {
311             lit = mk_literal(r);
312         }
313         else if (m.is_false(l)) {
314             lit = ~mk_literal(r);
315         }
316         else {
317             lit = mk_eq(l, r, false);
318         }
319         ctx.mark_as_relevant(lit);
320         return lit;
321     }
322 
323     /**
324      * For functions f(args) that are given as macros f(vs) = rhs
325      *
326      * 1. substitute `e.args` for `vs` into the macro rhs
327      * 2. add unit clause `f(args) = rhs`
328      */
assert_macro_axiom(case_expansion & e)329     void theory_recfun::assert_macro_axiom(case_expansion & e) {
330         m_stats.m_macro_expansions++;
331         TRACEFN("case expansion " << pp_case_expansion(e, m));
332         SASSERT(e.m_def->is_fun_macro());
333         auto & vars = e.m_def->get_vars();
334         expr_ref lhs(e.m_lhs, m);
335         unsigned depth = get_depth(e.m_lhs);
336         expr_ref rhs(apply_args(depth, vars, e.m_args, e.m_def->get_rhs()), m);
337         literal lit = mk_eq_lit(lhs, rhs);
338         std::function<literal(void)> fn = [&]() { return lit; };
339         scoped_trace_stream _tr(*this, fn);
340         ctx.mk_th_axiom(get_id(), 1, &lit);
341         TRACEFN("macro expansion yields " << pp_lit(ctx, lit));
342     }
343 
344     /**
345      * Add case axioms for every case expansion path.
346      *
347      * assert `p(args) <=> And(guards)` (with CNF on the fly)
348      *
349      * also body-expand paths that do not depend on any defined fun
350      */
assert_case_axioms(case_expansion & e)351     void theory_recfun::assert_case_axioms(case_expansion & e) {
352         TRACEFN("assert_case_axioms "<< pp_case_expansion(e,m)
353                 << " with " << e.m_def->get_cases().size() << " cases");
354         SASSERT(e.m_def->is_fun_defined());
355         // add case-axioms for all case-paths
356         // assert this was not defined before.
357         literal_vector preds;
358         auto & vars = e.m_def->get_vars();
359 
360         unsigned max_depth = 0;
361         for (recfun::case_def const & c : e.m_def->get_cases()) {
362             // applied predicate to `args`
363             app_ref pred_applied = c.apply_case_predicate(e.m_args);
364             SASSERT(u().owns_app(pred_applied));
365             literal concl = mk_literal(pred_applied);
366             preds.push_back(concl);
367 
368             unsigned depth = get_depth(e.m_lhs);
369             set_depth(depth, pred_applied);
370             expr_ref_vector guards(m);
371             for (auto & g : c.get_guards()) {
372                 guards.push_back(apply_args(depth, vars, e.m_args, g));
373             }
374             if (c.is_immediate()) {
375                 body_expansion be(pred_applied, c, e.m_args);
376                 assert_body_axiom(be);
377             }
378             else if (!is_enabled_guard(pred_applied)) {
379                 disable_guard(pred_applied, guards);
380                 max_depth = std::max(depth, max_depth);
381                 continue;
382             }
383             activate_guard(pred_applied, guards);
384         }
385         // the disjunction of branches is asserted
386         // to close the available cases.
387         {
388             scoped_trace_stream _tr2(*this, preds);
389             ctx.mk_th_axiom(get_id(), preds);
390         }
391         (void)max_depth;
392         // add_induction_lemmas(max_depth);
393     }
394 
add_induction_lemmas(unsigned depth)395     void theory_recfun::add_induction_lemmas(unsigned depth) {
396         if (depth > 4 && ctx.get_fparams().m_induction && induction::should_try(ctx)) {
397             ctx.get_induction()();
398         }
399     }
400 
activate_guard(expr * pred_applied,expr_ref_vector const & guards)401     void theory_recfun::activate_guard(expr* pred_applied, expr_ref_vector const& guards) {
402         literal concl = mk_literal(pred_applied);
403         literal_vector lguards;
404         lguards.push_back(concl);
405         for (expr* ga : guards) {
406             literal guard = mk_literal(ga);
407             lguards.push_back(~guard);
408             literal c[2] = {~concl, guard};
409             std::function<literal_vector(void)> fn = [&]() { return literal_vector(2, c); };
410             scoped_trace_stream _tr(*this, fn);
411             ctx.mk_th_axiom(get_id(), 2, c);
412         }
413         std::function<literal_vector(void)> fn1 = [&]() { return lguards; };
414         scoped_trace_stream _tr1(*this, fn1);
415         ctx.mk_th_axiom(get_id(), lguards);
416     }
417 
418     /**
419      * For a guarded definition guards => f(vars) = rhs
420      * and occurrence f(args)
421      *
422      * substitute `args` for `vars` in guards, and rhs
423      * add axiom guards[args/vars] => f(args) = rhs[args/vars]
424      *
425      */
assert_body_axiom(body_expansion & e)426     void theory_recfun::assert_body_axiom(body_expansion & e) {
427         recfun::def & d = *e.m_cdef->get_def();
428         auto & vars = d.get_vars();
429         auto & args = e.m_args;
430         SASSERT(is_standard_order(vars));
431         unsigned depth = get_depth(e.m_pred);
432         expr_ref lhs(u().mk_fun_defined(d, args), m);
433         expr_ref rhs = apply_args(depth, vars, args, e.m_cdef->get_rhs());
434         literal_vector clause;
435         for (auto & g : e.m_cdef->get_guards()) {
436             expr_ref guard = apply_args(depth, vars, args, g);
437             clause.push_back(~mk_literal(guard));
438             if (clause.back() == true_literal) {
439                 TRACEFN("body " << pp_body_expansion(e,m) << "\n" << clause << "\n" << guard);
440                 return;
441             }
442             if (clause.back() == false_literal) {
443                 clause.pop_back();
444             }
445         }
446         clause.push_back(mk_eq_lit(lhs, rhs));
447         std::function<literal_vector(void)> fn = [&]() { return clause; };
448         scoped_trace_stream _tr(*this, fn);
449         ctx.mk_th_axiom(get_id(), clause);
450         TRACEFN("body " << pp_body_expansion(e,m));
451         TRACEFN(pp_lits(ctx, clause));
452     }
453 
final_check_eh()454     final_check_status theory_recfun::final_check_eh() {
455         TRACEFN("final\n");
456         if (can_propagate()) {
457             propagate();
458             return FC_CONTINUE;
459         }
460         return FC_DONE;
461     }
462 
add_theory_assumptions(expr_ref_vector & assumptions)463     void theory_recfun::add_theory_assumptions(expr_ref_vector & assumptions) {
464         if (u().has_defs() || !m_disabled_guards.empty()) {
465             app_ref dlimit = m_util.mk_num_rounds_pred(m_num_rounds);
466             TRACEFN("add_theory_assumption " << dlimit);
467             assumptions.push_back(dlimit);
468             assumptions.append(m_disabled_guards);
469         }
470     }
471 
472     // if `dlimit` or a disabled guard occurs in unsat core, return 'true'
should_research(expr_ref_vector & unsat_core)473     bool theory_recfun::should_research(expr_ref_vector & unsat_core) {
474         bool found = false;
475         expr* to_delete = nullptr;
476         unsigned n = 0;
477         unsigned current_depth = UINT_MAX;
478         for (auto & e : unsat_core) {
479             if (is_disabled_guard(e)) {
480                 found = true;
481                 expr* ne = nullptr;
482                 VERIFY(m.is_not(e, ne));
483                 unsigned depth = get_depth(ne);
484                 if (depth < current_depth)
485                     n = 0;
486                 if (depth <= current_depth && (ctx.get_random_value() % (++n)) == 0) {
487                     to_delete = e;
488                     current_depth = depth;
489                 }
490             }
491             else if (u().is_num_rounds(e)) {
492                 found = true;
493             }
494         }
495         if (found) {
496             m_num_rounds++;
497             if (to_delete) {
498                 m_disabled_guards.erase(to_delete);
499                 m_enabled_guards.push_back(to_delete);
500                 m_q_guards.push_back(to_delete);
501                 IF_VERBOSE(1, verbose_stream() << "(smt.recfun :enable-guard " << mk_pp(to_delete, m) << ")\n");
502             }
503             else {
504                 IF_VERBOSE(1, verbose_stream() << "(smt.recfun :increment-round)\n");
505             }
506         }
507         return found;
508     }
509 
display(std::ostream & out) const510     void theory_recfun::display(std::ostream & out) const {
511         out << "recfun\n";
512         out << "disabled guards:\n" << m_disabled_guards << "\n";
513     }
514 
collect_statistics(::statistics & st) const515     void theory_recfun::collect_statistics(::statistics & st) const {
516         st.update("recfun macro expansion", m_stats.m_macro_expansions);
517         st.update("recfun case expansion", m_stats.m_case_expansions);
518         st.update("recfun body expansion", m_stats.m_body_expansions);
519     }
520 
operator <<(std::ostream & out,theory_recfun::pp_case_expansion const & e)521     std::ostream& operator<<(std::ostream & out, theory_recfun::pp_case_expansion const & e) {
522         return out << "case_exp(" << mk_pp(e.e.m_lhs, e.m) << ")";
523     }
524 
operator <<(std::ostream & out,theory_recfun::pp_body_expansion const & e)525     std::ostream& operator<<(std::ostream & out, theory_recfun::pp_body_expansion const & e) {
526         out << "body_exp(" << e.e.m_cdef->get_decl()->get_name();
527         for (auto* t : e.e.m_args) {
528             out << " " << mk_pp(t,e.m);
529         }
530         return out << ")";
531     }
532 }
533