1 /*++ 2 Copyright (c) 2020 Microsoft Corporation 3 4 Module Name: 5 6 q_solver.cpp 7 8 Abstract: 9 10 Quantifier solver plugin 11 12 Author: 13 14 Nikolaj Bjorner (nbjorner) 2020-09-29 15 16 --*/ 17 18 #include "ast/ast_util.h" 19 #include "ast/well_sorted.h" 20 #include "ast/rewriter/var_subst.h" 21 #include "ast/normal_forms/pull_quant.h" 22 #include "sat/smt/q_solver.h" 23 #include "sat/smt/euf_solver.h" 24 #include "sat/smt/sat_th.h" 25 26 27 namespace q { 28 solver(euf::solver & ctx,family_id fid)29 solver::solver(euf::solver& ctx, family_id fid) : 30 th_euf_solver(ctx, ctx.get_manager().get_family_name(fid), fid), 31 m_mbqi(ctx, *this), 32 m_ematch(ctx, *this), 33 m_expanded(ctx.get_manager()), 34 m_der(ctx.get_manager()) 35 { 36 } 37 asserted(sat::literal l)38 void solver::asserted(sat::literal l) { 39 expr* e = bool_var2expr(l.var()); 40 if (!is_forall(e) && !is_exists(e)) 41 return; 42 quantifier* q = to_quantifier(e); 43 44 if (l.sign() == is_forall(e)) { 45 sat::literal lit = skolemize(q); 46 add_clause(~l, lit); 47 ctx.add_root(~l, lit); 48 } 49 else if (expand(q)) { 50 for (expr* e : m_expanded) { 51 sat::literal lit = ctx.internalize(e, l.sign(), false, false); 52 add_clause(~l, lit); 53 ctx.add_root(~l, lit); 54 } 55 } 56 else if (is_ground(q->get_expr())) { 57 auto lit = ctx.internalize(q->get_expr(), l.sign(), false, false); 58 add_clause(~l, lit); 59 ctx.add_root(~l, lit); 60 } 61 else { 62 ctx.push_vec(m_universal, l); 63 if (ctx.get_config().m_ematching) 64 m_ematch.add(q); 65 } 66 m_stats.m_num_quantifier_asserts++; 67 } 68 check()69 sat::check_result solver::check() { 70 if (ctx.get_config().m_ematching && m_ematch()) 71 return sat::check_result::CR_CONTINUE; 72 73 if (ctx.get_config().m_mbqi) { 74 switch (m_mbqi()) { 75 case l_true: return sat::check_result::CR_DONE; 76 case l_false: return sat::check_result::CR_CONTINUE; 77 case l_undef: break; 78 } 79 } 80 return sat::check_result::CR_GIVEUP; 81 } 82 display(std::ostream & out) const83 std::ostream& solver::display(std::ostream& out) const { 84 return m_ematch.display(out); 85 } 86 display_constraint(std::ostream & out,sat::ext_constraint_idx idx) const87 std::ostream& solver::display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const { 88 return m_ematch.display_constraint(out, idx); 89 } 90 collect_statistics(statistics & st) const91 void solver::collect_statistics(statistics& st) const { 92 st.update("q asserts", m_stats.m_num_quantifier_asserts); 93 m_mbqi.collect_statistics(st); 94 m_ematch.collect_statistics(st); 95 } 96 clone(euf::solver & ctx)97 euf::th_solver* solver::clone(euf::solver& ctx) { 98 family_id fid = ctx.get_manager().mk_family_id(symbol("quant")); 99 return alloc(solver, ctx, fid); 100 } 101 unit_propagate()102 bool solver::unit_propagate() { 103 return m_ematch.unit_propagate(); 104 } 105 mk_var(euf::enode * n)106 euf::theory_var solver::mk_var(euf::enode* n) { 107 auto v = euf::th_euf_solver::mk_var(n); 108 ctx.attach_th_var(n, this, v); 109 return v; 110 } 111 instantiate(quantifier * _q,bool negate,std::function<expr * (quantifier *,unsigned)> & mk_var)112 sat::literal solver::instantiate(quantifier* _q, bool negate, std::function<expr* (quantifier*, unsigned)>& mk_var) { 113 sat::literal sk; 114 expr_ref tmp(m); 115 quantifier_ref q(_q, m); 116 expr_ref_vector vars(m); 117 if (negate) { 118 q = m.mk_quantifier( 119 is_forall(q) ? quantifier_kind::exists_k : quantifier_kind::forall_k, 120 q->get_num_decls(), q->get_decl_sorts(), q->get_decl_names(), m.mk_not(q->get_expr()), 121 q->get_weight(), q->get_qid(), q->get_skid()); 122 } 123 quantifier* q_flat = flatten(q); 124 unsigned sz = q_flat->get_num_decls(); 125 vars.resize(sz, nullptr); 126 for (unsigned i = 0; i < sz; ++i) 127 vars[i] = mk_var(q_flat, i); 128 var_subst subst(m); 129 expr_ref body = subst(q_flat->get_expr(), vars); 130 rewrite(body); 131 return mk_literal(body); 132 } 133 skolemize(quantifier * q)134 sat::literal solver::skolemize(quantifier* q) { 135 std::function<expr* (quantifier*, unsigned)> mk_var = [&](quantifier* q, unsigned i) { 136 return m.mk_fresh_const(q->get_decl_name(i), q->get_decl_sort(i)); 137 }; 138 return instantiate(q, is_forall(q), mk_var); 139 } 140 141 /* 142 * Find initial values to instantiate quantifier with so to make it as hard as possible for solver 143 * to find values to free variables. 144 */ specialize(quantifier * q)145 sat::literal solver::specialize(quantifier* q) { 146 std::function<expr* (quantifier*, unsigned)> mk_var = [&](quantifier* q, unsigned i) { 147 return get_unit(q->get_decl_sort(i)); 148 }; 149 return instantiate(q, is_exists(q), mk_var); 150 } 151 init_search()152 void solver::init_search() { 153 m_mbqi.init_search(); 154 } 155 internalize(expr * e,bool sign,bool root,bool learned)156 sat::literal solver::internalize(expr* e, bool sign, bool root, bool learned) { 157 SASSERT(is_forall(e) || is_exists(e)); 158 sat::bool_var v = ctx.get_si().add_bool_var(e); 159 sat::literal lit = ctx.attach_lit(sat::literal(v, false), e); 160 mk_var(ctx.get_egraph().find(e)); 161 if (sign) 162 lit.neg(); 163 return lit; 164 } 165 finalize_model(model & mdl)166 void solver::finalize_model(model& mdl) { 167 m_mbqi.finalize_model(mdl); 168 } 169 flatten(quantifier * q)170 quantifier* solver::flatten(quantifier* q) { 171 quantifier* q_flat = nullptr; 172 if (!has_quantifiers(q->get_expr())) 173 return q; 174 if (m_flat.find(q, q_flat)) 175 return q_flat; 176 proof_ref pr(m); 177 expr_ref new_q(m); 178 if (is_forall(q)) { 179 pull_quant pull(m); 180 pull(q, new_q, pr); 181 SASSERT(is_well_sorted(m, new_q)); 182 } 183 else { 184 new_q = q; 185 } 186 q_flat = to_quantifier(new_q); 187 m.inc_ref(q_flat); 188 m.inc_ref(q); 189 m_flat.insert(q, q_flat); 190 ctx.push(insert_ref2_map<ast_manager, quantifier, quantifier>(m, m_flat, q, q_flat)); 191 return q_flat; 192 } 193 init_units()194 void solver::init_units() { 195 if (!m_unit_table.empty()) 196 return; 197 for (euf::enode* n : ctx.get_egraph().nodes()) { 198 if (!n->interpreted() && !m.is_uninterp(n->get_expr()->get_sort())) 199 continue; 200 expr* e = n->get_expr(); 201 sort* s = e->get_sort(); 202 if (m_unit_table.contains(s)) 203 continue; 204 m_unit_table.insert(s, e); 205 ctx.push(insert_map<obj_map<sort, expr*>, sort*>(m_unit_table, s)); 206 } 207 } 208 get_unit(sort * s)209 expr* solver::get_unit(sort* s) { 210 expr* u = nullptr; 211 if (m_unit_table.find(s, u)) 212 return u; 213 init_units(); 214 if (m_unit_table.find(s, u)) 215 return u; 216 model mdl(m); 217 expr* val = mdl.get_some_value(s); 218 m.inc_ref(val); 219 m.inc_ref(s); 220 ctx.push(insert_ref2_map<ast_manager, sort, expr>(m, m_unit_table, s, val)); 221 return val; 222 } 223 expand(quantifier * q)224 bool solver::expand(quantifier* q) { 225 expr_ref r(m); 226 proof_ref pr(m); 227 m_der(q, r, pr); 228 m_expanded.reset(); 229 if (r != q) { 230 ctx.get_rewriter()(r); 231 m_expanded.push_back(r); 232 return true; 233 } 234 if (is_forall(q)) 235 flatten_and(q->get_expr(), m_expanded); 236 else if (is_exists(q)) 237 flatten_or(q->get_expr(), m_expanded); 238 else 239 UNREACHABLE(); 240 241 if (m_expanded.size() == 1 && is_forall(q)) { 242 m_expanded.reset(); 243 flatten_or(q->get_expr(), m_expanded); 244 expr_ref split1(m), split2(m), e1(m), e2(m); 245 unsigned idx = 0; 246 for (unsigned i = m_expanded.size(); i-- > 0; ) { 247 expr* arg = m_expanded.get(i); 248 if (split(arg, split1, split2)) { 249 if (e1) 250 return false; 251 e1 = split1; 252 e2 = split2; 253 idx = i; 254 } 255 } 256 if (!e1) 257 return false; 258 259 m_expanded[idx] = e1; 260 e1 = mk_or(m_expanded); 261 m_expanded[idx] = e2; 262 e2 = mk_or(m_expanded); 263 m_expanded.reset(); 264 m_expanded.push_back(e1); 265 m_expanded.push_back(e2); 266 } 267 if (m_expanded.size() > 1) { 268 for (unsigned i = m_expanded.size(); i-- > 0; ) { 269 expr_ref tmp(m.update_quantifier(q, m_expanded.get(i)), m); 270 ctx.get_rewriter()(tmp); 271 m_expanded[i] = tmp; 272 } 273 return true; 274 } 275 return false; 276 } 277 split(expr * arg,expr_ref & e1,expr_ref & e2)278 bool solver::split(expr* arg, expr_ref& e1, expr_ref& e2) { 279 expr* x, * y, * z; 280 if (m.is_not(arg, x) && m.is_or(x, y, z) && is_literal(y) && is_literal(z)) { 281 e1 = mk_not(m, y); 282 e2 = mk_not(m, z); 283 return true; 284 } 285 if (m.is_iff(arg, x, y) && is_literal(x) && is_literal(y)) { 286 e1 = m.mk_implies(x, y); 287 e2 = m.mk_implies(y, x); 288 return true; 289 } 290 if (m.is_and(arg, x, y) && is_literal(x) && is_literal(y)) { 291 e1 = x; 292 e2 = y; 293 return true; 294 } 295 if (m.is_not(arg, z) && m.is_iff(z, x, y) && is_literal(x) && is_literal(y)) { 296 e1 = m.mk_or(x, y); 297 e2 = m.mk_or(mk_not(m, x), mk_not(m, y)); 298 return true; 299 } 300 return false; 301 } 302 is_literal(expr * arg)303 bool solver::is_literal(expr* arg) { 304 m.is_not(arg, arg); 305 return !m.is_and(arg) && !m.is_or(arg) && !m.is_iff(arg) && !m.is_implies(arg); 306 } 307 get_antecedents(sat::literal l,sat::ext_justification_idx idx,sat::literal_vector & r,bool probing)308 void solver::get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) { 309 m_ematch.get_antecedents(l, idx, r, probing); 310 } 311 312 } 313