1 /*++
2 Copyright (c) 2020 Microsoft Corporation
3 
4 Module Name:
5 
6     dt_solver.h
7 
8 Abstract:
9 
10     Theory plugin for altegraic datatypes
11 
12 Author:
13 
14     Nikolaj Bjorner (nbjorner) 2020-09-08
15 
16 --*/
17 #pragma once
18 
19 #include "sat/smt/sat_th.h"
20 #include "ast/datatype_decl_plugin.h"
21 #include "ast/array_decl_plugin.h"
22 
23 namespace euf {
24     class solver;
25 }
26 
27 namespace dt {
28 
29     class solver : public euf::th_euf_solver {
30         typedef euf::theory_var theory_var;
31         typedef euf::theory_id theory_id;
32         typedef euf::enode enode;
33         typedef euf::enode_pair enode_pair;
34         typedef euf::enode_pair_vector enode_pair_vector;
35         typedef sat::bool_var bool_var;
36         typedef sat::literal literal;
37         typedef sat::literal_vector literal_vector;
38         typedef union_find<solver, euf::solver>  dt_union_find;
39 
40         struct var_data {
41             ptr_vector<enode> m_recognizers; //!< recognizers of this equivalence class that are being watched.
42             enode *           m_constructor; //!< constructor of this equivalence class, 0 if there is no constructor in the eqc.
var_datavar_data43             var_data():
44                 m_constructor(nullptr) {
45             }
46         };
47 
48         // class for managing state of final_check
49         class final_check_st {
50             solver& s;
51         public:
52             final_check_st(solver& s);
53             ~final_check_st();
54         };
55 
56         struct stats {
57             unsigned   m_occurs_check, m_splits;
58             unsigned   m_assert_cnstr, m_assert_accessor, m_assert_update_field;
resetstats59             void reset() { memset(this, 0, sizeof(*this)); }
statsstats60             stats() { reset(); }
61         };
62 
63         datatype_util         dt;
64         array_util            m_autil;
65         stats                 m_stats;
66         ptr_vector<var_data>  m_var_data;
67         dt_union_find         m_find;
68         expr_ref_vector       m_args;
69 
is_constructor(expr * f)70         bool is_constructor(expr * f) const { return dt.is_constructor(f); }
is_recognizer(expr * f)71         bool is_recognizer(expr * f) const { return dt.is_recognizer(f); }
is_accessor(expr * f)72         bool is_accessor(expr * f) const { return dt.is_accessor(f); }
is_update_field(expr * f)73         bool is_update_field(expr * f) const { return dt.is_update_field(f); }
74 
is_constructor(enode * n)75         bool is_constructor(enode * n) const { return is_constructor(n->get_expr()); }
is_recognizer(enode * n)76         bool is_recognizer(enode * n) const { return is_recognizer(n->get_expr()); }
is_accessor(enode * n)77         bool is_accessor(enode * n) const { return is_accessor(n->get_expr()); }
is_update_field(enode * n)78         bool is_update_field(enode * n) const { return dt.is_update_field(n->get_expr()); }
79 
is_datatype(expr * e)80         bool is_datatype(expr* e) const { return dt.is_datatype(m.get_sort(e)); }
is_datatype(enode * n)81         bool is_datatype(enode* n) const { return is_datatype(n->get_expr()); }
82 
83         void assert_eq_axiom(enode * lhs, expr * rhs, literal antecedent = sat::null_literal);
84         void assert_is_constructor_axiom(enode * n, func_decl * c, literal antecedent = sat::null_literal);
85         void assert_accessor_axioms(enode * n);
86         void assert_update_field_axioms(enode * n);
87         void add_recognizer(theory_var v, enode * recognizer);
88         void propagate_recognizer(theory_var v, enode * r);
89         void sign_recognizer_conflict(enode * c, enode * r);
90 
91         typedef enum { ENTER, EXIT } stack_op;
92         typedef obj_map<enode, enode*> parent_tbl;
93         typedef std::pair<stack_op, enode*> stack_entry;
94 
95         ptr_vector<enode>     m_to_unmark1;
96         ptr_vector<enode>     m_to_unmark2;
97         enode_pair_vector     m_used_eqs; // conflict, if any
98         parent_tbl            m_parent; // parent explanation for occurs_check
99         svector<stack_entry>  m_dfs; // stack for DFS for occurs_check
100 
101         void clear_mark();
102 
103         void oc_mark_on_stack(enode * n);
oc_on_stack(enode * n)104         bool oc_on_stack(enode * n) const { return n->get_root()->is_marked1(); }
105 
106         void oc_mark_cycle_free(enode * n);
oc_cycle_free(enode * n)107         bool oc_cycle_free(enode * n) const { return n->get_root()->is_marked2(); }
108 
109         void oc_push_stack(enode * n);
110         ptr_vector<enode> m_array_args;
111         ptr_vector<enode> const& get_array_args(enode* n);
112 
113         void pop_core(unsigned n) override;
114 
115         enode * oc_get_cstor(enode * n);
116         bool occurs_check(enode * n);
117         bool occurs_check_enter(enode * n);
118         void occurs_check_explain(enode * top, enode * root);
119         void explain_is_child(enode* parent, enode* child);
120 
121         void mk_split(theory_var v);
122 
123         void display_var(std::ostream & out, theory_var v) const;
124 
125         // internalize
126         bool visit(expr* e) override;
127         bool visited(expr* e) override;
128         bool post_visit(expr* e, bool sign, bool root) override;
129         void clone_var(solver& src, theory_var v);
130 
131     public:
132         solver(euf::solver& ctx, theory_id id);
133         ~solver() override;
134 
is_external(bool_var v)135         bool is_external(bool_var v) override { return false; }
136         void get_antecedents(literal l, sat::ext_justification_idx idx, literal_vector& r, bool probing) override;
137         void asserted(literal l) override;
138         sat::check_result check() override;
139 
140         std::ostream& display(std::ostream& out) const override;
display_justification(std::ostream & out,sat::ext_justification_idx idx)141         std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override { return euf::th_propagation::from_index(idx).display(out); }
display_constraint(std::ostream & out,sat::ext_constraint_idx idx)142         std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override { return display_justification(out, idx); }
143         void collect_statistics(statistics& st) const override;
144         euf::th_solver* clone(euf::solver& ctx) override;
145         void new_eq_eh(euf::th_eq const& eq) override;
unit_propagate()146         bool unit_propagate() override { return false; }
147         void add_value(euf::enode* n, model& mdl, expr_ref_vector& values) override;
148         void add_dep(euf::enode* n, top_sort<euf::enode>& dep) override;
149         sat::literal internalize(expr* e, bool sign, bool root, bool redundant) override;
150         void internalize(expr* e, bool redundant) override;
151         euf::theory_var mk_var(euf::enode* n) override;
152         void apply_sort_cnstr(euf::enode* n, sort* s) override;
is_shared(theory_var v)153         bool is_shared(theory_var v) const override { return false; }
154 
155         void merge_eh(theory_var, theory_var, theory_var v1, theory_var v2);
after_merge_eh(theory_var r1,theory_var r2,theory_var v1,theory_var v2)156         void after_merge_eh(theory_var r1, theory_var r2, theory_var v1, theory_var v2) {}
unmerge_eh(theory_var v1,theory_var v2)157         void unmerge_eh(theory_var v1, theory_var v2) {}
158     };
159 }
160