1 /*++
2 Copyright (c) 2020 Microsoft Corporation
3 
4 Module Name:
5 
6     q_solver.cpp
7 
8 Abstract:
9 
10     Quantifier solver plugin
11 
12 Author:
13 
14     Nikolaj Bjorner (nbjorner) 2020-09-29
15 
16 --*/
17 
18 #include "ast/ast_util.h"
19 #include "ast/well_sorted.h"
20 #include "ast/rewriter/var_subst.h"
21 #include "ast/normal_forms/pull_quant.h"
22 #include "sat/smt/q_solver.h"
23 #include "sat/smt/euf_solver.h"
24 #include "sat/smt/sat_th.h"
25 
26 
27 namespace q {
28 
solver(euf::solver & ctx,family_id fid)29     solver::solver(euf::solver& ctx, family_id fid) :
30         th_euf_solver(ctx, ctx.get_manager().get_family_name(fid), fid),
31         m_mbqi(ctx,  *this),
32         m_ematch(ctx, *this),
33         m_expanded(ctx.get_manager()),
34         m_der(ctx.get_manager())
35     {
36     }
37 
asserted(sat::literal l)38     void solver::asserted(sat::literal l) {
39         expr* e = bool_var2expr(l.var());
40         if (!is_forall(e) && !is_exists(e))
41             return;
42         quantifier* q = to_quantifier(e);
43 
44         if (l.sign() == is_forall(e)) {
45             sat::literal lit = skolemize(q);
46             add_clause(~l, lit);
47             ctx.add_root(~l, lit);
48         }
49         else if (expand(q)) {
50             for (expr* e : m_expanded) {
51                 sat::literal lit = ctx.internalize(e, l.sign(), false, false);
52                 add_clause(~l, lit);
53                 ctx.add_root(~l, lit);
54             }
55         }
56         else if (is_ground(q->get_expr())) {
57             auto lit = ctx.internalize(q->get_expr(), l.sign(), false, false);
58             add_clause(~l, lit);
59             ctx.add_root(~l, lit);
60         }
61         else {
62             ctx.push_vec(m_universal, l);
63             if (ctx.get_config().m_ematching)
64                 m_ematch.add(q);
65         }
66         m_stats.m_num_quantifier_asserts++;
67     }
68 
check()69     sat::check_result solver::check() {
70         if (ctx.get_config().m_ematching && m_ematch())
71             return sat::check_result::CR_CONTINUE;
72 
73         if (ctx.get_config().m_mbqi) {
74             switch (m_mbqi()) {
75             case l_true:  return sat::check_result::CR_DONE;
76             case l_false: return sat::check_result::CR_CONTINUE;
77             case l_undef: break;
78             }
79         }
80         return sat::check_result::CR_GIVEUP;
81     }
82 
display(std::ostream & out) const83     std::ostream& solver::display(std::ostream& out) const {
84         return m_ematch.display(out);
85     }
86 
display_constraint(std::ostream & out,sat::ext_constraint_idx idx) const87     std::ostream& solver::display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const {
88         return m_ematch.display_constraint(out, idx);
89     }
90 
collect_statistics(statistics & st) const91     void solver::collect_statistics(statistics& st) const {
92         st.update("q asserts", m_stats.m_num_quantifier_asserts);
93         m_mbqi.collect_statistics(st);
94         m_ematch.collect_statistics(st);
95     }
96 
clone(euf::solver & ctx)97     euf::th_solver* solver::clone(euf::solver& ctx) {
98         family_id fid = ctx.get_manager().mk_family_id(symbol("quant"));
99         return alloc(solver, ctx, fid);
100     }
101 
unit_propagate()102     bool solver::unit_propagate() {
103         return m_ematch.unit_propagate();
104     }
105 
mk_var(euf::enode * n)106     euf::theory_var solver::mk_var(euf::enode* n) {
107         auto v = euf::th_euf_solver::mk_var(n);
108         ctx.attach_th_var(n, this, v);
109         return v;
110     }
111 
instantiate(quantifier * _q,bool negate,std::function<expr * (quantifier *,unsigned)> & mk_var)112     sat::literal solver::instantiate(quantifier* _q, bool negate, std::function<expr* (quantifier*, unsigned)>& mk_var) {
113         sat::literal sk;
114         expr_ref tmp(m);
115         quantifier_ref q(_q, m);
116         expr_ref_vector vars(m);
117         if (negate) {
118             q = m.mk_quantifier(
119                 is_forall(q) ? quantifier_kind::exists_k : quantifier_kind::forall_k,
120                 q->get_num_decls(), q->get_decl_sorts(), q->get_decl_names(), m.mk_not(q->get_expr()),
121                 q->get_weight(), q->get_qid(), q->get_skid());
122         }
123         quantifier* q_flat = flatten(q);
124         unsigned sz = q_flat->get_num_decls();
125         vars.resize(sz, nullptr);
126         for (unsigned i = 0; i < sz; ++i)
127             vars[i] = mk_var(q_flat, i);
128         var_subst subst(m);
129         expr_ref body = subst(q_flat->get_expr(), vars);
130         rewrite(body);
131         return mk_literal(body);
132     }
133 
skolemize(quantifier * q)134     sat::literal solver::skolemize(quantifier* q) {
135         std::function<expr* (quantifier*, unsigned)> mk_var = [&](quantifier* q, unsigned i) {
136             return m.mk_fresh_const(q->get_decl_name(i), q->get_decl_sort(i));
137         };
138         return instantiate(q, is_forall(q), mk_var);
139     }
140 
141     /*
142      * Find initial values to instantiate quantifier with so to make it as hard as possible for solver
143      * to find values to free variables.
144      */
specialize(quantifier * q)145     sat::literal solver::specialize(quantifier* q) {
146         std::function<expr* (quantifier*, unsigned)> mk_var = [&](quantifier* q, unsigned i) {
147             return get_unit(q->get_decl_sort(i));
148         };
149         return instantiate(q, is_exists(q), mk_var);
150     }
151 
init_search()152     void solver::init_search() {
153         m_mbqi.init_search();
154     }
155 
internalize(expr * e,bool sign,bool root,bool learned)156     sat::literal solver::internalize(expr* e, bool sign, bool root, bool learned) {
157         SASSERT(is_forall(e) || is_exists(e));
158         sat::bool_var v = ctx.get_si().add_bool_var(e);
159         sat::literal lit = ctx.attach_lit(sat::literal(v, false), e);
160         mk_var(ctx.get_egraph().find(e));
161         if (sign)
162             lit.neg();
163         return lit;
164     }
165 
finalize_model(model & mdl)166     void solver::finalize_model(model& mdl) {
167         m_mbqi.finalize_model(mdl);
168     }
169 
flatten(quantifier * q)170     quantifier* solver::flatten(quantifier* q) {
171         quantifier* q_flat = nullptr;
172         if (!has_quantifiers(q->get_expr()))
173             return q;
174         if (m_flat.find(q, q_flat))
175             return q_flat;
176         proof_ref pr(m);
177         expr_ref  new_q(m);
178         if (is_forall(q)) {
179             pull_quant pull(m);
180             pull(q, new_q, pr);
181             SASSERT(is_well_sorted(m, new_q));
182         }
183         else {
184             new_q = q;
185         }
186         q_flat = to_quantifier(new_q);
187         m.inc_ref(q_flat);
188         m.inc_ref(q);
189         m_flat.insert(q, q_flat);
190         ctx.push(insert_ref2_map<ast_manager, quantifier, quantifier>(m, m_flat, q, q_flat));
191         return q_flat;
192     }
193 
init_units()194     void solver::init_units() {
195         if (!m_unit_table.empty())
196             return;
197         for (euf::enode* n : ctx.get_egraph().nodes()) {
198             if (!n->interpreted() && !m.is_uninterp(n->get_expr()->get_sort()))
199                 continue;
200             expr* e = n->get_expr();
201             sort* s = e->get_sort();
202             if (m_unit_table.contains(s))
203                 continue;
204             m_unit_table.insert(s, e);
205             ctx.push(insert_map<obj_map<sort, expr*>, sort*>(m_unit_table, s));
206         }
207     }
208 
get_unit(sort * s)209     expr* solver::get_unit(sort* s) {
210         expr* u = nullptr;
211         if (m_unit_table.find(s, u))
212             return u;
213         init_units();
214         if (m_unit_table.find(s, u))
215             return u;
216         model mdl(m);
217         expr* val = mdl.get_some_value(s);
218         m.inc_ref(val);
219         m.inc_ref(s);
220         ctx.push(insert_ref2_map<ast_manager, sort, expr>(m, m_unit_table, s, val));
221         return val;
222     }
223 
expand(quantifier * q)224     bool solver::expand(quantifier* q) {
225         expr_ref r(m);
226         proof_ref pr(m);
227         m_der(q, r, pr);
228         m_expanded.reset();
229         if (r != q) {
230             ctx.get_rewriter()(r);
231             m_expanded.push_back(r);
232             return true;
233         }
234         if (is_forall(q))
235             flatten_and(q->get_expr(), m_expanded);
236         else if (is_exists(q))
237             flatten_or(q->get_expr(), m_expanded);
238         else
239             UNREACHABLE();
240 
241         if (m_expanded.size() == 1 && is_forall(q)) {
242             m_expanded.reset();
243             flatten_or(q->get_expr(), m_expanded);
244             expr_ref split1(m), split2(m), e1(m), e2(m);
245             unsigned idx = 0;
246             for (unsigned i = m_expanded.size(); i-- > 0; ) {
247                 expr* arg = m_expanded.get(i);
248                 if (split(arg, split1, split2)) {
249                     if (e1)
250                         return false;
251                     e1 = split1;
252                     e2 = split2;
253                     idx = i;
254                 }
255             }
256             if (!e1)
257                 return false;
258 
259             m_expanded[idx] = e1;
260             e1 = mk_or(m_expanded);
261             m_expanded[idx] = e2;
262             e2 = mk_or(m_expanded);
263             m_expanded.reset();
264             m_expanded.push_back(e1);
265             m_expanded.push_back(e2);
266         }
267         if (m_expanded.size() > 1) {
268             for (unsigned i = m_expanded.size(); i-- > 0; ) {
269                 expr_ref tmp(m.update_quantifier(q, m_expanded.get(i)), m);
270                 ctx.get_rewriter()(tmp);
271                 m_expanded[i] = tmp;
272             }
273             return true;
274         }
275         return false;
276     }
277 
split(expr * arg,expr_ref & e1,expr_ref & e2)278     bool solver::split(expr* arg, expr_ref& e1, expr_ref& e2) {
279         expr* x, * y, * z;
280         if (m.is_not(arg, x) && m.is_or(x, y, z) && is_literal(y) && is_literal(z)) {
281             e1 = mk_not(m, y);
282             e2 = mk_not(m, z);
283             return true;
284         }
285         if (m.is_iff(arg, x, y) && is_literal(x) && is_literal(y)) {
286             e1 = m.mk_implies(x, y);
287             e2 = m.mk_implies(y, x);
288             return true;
289         }
290         if (m.is_and(arg, x, y) && is_literal(x) && is_literal(y)) {
291             e1 = x;
292             e2 = y;
293             return true;
294         }
295         if (m.is_not(arg, z) && m.is_iff(z, x, y) && is_literal(x) && is_literal(y)) {
296             e1 = m.mk_or(x, y);
297             e2 = m.mk_or(mk_not(m, x), mk_not(m, y));
298             return true;
299         }
300         return false;
301     }
302 
is_literal(expr * arg)303     bool solver::is_literal(expr* arg) {
304         m.is_not(arg, arg);
305         return !m.is_and(arg) && !m.is_or(arg) && !m.is_iff(arg) && !m.is_implies(arg);
306     }
307 
get_antecedents(sat::literal l,sat::ext_justification_idx idx,sat::literal_vector & r,bool probing)308     void solver::get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) {
309         m_ematch.get_antecedents(l, idx, r, probing);
310     }
311 
312 }
313