1 /*++
2 Copyright (c) 2020 Microsoft Corporation
3 
4 Module Name:
5 
6     dt_solver.h
7 
8 Abstract:
9 
10     Theory plugin for altegraic datatypes
11 
12 Author:
13 
14     Nikolaj Bjorner (nbjorner) 2020-09-08
15 
16 --*/
17 
18 #include "sat/smt/dt_solver.h"
19 #include "sat/smt/euf_solver.h"
20 #include "sat/smt/array_solver.h"
21 
22 namespace euf {
23     class solver;
24 }
25 
26 namespace dt {
27 
solver(euf::solver & ctx,theory_id id)28     solver::solver(euf::solver& ctx, theory_id id) :
29         th_euf_solver(ctx, ctx.get_manager().get_family_name(id), id),
30         dt(m),
31         m_autil(m),
32         m_find(*this),
33         m_args(m)
34     {}
35 
~solver()36     solver::~solver() {
37         std::for_each(m_var_data.begin(), m_var_data.end(), delete_proc<var_data>());
38         m_var_data.reset();
39     }
40 
clone_var(solver & src,theory_var v)41     void solver::clone_var(solver& src, theory_var v) {
42         enode* n = src.ctx.copy(ctx, src.var2enode(v));
43         VERIFY(v == th_euf_solver::mk_var(n));
44         m_var_data.push_back(alloc(var_data));
45         var_data* d_dst = m_var_data[v];
46         var_data* d_src = src.m_var_data[v];
47         ctx.attach_th_var(n, this, v);
48         if (d_src->m_constructor && !d_dst->m_constructor)
49             d_dst->m_constructor = src.ctx.copy(ctx, d_src->m_constructor);
50         for (auto* r : d_src->m_recognizers)
51             d_dst->m_recognizers.push_back(src.ctx.copy(ctx, r));
52     }
53 
clone(euf::solver & dst_ctx)54     euf::th_solver* solver::clone(euf::solver& dst_ctx) {
55         auto* result = alloc(solver, dst_ctx, get_id());
56         for (unsigned v = 0; v < get_num_vars(); ++v)
57             result->clone_var(*this, v);
58         return result;
59     }
60 
final_check_st(solver & s)61     solver::final_check_st::final_check_st(solver& s) : s(s) {
62         SASSERT(s.m_to_unmark1.empty());
63         SASSERT(s.m_to_unmark2.empty());
64         s.m_used_eqs.reset();
65         s.m_dfs.reset();
66     }
67 
~final_check_st()68     solver::final_check_st::~final_check_st() {
69         s.clear_mark();
70     }
71 
clear_mark()72     void solver::clear_mark() {
73         for (enode* n : m_to_unmark1)
74             n->unmark1();
75         for (enode* n : m_to_unmark2)
76             n->unmark2();
77         m_to_unmark1.reset();
78         m_to_unmark2.reset();
79     }
80 
oc_mark_on_stack(enode * n)81     void solver::oc_mark_on_stack(enode* n) {
82         n = n->get_root();
83         n->mark1();
84         m_to_unmark1.push_back(n);
85     }
86 
oc_mark_cycle_free(enode * n)87     void solver::oc_mark_cycle_free(enode* n) {
88         n = n->get_root();
89         n->mark2();
90         m_to_unmark2.push_back(n);
91     }
92 
oc_push_stack(enode * n)93     void solver::oc_push_stack(enode* n) {
94         m_dfs.push_back(std::make_pair(EXIT, n));
95         m_dfs.push_back(std::make_pair(ENTER, n));
96     }
97 
98     /**
99        \brief Assert the axiom (antecedent => lhs = rhs)
100        antecedent may be null_literal
101     */
assert_eq_axiom(enode * lhs,expr * rhs,literal antecedent)102     void solver::assert_eq_axiom(enode* lhs, expr* rhs, literal antecedent) {
103         if (antecedent == sat::null_literal)
104             add_unit(eq_internalize(lhs->get_expr(), rhs));
105         else if (s().value(antecedent) == l_true) {
106             euf::th_propagation* jst = euf::th_propagation::mk(*this, antecedent);
107             ctx.propagate(lhs, e_internalize(rhs), jst);
108         }
109         else
110             add_clause(~antecedent, eq_internalize(lhs->get_expr(), rhs));
111     }
112 
113     /**
114        \brief Assert the equality (= n (c (acc_1 n) ... (acc_m n))) where
115        where acc_i are the accessors of constructor c.
116     */
assert_is_constructor_axiom(enode * n,func_decl * c,literal antecedent)117     void solver::assert_is_constructor_axiom(enode* n, func_decl* c, literal antecedent) {
118         expr* e = n->get_expr();
119         TRACE("dt", tout << "creating axiom (= n (c (acc_1 n) ... (acc_m n))) for\n"
120             << mk_pp(c, m) << " " << mk_pp(e, m) << "\n";);
121         m_stats.m_assert_cnstr++;
122         SASSERT(dt.is_constructor(c));
123         SASSERT(is_datatype(e));
124         SASSERT(c->get_range() == m.get_sort(e));
125         m_args.reset();
126         ptr_vector<func_decl> const& accessors = *dt.get_constructor_accessors(c);
127         SASSERT(c->get_arity() == accessors.size());
128         for (func_decl* d : accessors)
129             m_args.push_back(m.mk_app(d, e));
130         expr_ref con(m.mk_app(c, m_args), m);
131         assert_eq_axiom(n, con, antecedent);
132     }
133 
134     /**
135        \brief Given a constructor n := (c a_1 ... a_m) assert the axioms
136        (= (acc_1 n) a_1)
137        ...
138        (= (acc_m n) a_m)
139     */
assert_accessor_axioms(enode * n)140     void solver::assert_accessor_axioms(enode* n) {
141         m_stats.m_assert_accessor++;
142         expr* e = n->get_expr();
143         SASSERT(is_constructor(n));
144         func_decl* d = n->get_decl();
145         ptr_vector<func_decl> const& accessors = *dt.get_constructor_accessors(d);
146         SASSERT(n->num_args() == accessors.size());
147         unsigned i = 0;
148         for (func_decl* acc : accessors) {
149             app_ref acc_app(m.mk_app(acc, e), m);
150             assert_eq_axiom(n->get_arg(i), acc_app);
151             ++i;
152         }
153     }
154 
155     /**
156        \brief Sign a conflict for r := is_mk(a), c := mk(...), not(r),  and c == a.
157     */
sign_recognizer_conflict(enode * c,enode * r)158     void solver::sign_recognizer_conflict(enode* c, enode* r) {
159         SASSERT(is_constructor(c));
160         SASSERT(is_recognizer(r));
161         SASSERT(dt.get_recognizer_constructor(r->get_decl()) == c->get_decl());
162         SASSERT(c->get_root() == r->get_arg(0)->get_root());
163         TRACE("dt", tout << ctx.bpp(c) << "\n" << ctx.bpp(r) << "\n";);
164         literal l = ctx.enode2literal(r);
165         SASSERT(s().value(l) == l_false);
166         clear_mark();
167         auto* jst = euf::th_propagation::mk(*this, ~l, c, r->get_arg(0));
168         ctx.set_conflict(jst);
169     }
170 
171     /**
172        \brief Given a field update n := { r with field := v } for constructor C, assert the axioms:
173        (=> (is-C r) (= (acc_j n) (acc_j r))) for acc_j != field
174        (=> (is-C r) (= (field n) v))         for acc_j != field
175        (=> (not (is-C r)) (= n r))
176        (=> (is-C r) (is-C n))
177     */
assert_update_field_axioms(enode * n)178     void solver::assert_update_field_axioms(enode* n) {
179         m_stats.m_assert_update_field++;
180         SASSERT(is_update_field(n));
181         expr* own = n->get_expr();
182         expr* arg1 = n->get_arg(0)->get_expr();
183         func_decl* upd = n->get_decl();
184         func_decl* acc = to_func_decl(upd->get_parameter(0).get_ast());
185         func_decl* con = dt.get_accessor_constructor(acc);
186         func_decl* rec = dt.get_constructor_is(con);
187         ptr_vector<func_decl> const& accessors = *dt.get_constructor_accessors(con);
188         app_ref rec_app(m.mk_app(rec, arg1), m);
189         app_ref acc_app(m);
190         literal is_con = mk_literal(rec_app);
191         for (func_decl* acc1 : accessors) {
192             enode* arg;
193             if (acc1 == acc) {
194                 arg = n->get_arg(1);
195             }
196             else {
197                 acc_app = m.mk_app(acc1, arg1);
198                 arg = e_internalize(acc_app);
199             }
200             app_ref acc_own(m.mk_app(acc1, own), m);
201             assert_eq_axiom(arg, acc_own, is_con);
202         }
203         // update_field is identity if 'n' is not created by a matching constructor.
204         assert_eq_axiom(n, arg1, ~is_con);
205         app_ref n_is_con(m.mk_app(rec, own), m);
206         add_clause(~is_con, mk_literal(n_is_con));
207     }
208 
mk_var(enode * n)209     euf::theory_var solver::mk_var(enode* n) {
210         if (is_attached_to_var(n))
211             return n->get_th_var(get_id());
212         euf::theory_var r = th_euf_solver::mk_var(n);
213         VERIFY(r == static_cast<theory_var>(m_find.mk_var()));
214         SASSERT(r == static_cast<int>(m_var_data.size()));
215         m_var_data.push_back(alloc(var_data));
216         var_data* d = m_var_data[r];
217         ctx.attach_th_var(n, this, r);
218         if (is_constructor(n)) {
219             d->m_constructor = n;
220             assert_accessor_axioms(n);
221         }
222         else if (is_update_field(n)) {
223             assert_update_field_axioms(n);
224         }
225         else {
226             sort* s = m.get_sort(n->get_expr());
227             if (dt.get_datatype_num_constructors(s) == 1)
228                 assert_is_constructor_axiom(n, dt.get_datatype_constructors(s)->get(0));
229             else if (get_config().m_dt_lazy_splits == 0 || (get_config().m_dt_lazy_splits == 1 && !s->is_infinite()))
230                 mk_split(r);
231         }
232         return r;
233     }
234 
235 
236     /**
237        \brief Create a new case split for v. That is, create the atom (is_mk v) and mark it as relevant.
238        If first is true, it means that v does not have recognizer yet.
239     */
mk_split(theory_var v)240     void solver::mk_split(theory_var v) {
241         m_stats.m_splits++;
242 
243         v = m_find.find(v);
244         enode* n = var2enode(v);
245         sort* srt = m.get_sort(n->get_expr());
246         func_decl* non_rec_c = dt.get_non_rec_constructor(srt);
247         unsigned non_rec_idx = dt.get_constructor_idx(non_rec_c);
248         var_data* d = m_var_data[v];
249         SASSERT(d->m_constructor == nullptr);
250         func_decl* r = nullptr;
251 
252         TRACE("dt", tout << "non_rec_c: " << non_rec_c->get_name() << " #rec: " << d->m_recognizers.size() << "\n";);
253 
254         enode* recognizer = d->m_recognizers.get(non_rec_idx, nullptr);
255         if (recognizer == nullptr)
256             r = dt.get_constructor_is(non_rec_c);
257         else if (ctx.value(recognizer) != l_false)
258             // if is l_true, then we are done
259             // otherwise wait for recognizer to be assigned.
260             return;
261         else {
262             // look for a slot of d->m_recognizers that is 0, or it is not marked as relevant and is unassigned.
263             unsigned idx = 0;
264             ptr_vector<func_decl> const& constructors = *dt.get_datatype_constructors(srt);
265             for (enode* curr : d->m_recognizers) {
266                 if (curr == nullptr) {
267                     // found empty slot...
268                     r = dt.get_constructor_is(constructors[idx]);
269                     break;
270                 }
271                 else if (ctx.value(curr) != l_false)
272                     return;
273                 ++idx;
274             }
275             if (r == nullptr)
276                 return; // all recognizers are asserted to false... conflict will be detected...
277         }
278         SASSERT(r != nullptr);
279         app_ref r_app(m.mk_app(r, n->get_expr()), m);
280         TRACE("dt", tout << "creating split: " << mk_pp(r_app, m) << "\n";);
281         mk_literal(r_app);
282     }
283 
apply_sort_cnstr(enode * n,sort * s)284     void solver::apply_sort_cnstr(enode* n, sort* s) {
285         force_push();
286         // Remark: If s is an infinite sort, then it is not necessary to create
287         // a theory variable.
288         //
289         // Actually, when the logical context has quantifiers, it is better to
290         // disable this optimization.
291         // Example:
292         //
293         //   (forall (l list) (a Int) (= (len (cons a l)) (+ (len l) 1)))
294         //   (assert (> (len a) 1)
295         //
296         // If the theory variable is not created for 'a', then a wrong model will be generated.
297         TRACE("dt", tout << "apply_sort_cnstr: #" << n->get_expr_id() << " " << mk_pp(n->get_expr(), m) << "\n";);
298         TRACE("dt_bug",
299             tout << "apply_sort_cnstr:\n" << mk_pp(n->get_expr(), m) << " ";
300             tout << dt.is_datatype(s) << " ";
301             if (dt.is_datatype(s)) tout << "is-infinite: " << s->is_infinite() << " ";
302             if (dt.is_datatype(s)) tout << "attached: " << is_attached_to_var(n) << " ";
303             tout << "\n";);
304 
305         if (!is_attached_to_var(n) &&
306             (/*ctx.has_quantifiers()*/ true ||
307                 (dt.is_datatype(s) && dt.has_nested_arrays()) ||
308                 (dt.is_datatype(s) && !s->is_infinite()))) {
309             mk_var(n);
310         }
311     }
312 
313 
new_eq_eh(euf::th_eq const & eq)314     void solver::new_eq_eh(euf::th_eq const& eq) {
315         force_push();
316         m_find.merge(eq.v1(), eq.v2());
317     }
318 
asserted(literal lit)319     void solver::asserted(literal lit) {
320         force_push();
321         enode* n = bool_var2enode(lit.var());
322         if (!is_recognizer(n))
323             return;
324         TRACE("dt", tout << "assigning recognizer: #" << n->get_expr_id() << " " << ctx.bpp(n) << "\n";);
325         SASSERT(n->num_args() == 1);
326         enode* arg = n->get_arg(0);
327         theory_var tv = arg->get_th_var(get_id());
328         tv = m_find.find(tv);
329         var_data* d = m_var_data[tv];
330         func_decl* r = n->get_decl();
331         func_decl* c = dt.get_recognizer_constructor(r);
332         if (!lit.sign()) {
333             SASSERT(tv != euf::null_theory_var);
334             if (d->m_constructor != nullptr && d->m_constructor->get_decl() == c)
335                 return; // do nothing
336             assert_is_constructor_axiom(arg, c, lit);
337         }
338         else if (d->m_constructor == nullptr)                   // make sure a constructor is attached
339             propagate_recognizer(tv, n);
340         else if (d->m_constructor->get_decl() == c)             // conflict
341             sign_recognizer_conflict(d->m_constructor, n);
342     }
343 
add_recognizer(theory_var v,enode * recognizer)344     void solver::add_recognizer(theory_var v, enode* recognizer) {
345         SASSERT(is_recognizer(recognizer));
346         v = m_find.find(v);
347         var_data* d = m_var_data[v];
348         sort* s = recognizer->get_decl()->get_domain(0);
349         if (d->m_recognizers.empty()) {
350             SASSERT(dt.is_datatype(s));
351             d->m_recognizers.resize(dt.get_datatype_num_constructors(s), nullptr);
352         }
353         SASSERT(d->m_recognizers.size() == dt.get_datatype_num_constructors(s));
354         unsigned c_idx = dt.get_recognizer_constructor_idx(recognizer->get_decl());
355         if (d->m_recognizers[c_idx] == nullptr) {
356             lbool val = ctx.value(recognizer);
357             TRACE("dt", tout << "adding recognizer to v" << v << " rec: #" << recognizer->get_expr_id() << " val: " << val << "\n";);
358             if (val == l_true) {
359                 // do nothing...
360                 // If recognizer assignment was already processed, then
361                 // d->m_constructor is already set.
362                 // Otherwise, it will be set when asserted is invoked.
363                 return;
364             }
365             if (val == l_false && d->m_constructor != nullptr) {
366                 func_decl* c_decl = dt.get_recognizer_constructor(recognizer->get_decl());
367                 if (d->m_constructor->get_decl() == c_decl) {
368                     // conflict
369                     sign_recognizer_conflict(d->m_constructor, recognizer);
370                 }
371                 return;
372             }
373             SASSERT(val == l_undef || (val == l_false && d->m_constructor == nullptr));
374             d->m_recognizers[c_idx] = recognizer;
375             ctx.push(set_vector_idx_trail<euf::solver, enode>(d->m_recognizers, c_idx));
376             if (val == l_false)
377                 propagate_recognizer(v, recognizer);
378         }
379     }
380 
381     /**
382        \brief Propagate a recognizer assigned to false.
383     */
propagate_recognizer(theory_var v,enode * recognizer)384     void solver::propagate_recognizer(theory_var v, enode* recognizer) {
385         SASSERT(is_recognizer(recognizer));
386         SASSERT(static_cast<int>(m_find.find(v)) == v);
387         SASSERT(ctx.value(recognizer) == l_false);
388         unsigned num_unassigned = 0;
389         unsigned unassigned_idx = UINT_MAX;
390         enode* n = var2enode(v);
391         sort* srt = m.get_sort(n->get_expr());
392         var_data* d = m_var_data[v];
393         if (d->m_recognizers.empty()) {
394             theory_var w = recognizer->get_arg(0)->get_th_var(get_id());
395             SASSERT(w != euf::null_theory_var);
396             add_recognizer(w, recognizer);
397         }
398         CTRACE("dt", d->m_recognizers.empty(), ctx.display(tout););
399         SASSERT(!d->m_recognizers.empty());
400         literal_vector lits;
401         enode_pair_vector eqs;
402         unsigned idx = 0;
403         for (enode* r : d->m_recognizers) {
404             if (!r) {
405                 if (num_unassigned == 0)
406                     unassigned_idx = idx;
407                 num_unassigned++;
408             }
409             else if (ctx.value(r) == l_true)
410                 return; // nothing to be propagated
411             else if (ctx.value(r) == l_false) {
412                 SASSERT(r->num_args() == 1);
413                 lits.push_back(~ctx.enode2literal(r));
414                 if (n != r->get_arg(0)) {
415                     // Argument of the current recognizer is not necessarily equal to n.
416                     // This can happen when n and r->get_arg(0) are in the same equivalence class.
417                     // We must add equality as an assumption to the conflict or propagation
418                     SASSERT(n->get_root() == r->get_arg(0)->get_root());
419                     eqs.push_back(euf::enode_pair(n, r->get_arg(0)));
420                 }
421             }
422             ++idx;
423         }
424         TRACE("dt", tout << "propagate " << num_unassigned << " eqs: " << eqs.size() << "\n";);
425         if (num_unassigned == 0)
426             ctx.set_conflict(euf::th_propagation::mk(*this, lits, eqs));
427         else if (num_unassigned == 1) {
428             // propagate remaining recognizer
429             SASSERT(!lits.empty());
430             enode* r = d->m_recognizers[unassigned_idx];
431             literal consequent;
432             if (!r) {
433                 ptr_vector<func_decl> const& constructors = *dt.get_datatype_constructors(srt);
434                 func_decl* rec = dt.get_constructor_is(constructors[unassigned_idx]);
435                 app_ref rec_app(m.mk_app(rec, n->get_expr()), m);
436                 consequent = mk_literal(rec_app);
437             }
438             else
439                 consequent = ctx.enode2literal(r);
440             ctx.propagate(consequent, euf::th_propagation::mk(*this, lits, eqs));
441         }
442         else if (get_config().m_dt_lazy_splits == 0 || (!srt->is_infinite() && get_config().m_dt_lazy_splits == 1))
443             // there are more than 2 unassigned recognizers...
444             // if eager splits are enabled... create new case split
445             mk_split(v);
446     }
447 
merge_eh(theory_var v1,theory_var v2,theory_var,theory_var)448     void solver::merge_eh(theory_var v1, theory_var v2, theory_var, theory_var) {
449         // v1 is the new root
450         TRACE("dt", tout << "merging v" << v1 << " v" << v2 << "\n";);
451         SASSERT(v1 == static_cast<int>(m_find.find(v1)));
452         var_data* d1 = m_var_data[v1];
453         var_data* d2 = m_var_data[v2];
454         auto* con1 = d1->m_constructor;
455         auto* con2 = d2->m_constructor;
456         if (con2 != nullptr) {
457             if (con1 == nullptr) {
458                 ctx.push(set_ptr_trail<euf::solver, enode>(con1));
459                 // check whether there is a recognizer in d1 that conflicts with con2;
460                 if (!d1->m_recognizers.empty()) {
461                     unsigned c_idx = dt.get_constructor_idx(con2->get_decl());
462                     enode* recognizer = d1->m_recognizers[c_idx];
463                     if (recognizer != nullptr && ctx.value(recognizer) == l_false) {
464                         sign_recognizer_conflict(con2, recognizer);
465                         return;
466                     }
467                 }
468                 d1->m_constructor = con2;
469             }
470             else if (con1->get_decl() != con2->get_decl())
471                 add_unit(~eq_internalize(con1->get_expr(), con2->get_expr()));
472         }
473         for (enode* e : d2->m_recognizers)
474             if (e)
475                 add_recognizer(v1, e);
476     }
477 
get_array_args(enode * n)478     ptr_vector<euf::enode> const& solver::get_array_args(enode* n) {
479         m_array_args.reset();
480         array::solver* th = dynamic_cast<array::solver*>(ctx.fid2solver(m_autil.get_family_id()));
481         for (enode* p : th->parent_selects(n))
482             m_array_args.push_back(p);
483         app_ref def(m_autil.mk_default(n->get_expr()), m);
484         m_array_args.push_back(ctx.get_enode(def));
485         return m_array_args;
486     }
487 
488     // Assuming `app` is equal to a constructor term, return the constructor enode
oc_get_cstor(enode * app)489     inline euf::enode* solver::oc_get_cstor(enode* app) {
490         theory_var v = app->get_root()->get_th_var(get_id());
491         SASSERT(v != euf::null_theory_var);
492         v = m_find.find(v);
493         var_data* d = m_var_data[v];
494         SASSERT(d->m_constructor);
495         return d->m_constructor;
496     }
497 
explain_is_child(enode * parent,enode * child)498     void solver::explain_is_child(enode* parent, enode* child) {
499         enode* parentc = oc_get_cstor(parent);
500         if (parent != parentc)
501             m_used_eqs.push_back(enode_pair(parent, parentc));
502 
503         // collect equalities on all children that may have been used.
504         bool found = false;
505         auto add = [&](enode* arg) {
506             if (arg->get_root() == child->get_root()) {
507                 if (arg != child)
508                     m_used_eqs.push_back(enode_pair(arg, child));
509                 found = true;
510             }
511         };
512         for (enode* arg : euf::enode_args(parentc)) {
513             add(arg);
514             sort* s = m.get_sort(arg->get_expr());
515             if (m_autil.is_array(s) && dt.is_datatype(get_array_range(s)))
516                 for (enode* aarg : get_array_args(arg))
517                     add(aarg);
518         }
519         VERIFY(found);
520     }
521 
522     // explain the cycle root -> ... -> app -> root
occurs_check_explain(enode * app,enode * root)523     void solver::occurs_check_explain(enode* app, enode* root) {
524         TRACE("dt", tout << "occurs_check_explain " << ctx.bpp(app) << " <-> " << ctx.bpp(root) << "\n";);
525 
526         // first: explain that root=v, given that app=cstor(...,v,...)
527 
528         explain_is_child(app, root);
529 
530         // now explain app=cstor(..,v,..) where v=root, and recurse with parent of app
531         while (app->get_root() != root->get_root()) {
532             enode* parent_app = m_parent[app->get_root()];
533             explain_is_child(parent_app, app);
534             SASSERT(is_constructor(parent_app));
535             app = parent_app;
536         }
537 
538         SASSERT(app->get_root() == root->get_root());
539         if (app != root)
540             m_used_eqs.push_back(enode_pair(app, root));
541 
542         TRACE("dt",
543             tout << "occurs_check\n"; for (enode_pair const& p : m_used_eqs) tout << ctx.bpp(p.first) << " - " << ctx.bpp(p.second) << " ";);
544     }
545 
546     // start exploring subgraph below `app`
occurs_check_enter(enode * app)547     bool solver::occurs_check_enter(enode* app) {
548         app = app->get_root();
549         theory_var v = app->get_th_var(get_id());
550         if (v == euf::null_theory_var)
551             return false;
552         v = m_find.find(v);
553         var_data* d = m_var_data[v];
554         if (!d->m_constructor)
555             return false;
556         enode* parent = d->m_constructor;
557         oc_mark_on_stack(parent);
558         for (enode* arg : euf::enode_args(parent)) {
559             if (oc_cycle_free(arg))
560                 continue;
561             if (oc_on_stack(arg)) {
562                 // arg was explored before app, and is still on the stack: cycle
563                 occurs_check_explain(parent, arg);
564                 return true;
565             }
566             // explore `arg` (with parent)
567             expr* earg = arg->get_expr();
568             sort* s = m.get_sort(earg);
569             if (dt.is_datatype(s)) {
570                 m_parent.insert(arg->get_root(), parent);
571                 oc_push_stack(arg);
572             }
573             else if (m_autil.is_array(s) && dt.is_datatype(get_array_range(s))) {
574                 for (enode* aarg : get_array_args(arg)) {
575                     if (oc_cycle_free(aarg))
576                         continue;
577                     if (oc_on_stack(aarg)) {
578                         occurs_check_explain(parent, aarg);
579                         return true;
580                     }
581                     if (is_datatype(aarg)) {
582                         m_parent.insert(aarg->get_root(), parent);
583                         oc_push_stack(aarg);
584                     }
585                 }
586             }
587         }
588         return false;
589     }
590 
591     /**
592        \brief Check if n can be reached starting from n and following equalities and constructors.
593        For example, occur_check(a1) returns true in the following set of equalities:
594        a1 = cons(v1, a2)
595        a2 = cons(v2, a3)
596        a3 = cons(v3, a1)
597     */
occurs_check(enode * n)598     bool solver::occurs_check(enode* n) {
599         TRACE("dt", tout << "occurs check: " << ctx.bpp(n) << "\n";);
600         m_stats.m_occurs_check++;
601 
602         bool res = false;
603         oc_push_stack(n);
604 
605         // DFS traversal from `n`. Look at top element and explore it.
606         while (!res && !m_dfs.empty()) {
607             stack_op op = m_dfs.back().first;
608             enode* app = m_dfs.back().second;
609             m_dfs.pop_back();
610 
611             if (oc_cycle_free(app))
612                 continue;
613 
614             TRACE("dt", tout << "occurs check loop: " << ctx.bpp(app) << (op == ENTER ? " enter" : " exit") << "\n";);
615 
616             switch (op) {
617             case ENTER:
618                 res = occurs_check_enter(app);
619                 break;
620 
621             case EXIT:
622                 oc_mark_cycle_free(app);
623                 break;
624             }
625         }
626 
627         if (res) {
628             clear_mark();
629             ctx.set_conflict(euf::th_propagation::mk(*this, m_used_eqs));
630         }
631         return res;
632     }
633 
check()634     sat::check_result solver::check() {
635         force_push();
636         int num_vars = get_num_vars();
637         sat::check_result r = sat::check_result::CR_DONE;
638         final_check_st _guard(*this);
639         int start = s().rand()();
640         for (int i = 0; i < num_vars; i++) {
641             theory_var v = (i + start) % num_vars;
642             if (v == static_cast<int>(m_find.find(v))) {
643                 enode* node = var2enode(v);
644                 if (!is_datatype(node))
645                     continue;
646                 if (!oc_cycle_free(node) && occurs_check(node))
647                     // conflict was detected...
648                     return sat::check_result::CR_CONTINUE;
649                 if (get_config().m_dt_lazy_splits > 0) {
650                     // using lazy case splits...
651                     var_data* d = m_var_data[v];
652                     if (d->m_constructor == nullptr) {
653                         clear_mark();
654                         mk_split(v);
655                         r = sat::check_result::CR_CONTINUE;
656                     }
657                 }
658             }
659         }
660         return r;
661     }
662 
pop_core(unsigned num_scopes)663     void solver::pop_core(unsigned num_scopes) {
664         th_euf_solver::pop_core(num_scopes);
665         std::for_each(m_var_data.begin() + get_num_vars(), m_var_data.end(), delete_proc<var_data>());
666         m_var_data.shrink(get_num_vars());
667         SASSERT(m_find.get_num_vars() == m_var_data.size());
668         SASSERT(m_find.get_num_vars() == get_num_vars());
669     }
670 
get_antecedents(literal l,sat::ext_justification_idx idx,literal_vector & r,bool probing)671     void solver::get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing) {
672         auto& jst = euf::th_propagation::from_index(idx);
673         ctx.get_antecedents(l, jst, r, probing);
674     }
675 
add_value(euf::enode * n,model & mdl,expr_ref_vector & values)676     void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) {
677         theory_var v = n->get_th_var(get_id());
678         v = m_find.find(v);
679         SASSERT(v != euf::null_theory_var);
680         enode* con = m_var_data[v]->m_constructor;
681         func_decl* c_decl = con->get_decl();
682         m_args.reset();
683         for (enode* arg : euf::enode_args(m_var_data[v]->m_constructor))
684             m_args.push_back(values.get(arg->get_root_id()));
685         values.set(n->get_root_id(), m.mk_app(c_decl, m_args));
686     }
687 
add_dep(euf::enode * n,top_sort<euf::enode> & dep)688     void solver::add_dep(euf::enode* n, top_sort<euf::enode>& dep) {
689         theory_var v = n->get_th_var(get_id());
690         for (enode* arg : euf::enode_args(m_var_data[m_find.find(v)]->m_constructor))
691             dep.add(n, arg);
692     }
693 
internalize(expr * e,bool sign,bool root,bool redundant)694     sat::literal solver::internalize(expr* e, bool sign, bool root, bool redundant) {
695         if (!visit_rec(m, e, sign, root, redundant)) {
696             TRACE("dt", tout << mk_pp(e, m) << "\n";);
697             return sat::null_literal;
698         }
699         auto lit = ctx.expr2literal(e);
700         if (sign)
701             lit.neg();
702         return lit;
703     }
704 
internalize(expr * e,bool redundant)705     void solver::internalize(expr* e, bool redundant) {
706         visit_rec(m, e, false, false, redundant);
707     }
708 
visit(expr * e)709     bool solver::visit(expr* e) {
710         if (visited(e))
711             return true;
712         if (!is_app(e) || to_app(e)->get_family_id() != get_id()) {
713             ctx.internalize(e, m_is_redundant);
714             if (is_datatype(e))
715                 mk_var(expr2enode(e));
716             return true;
717         }
718         m_stack.push_back(sat::eframe(e));
719         return false;
720     }
721 
visited(expr * e)722     bool solver::visited(expr* e) {
723         euf::enode* n = expr2enode(e);
724         return n && n->is_attached_to(get_id());
725     }
726 
post_visit(expr * term,bool sign,bool root)727     bool solver::post_visit(expr* term, bool sign, bool root) {
728         euf::enode* n = expr2enode(term);
729         SASSERT(!n || !n->is_attached_to(get_id()));
730         if (!n)
731             n = mk_enode(term);
732         SASSERT(!n->is_attached_to(get_id()));
733         if (is_constructor(term) || is_update_field(term)) {
734             for (enode* arg : euf::enode_args(n)) {
735                 sort* s = m.get_sort(arg->get_expr());
736                 if (dt.is_datatype(s))
737                     mk_var(arg);
738                 else if (m_autil.is_array(s) && dt.is_datatype(get_array_range(s))) {
739                     app_ref def(m_autil.mk_default(arg->get_expr()), m);
740                     mk_var(e_internalize(def));
741                 }
742             }
743             mk_var(n);
744         }
745         else if (is_recognizer(term)) {
746             enode* arg = n->get_arg(0);
747             theory_var v = mk_var(arg);
748             add_recognizer(v, n);
749         }
750         else {
751             SASSERT(is_accessor(term));
752             SASSERT(n->num_args() == 1);
753             mk_var(n->get_arg(0));
754         }
755         return true;
756     }
757 
collect_statistics(::statistics & st) const758     void solver::collect_statistics(::statistics& st) const {
759         st.update("datatype occurs check", m_stats.m_occurs_check);
760         st.update("datatype splits", m_stats.m_splits);
761         st.update("datatype constructor ax", m_stats.m_assert_cnstr);
762         st.update("datatype accessor ax", m_stats.m_assert_accessor);
763         st.update("datatype update ax", m_stats.m_assert_update_field);
764     }
765 
display(std::ostream & out) const766     std::ostream& solver::display(std::ostream& out) const {
767         unsigned num_vars = get_num_vars();
768         if (num_vars > 0)
769             out << "Theory datatype:\n";
770         for (unsigned v = 0; v < num_vars; v++)
771             display_var(out, v);
772         return out;
773     }
774 
display_var(std::ostream & out,theory_var v) const775     void solver::display_var(std::ostream& out, theory_var v) const {
776         var_data* d = m_var_data[v];
777         out << "v" << v << " #" << var2expr(v)->get_id() << " -> v" << m_find.find(v) << " ";
778         if (d->m_constructor)
779             out << ctx.bpp(d->m_constructor);
780         else
781             out << "(null)";
782         out << "\n";
783     }
784 }
785