1 /*++
2   Copyright (c) 2020 Microsoft Corporation
3 
4   Module Name:
5 
6    smt_induction.cpp
7 
8   Abstract:
9 
10    Add induction lemmas to context.
11 
12   Author:
13 
14     Nikolaj Bjorner 2020-04-25
15 
16   Notes:
17 
18   - work in absence of recursive functions but instead presence of quantifiers
19     - relax current requirement of model sweeping when terms don't have value simplifications
20   - k-induction
21     - also to deal with mutually recursive datatypes
22   - beyond literal induction lemmas
23   - refine initialization of values when term is equal to constructor application,
24 
25 --*/
26 
27 #include "ast/ast_pp.h"
28 #include "ast/ast_util.h"
29 #include "ast/for_each_expr.h"
30 #include "ast/recfun_decl_plugin.h"
31 #include "ast/datatype_decl_plugin.h"
32 #include "ast/arith_decl_plugin.h"
33 #include "ast/rewriter/value_sweep.h"
34 #include "ast/rewriter/expr_safe_replace.h"
35 #include "smt/smt_context.h"
36 #include "smt/smt_induction.h"
37 
38 using namespace smt;
39 
40 /**
41  * collect literals that are assigned to true,
42  * but evaluate to false under all extensions of
43  * the congruence closure.
44  */
45 
pre_select()46 literal_vector collect_induction_literals::pre_select() {
47     literal_vector result;
48     for (unsigned i = m_literal_index; i < ctx.assigned_literals().size(); ++i) {
49         literal lit = ctx.assigned_literals()[i];
50         smt::bool_var v = lit.var();
51         if (!ctx.has_enode(v)) {
52             continue;
53         }
54         expr* e = ctx.bool_var2expr(v);
55         if (!lit.sign() && m.is_eq(e))
56             continue;
57         result.push_back(lit);
58     }
59     TRACE("induction", ctx.display(tout << "literal index: " << m_literal_index << "\n" << result << "\n"););
60 
61     ctx.push_trail(value_trail<context, unsigned>(m_literal_index));
62     m_literal_index = ctx.assigned_literals().size();
63     return result;
64 }
65 
model_sweep_filter(literal_vector & candidates)66 void collect_induction_literals::model_sweep_filter(literal_vector& candidates) {
67     expr_ref_vector terms(m);
68     for (literal lit : candidates) {
69         terms.push_back(ctx.bool_var2expr(lit.var()));
70     }
71     vector<expr_ref_vector> values;
72     vs(terms, values);
73     unsigned j = 0;
74     for (unsigned i = 0; i < terms.size(); ++i) {
75         literal lit = candidates[i];
76         bool is_viable_candidate = true;
77         for (auto const& vec : values) {
78             if (vec[i] && lit.sign() && m.is_true(vec[i]))
79                 continue;
80             if (vec[i] && !lit.sign() && m.is_false(vec[i]))
81                 continue;
82             is_viable_candidate = false;
83             break;
84         }
85         if (is_viable_candidate)
86             candidates[j++] = lit;
87     }
88     candidates.shrink(j);
89 }
90 
91 
collect_induction_literals(context & ctx,ast_manager & m,value_sweep & vs)92 collect_induction_literals::collect_induction_literals(context& ctx, ast_manager& m, value_sweep& vs):
93     ctx(ctx),
94     m(m),
95     vs(vs),
96     m_literal_index(0)
97 {
98 }
99 
operator ()()100 literal_vector collect_induction_literals::operator()() {
101     literal_vector candidates = pre_select();
102     model_sweep_filter(candidates);
103     return candidates;
104 }
105 
106 
107 // --------------------------------------
108 // induction_lemmas
109 
viable_induction_sort(sort * s)110 bool induction_lemmas::viable_induction_sort(sort* s) {
111     // potentially also induction on integers, sequences
112     return m_dt.is_datatype(s) && m_dt.is_recursive(s);
113 }
114 
viable_induction_parent(enode * p,enode * n)115 bool induction_lemmas::viable_induction_parent(enode* p, enode* n) {
116     app* o = p->get_owner();
117     return
118         m_rec.is_defined(o) ||
119         m_dt.is_constructor(o);
120 }
121 
viable_induction_children(enode * n)122 bool induction_lemmas::viable_induction_children(enode* n) {
123     app* e = n->get_owner();
124     if (m.is_value(e))
125         return false;
126     if (e->get_decl()->is_skolem())
127         return false;
128     if (n->get_num_args() == 0)
129         return true;
130     if (e->get_family_id() == m_rec.get_family_id())
131         return m_rec.is_defined(e);
132     if (e->get_family_id() == m_dt.get_family_id())
133         return m_dt.is_constructor(e);
134     return false;
135 }
136 
viable_induction_term(enode * p,enode * n)137 bool induction_lemmas::viable_induction_term(enode* p, enode* n) {
138     return
139         viable_induction_sort(m.get_sort(n->get_owner())) &&
140         viable_induction_parent(p, n) &&
141         viable_induction_children(n);
142 }
143 
144 /**
145  * positions in n that can be used for induction
146  * the positions are distinct roots
147  * and none of the roots are equivalent to a value in the current
148  * congruence closure.
149  */
induction_positions(enode * n)150 enode_vector induction_lemmas::induction_positions(enode* n) {
151     enode_vector result;
152     enode_vector todo;
153     auto add_todo = [&](enode* n) {
154         if (!n->is_marked()) {
155             n->set_mark();
156             todo.push_back(n);
157         }
158     };
159     add_todo(n);
160     for (unsigned i = 0; i < todo.size(); ++i) {
161         n = todo[i];
162         for (enode* a : smt::enode::args(n)) {
163             add_todo(a);
164             if (!a->is_marked2() && viable_induction_term(n, a)) {
165                 result.push_back(a);
166                 a->set_mark2();
167             }
168         }
169     }
170     for (enode* n : todo)
171         n->unset_mark();
172     for (enode* n : result)
173         n->unset_mark2();
174     return result;
175 }
176 
177 
178 // Collecting induction positions relative to parent.
induction_positions2(enode * n)179 induction_lemmas::induction_positions_t induction_lemmas::induction_positions2(enode* n) {
180     induction_positions_t result;
181     enode_vector todo;
182     todo.push_back(n);
183     n->set_mark();
184     for (unsigned i = 0; i < todo.size(); ++i) {
185         enode* n = todo[i];
186         unsigned idx = 0;
187         for (enode* a : smt::enode::args(n)) {
188             if (viable_induction_term(n, a)) {
189                 result.push_back(induction_position_t(n, idx));
190             }
191             if (!a->is_marked()) {
192                 a->set_mark();
193                 todo.push_back(a);
194             }
195             ++idx;
196         }
197     }
198     for (enode* n : todo)
199         n->unset_mark();
200     return result;
201 }
202 
initialize_levels(enode * n)203 void induction_lemmas::initialize_levels(enode* n) {
204     expr_ref tmp(n->get_owner(), m);
205     m_depth2terms.reset();
206     m_depth2terms.resize(get_depth(tmp) + 1);
207     m_ts++;
208     for (expr* t : subterms(tmp)) {
209         if (is_app(t)) {
210             m_depth2terms[get_depth(t)].push_back(to_app(t));
211             m_marks.reserve(t->get_id()+1, 0);
212         }
213     }
214 }
215 
induction_combinations(enode * n)216 induction_lemmas::induction_combinations_t induction_lemmas::induction_combinations(enode* n) {
217     initialize_levels(n);
218     induction_combinations_t result;
219     auto pos = induction_positions2(n);
220 
221     if (pos.size() > 6) {
222         induction_positions_t r;
223         for (auto const& p : pos) {
224             if (is_uninterp_const(p.first->get_owner()))
225                 r.push_back(p);
226         }
227         result.push_back(r);
228         return result;
229     }
230     for (unsigned i = 0; i < (1ull << pos.size()); ++i) {
231         induction_positions_t r;
232         for (unsigned j = 0; j < pos.size(); ++j) {
233             if (0 != (i & (1 << j)))
234                 r.push_back(pos[j]);
235         }
236         if (positions_dont_overlap(r))
237             result.push_back(r);
238     }
239     for (auto const& pos : result) {
240         std::cout << "position\n";
241         for (auto const& p : pos) {
242             std::cout << mk_pp(p.first->get_owner(), m) << ":" << p.second << "\n";
243         }
244     }
245     return result;
246 }
247 
positions_dont_overlap(induction_positions_t const & positions)248 bool induction_lemmas::positions_dont_overlap(induction_positions_t const& positions) {
249     if (positions.empty())
250         return false;
251     m_ts++;
252     auto mark = [&](expr* n) { m_marks[n->get_id()] = m_ts; };
253     auto is_marked = [&](expr* n) { return m_marks[n->get_id()] == m_ts; };
254     for (auto p : positions)
255         mark(p.first->get_owner());
256     // no term used for induction contains a subterm also used for induction.
257     for (auto const& terms : m_depth2terms) {
258         for (app* t : terms) {
259             bool has_mark = false;
260             for (expr* arg : *t)
261                 has_mark |= is_marked(arg);
262             if (is_marked(t) && has_mark)
263                 return false;
264             if (has_mark)
265                 mark(t);
266         }
267     }
268     return true;
269 }
270 
271 /**
272    extract substitutions for x into accessor values of the same sort.
273    collect side-conditions for the accessors to be well defined.
274    apply a depth-bounded unfolding of datatype constructors to collect
275    accessor values beyond a first level and for nested (mutually recursive)
276    datatypes.
277  */
mk_hypothesis_substs(unsigned depth,expr * x,cond_substs_t & subst)278 void induction_lemmas::mk_hypothesis_substs(unsigned depth, expr* x, cond_substs_t& subst) {
279     expr_ref_vector conds(m);
280     mk_hypothesis_substs_rec(depth, m.get_sort(x), x, conds, subst);
281 }
282 
mk_hypothesis_substs_rec(unsigned depth,sort * s,expr * y,expr_ref_vector & conds,cond_substs_t & subst)283 void induction_lemmas::mk_hypothesis_substs_rec(unsigned depth, sort* s, expr* y, expr_ref_vector& conds, cond_substs_t& subst) {
284     sort* ys = m.get_sort(y);
285     for (func_decl* c : *m_dt.get_datatype_constructors(ys)) {
286         func_decl* is_c = m_dt.get_constructor_recognizer(c);
287         conds.push_back(m.mk_app(is_c, y));
288         for (func_decl* acc : *m_dt.get_constructor_accessors(c)) {
289             sort* rs = acc->get_range();
290             if (!m_dt.is_datatype(rs) || !m_dt.is_recursive(rs))
291                 continue;
292             expr_ref acc_y(m.mk_app(acc, y), m);
293             if (rs == s) {
294                 subst.push_back(std::make_pair(conds, acc_y));
295             }
296             if (depth > 1) {
297                 mk_hypothesis_substs_rec(depth - 1, s, acc_y, conds, subst);
298             }
299         }
300         conds.pop_back();
301     }
302 }
303 
304 /*
305  * Create simple induction lemmas of the form:
306  *
307  * lit & a.eqs() => alpha
308  * alpha & is-c(sk) => ~beta
309  *
310  * where
311  *       lit   = is a formula containing t
312  *       alpha = a.term(), a variant of lit
313  *               with some occurrences of t replaced by sk
314  *       beta  = alpha[sk/access_k(sk)]
315  * for each constructor c, that is recursive
316  * and contains argument of datatype sort s
317  *
318  * The main claim is that the lemmas are valid and that
319  * they approximate induction reasoning.
320  *
321  * alpha approximates minimal instance of the datatype s where
322  * the instance of s is true. In the limit one can
323  * set beta to all instantiations of smaller values than sk.
324  *
325  */
326 
mk_hypothesis_lemma(expr_ref_vector const & conds,expr_pair_vector const & subst,literal alpha)327 void induction_lemmas::mk_hypothesis_lemma(expr_ref_vector const& conds, expr_pair_vector const& subst, literal alpha) {
328     expr_ref beta(m);
329     ctx.literal2expr(alpha, beta);
330     expr_safe_replace rep(m);
331     for (auto const& p : subst) {
332         rep.insert(p.first, p.second);
333     }
334     rep(beta);                          // set beta := alpha[sk/acc(acc2(sk))]
335     // alpha & is-c(sk) => ~alpha[sk/acc(sk)]
336     literal_vector lits;
337     lits.push_back(~alpha);
338     for (expr* c : conds) lits.push_back(~mk_literal(c));
339     lits.push_back(~mk_literal(beta));
340     add_th_lemma(lits);
341 }
342 
create_hypotheses(unsigned depth,expr_ref_vector const & sks,literal alpha)343 void induction_lemmas::create_hypotheses(unsigned depth, expr_ref_vector const& sks, literal alpha) {
344     if (sks.empty())
345         return;
346 
347     // extract hypothesis substitutions
348     vector<std::pair<expr*, cond_substs_t>> substs;
349     for (expr* sk : sks) {
350         cond_substs_t subst;
351         mk_hypothesis_substs(depth, sk, subst);
352 
353         // append the identity substitution:
354         expr_ref_vector conds(m);
355         subst.push_back(std::make_pair(conds, expr_ref(sk, m)));
356         substs.push_back(std::make_pair(sk, subst));
357     }
358 
359     // create cross-product of instantiations:
360     vector<std::pair<expr_ref_vector, expr_pair_vector>> s1, s2;
361     s1.push_back(std::make_pair(expr_ref_vector(m), expr_pair_vector()));
362     for (auto const& x2cond_sub : substs) {
363         s2.reset();
364         for (auto const& cond_sub : x2cond_sub.second) {
365             for (auto const& cond_subs : s1) {
366                 expr_pair_vector pairs(cond_subs.second);
367                 expr_ref_vector conds(cond_subs.first);
368                 pairs.push_back(std::make_pair(x2cond_sub.first, cond_sub.second));
369                 conds.append(cond_sub.first);
370                 s2.push_back(std::make_pair(conds, pairs));
371             }
372         }
373         s1.swap(s2);
374     }
375     s1.pop_back(); // last substitution is the identity
376 
377     // extract lemmas from instantiations
378     for (auto& p : s1) {
379         mk_hypothesis_lemma(p.first, p.second, alpha);
380     }
381 }
382 
383 
add_th_lemma(literal_vector const & lits)384 void induction_lemmas::add_th_lemma(literal_vector const& lits) {
385     IF_VERBOSE(0, ctx.display_literals_verbose(verbose_stream() << "lemma:\n", lits) << "\n");
386     ctx.mk_clause(lits.size(), lits.c_ptr(), nullptr, smt::CLS_TH_AXIOM);
387     // CLS_TH_LEMMA, but then should re-instance if GC'ed
388     ++m_num_lemmas;
389 }
390 
mk_literal(expr * e)391 literal induction_lemmas::mk_literal(expr* e) {
392     expr_ref _e(e, m);
393     if (!ctx.e_internalized(e)) {
394         ctx.internalize(e, false);
395     }
396     enode* n = ctx.get_enode(e);
397     ctx.mark_as_relevant(n);
398     return ctx.get_literal(e);
399 }
400 
401 
402 
operator ()(literal lit)403 bool induction_lemmas::operator()(literal lit) {
404     enode* r = ctx.bool_var2enode(lit.var());
405 
406 #if 1
407     auto combinations = induction_combinations(r);
408     for (auto const& positions : combinations) {
409         apply_induction(lit, positions);
410     }
411     return !combinations.empty();
412 #else
413     unsigned num = m_num_lemmas;
414     expr_ref_vector sks(m);
415     expr_safe_replace rep(m);
416     // have to be non-overlapping:
417     for (enode* n : induction_positions(r)) {
418         expr* t = n->get_owner();
419         if (is_uninterp_const(t)) { // for now, to avoid overlapping terms
420             sort* s = m.get_sort(t);
421             expr_ref sk(m.mk_fresh_const("sk", s), m);
422             sks.push_back(sk);
423             rep.insert(t, sk);
424         }
425     }
426     expr_ref alpha(m);
427     ctx.literal2expr(lit, alpha);
428     rep(alpha);
429     literal alpha_lit = mk_literal(alpha);
430 
431     // alpha is the minimal instance of induction_positions where lit holds
432     // alpha & is-c(sk) => ~alpha[sk/acc(sk)]
433     create_hypotheses(1, sks, alpha_lit);
434     if (m_num_lemmas == num)
435         return false;
436     // lit => alpha
437     literal_vector lits;
438     lits.push_back(~lit);
439     lits.push_back(alpha_lit);
440     add_th_lemma(lits);
441     return true;
442 #endif
443 }
444 
apply_induction(literal lit,induction_positions_t const & positions)445 void induction_lemmas::apply_induction(literal lit, induction_positions_t const & positions) {
446     unsigned num = m_num_lemmas;
447     obj_map<expr, expr*> term2skolem;
448     expr_ref alpha(m), sk(m);
449     expr_ref_vector sks(m);
450     ctx.literal2expr(lit, alpha);
451     induction_term_and_position_t itp(alpha, positions);
452     bool found = m_skolems.find(itp, itp);
453     if (found) {
454         sks.append(itp.m_skolems.size(), itp.m_skolems.c_ptr());
455     }
456 
457     unsigned i = 0;
458     for (auto const& p : positions) {
459         expr* t = p.first->get_owner()->get_arg(p.second);
460         if (term2skolem.contains(t))
461             continue;
462         if (i == sks.size()) {
463             sk = m.mk_fresh_const("sk", m.get_sort(t));
464             sks.push_back(sk);
465         }
466         else {
467             sk = sks.get(i);
468         }
469         term2skolem.insert(t, sk);
470         ++i;
471     }
472     if (!found) {
473         itp.m_skolems.append(sks.size(), sks.c_ptr());
474         m_trail.push_back(alpha);
475         m_trail.append(sks);
476         m_skolems.insert(itp);
477     }
478 
479     ptr_vector<expr> todo;
480     obj_map<expr, expr*> sub;
481     expr_ref_vector trail(m), args(m);
482     todo.push_back(alpha);
483     // replace occurrences of induction arguments.
484 #if 0
485     std::cout << "positions\n";
486     for (auto const& p : positions)
487         std::cout << mk_pp(p.first->get_owner(), m) << " " << p.second << "\n";
488 #endif
489     while (!todo.empty()) {
490         expr* t = todo.back();
491         if (sub.contains(t)) {
492             todo.pop_back();
493             continue;
494         }
495         SASSERT(is_app(t));
496         args.reset();
497         unsigned sz = todo.size();
498         expr* s = nullptr;
499         for (unsigned i = 0; i < to_app(t)->get_num_args(); ++i) {
500             expr* arg = to_app(t)->get_arg(i);
501             found = false;
502             for (auto const& p : positions) {
503                 if (p.first->get_owner() == t && p.second == i) {
504                     args.push_back(term2skolem[arg]);
505                     found = true;
506                     break;
507                 }
508             }
509             if (found)
510                 continue;
511             if (sub.find(arg, s)) {
512                 args.push_back(s);
513                 continue;
514             }
515             todo.push_back(arg);
516         }
517         if (todo.size() == sz) {
518             s = m.mk_app(to_app(t)->get_decl(), args);
519             trail.push_back(s);
520             sub.insert(t, s);
521             todo.pop_back();
522         }
523     }
524     alpha = sub[alpha];
525     std::cout << "alpha:" << alpha << "\n";
526     literal alpha_lit = mk_literal(alpha);
527 
528     // alpha is the minimal instance of induction_positions where lit holds
529     // alpha & is-c(sk) => ~alpha[sk/acc(sk)]
530     create_hypotheses(1, sks, alpha_lit);
531     if (m_num_lemmas > num) {
532         // lit => alpha
533         literal_vector lits;
534         lits.push_back(~lit);
535         lits.push_back(alpha_lit);
536         add_th_lemma(lits);
537     }
538 }
539 
induction_lemmas(context & ctx,ast_manager & m)540 induction_lemmas::induction_lemmas(context& ctx, ast_manager& m):
541     ctx(ctx),
542     m(m),
543     m_dt(m),
544     m_a(m),
545     m_rec(m),
546     m_num_lemmas(0),
547     m_trail(m)
548 {}
549 
induction(context & ctx,ast_manager & m)550 induction::induction(context& ctx, ast_manager& m):
551     ctx(ctx),
552     m(m),
553     vs(m),
554     m_collect_literals(ctx, m, vs),
555     m_create_lemmas(ctx, m)
556 {}
557 
558 // TBD: use smt_arith_value to also include state from arithmetic solver
init_values()559 void induction::init_values() {
560     for (enode* n : ctx.enodes())
561         if (m.is_value(n->get_owner()))
562             for (enode* r : *n)
563                 if (r != n) {
564                     vs.set_value(r->get_owner(), n->get_owner());
565                 }
566 }
567 
operator ()()568 bool induction::operator()() {
569     bool added_lemma = false;
570     vs.reset_values();
571     init_values();
572     literal_vector candidates = m_collect_literals();
573     for (literal lit : candidates) {
574         if (m_create_lemmas(lit))
575             added_lemma = true;
576     }
577     return added_lemma;
578 }
579 
580 // state contains datatypes + recursive functions
581 // more comprehensive:
582 // state contains integers / datatypes / sequences + recursive function / quantifiers
583 
should_try(context & ctx)584 bool induction::should_try(context& ctx) {
585     recfun::util u(ctx.get_manager());
586     datatype::util dt(ctx.get_manager());
587     theory* adt = ctx.get_theory(dt.get_family_id());
588     return adt && adt->get_num_vars() > 0 && !u.get_rec_funs().empty();
589 }
590