1 /**
2 Copyright (c) 2017 Microsoft Corporation
3 
4 Module Name:
5 
6     solver_pool.cpp
7 
8 Abstract:
9 
10    Maintain a pool of solvers
11 
12 Author:
13 
14     Nikolaj Bjorner
15 
16 Notes:
17 
18 --*/
19 
20 #include "solver/solver_pool.h"
21 #include "solver/solver_na2as.h"
22 #include "ast/proofs/proof_utils.h"
23 #include "ast/ast_util.h"
24 
25 class pool_solver : public solver_na2as {
26     solver_pool&       m_pool;
27     app_ref            m_pred;
28     proof_ref          m_proof;
29     ref<solver>        m_base;
30     expr_ref_vector    m_assertions;
31     unsigned           m_head;
32     expr_ref_vector    m_flat;
33     bool               m_pushed;
34     bool               m_in_delayed_scope;
35     bool               m_dump_benchmarks;
36     double             m_dump_threshold;
37     unsigned           m_dump_counter;
38 
39 
is_virtual() const40     bool is_virtual() const { return !m.is_true(m_pred); }
41 public:
pool_solver(solver * b,solver_pool & pool,app_ref & pred)42     pool_solver(solver* b, solver_pool& pool, app_ref& pred):
43         solver_na2as(pred.get_manager()),
44         m_pool(pool),
45         m_pred(pred),
46         m_proof(m),
47         m_base(b),
48         m_assertions(m),
49         m_head(0),
50         m_flat(m),
51         m_pushed(false),
52         m_in_delayed_scope(false),
53         m_dump_benchmarks(false),
54         m_dump_threshold(5.0),
55         m_dump_counter(0) {
56         if (is_virtual()) {
57             solver_na2as::assert_expr_core2(m.mk_true(), pred);
58         }
59         updt_params(m_base->get_params());
60     }
61 
~pool_solver()62     ~pool_solver() override {
63         if (m_pushed) pop(get_scope_level());
64         if (is_virtual()) {
65             m_pred = m.mk_not(m_pred);
66             m_base->assert_expr(m_pred);
67         }
68     }
69 
base_solver()70     solver* base_solver() { return m_base.get(); }
71 
translate(ast_manager & m,params_ref const & p)72     solver* translate(ast_manager& m, params_ref const& p) override { UNREACHABLE(); return nullptr; }
updt_params(params_ref const & p)73     void updt_params(params_ref const& p) override {
74         solver::updt_params(p); m_base->updt_params(p);
75         m_dump_benchmarks = solver::get_params().get_bool("dump_benchmarks", false);
76         m_dump_threshold = solver::get_params().get_double("dump_threshold", 5.0);
77     }
push_params()78     void push_params() override {m_base->push_params();}
pop_params()79     void pop_params() override {m_base->pop_params();}
80 
collect_param_descrs(param_descrs & r)81     void collect_param_descrs(param_descrs & r) override { m_base->collect_param_descrs(r); }
collect_statistics(statistics & st) const82     void collect_statistics(statistics & st) const override { m_base->collect_statistics(st); }
get_num_assertions() const83     unsigned get_num_assertions() const override { return m_base->get_num_assertions(); }
get_assertion(unsigned idx) const84     expr * get_assertion(unsigned idx) const override { return m_base->get_assertion(idx); }
85 
get_unsat_core(expr_ref_vector & r)86     void get_unsat_core(expr_ref_vector& r) override {
87         m_base->get_unsat_core(r);
88         unsigned j = 0;
89         for (unsigned i = 0; i < r.size(); ++i)
90             if (m_pred != r.get(i))
91                 r[j++] = r.get(i);
92         r.shrink(j);
93     }
94 
get_num_assumptions() const95     unsigned get_num_assumptions() const override {
96         unsigned sz = solver_na2as::get_num_assumptions();
97         return is_virtual() ? sz - 1 : sz;
98     }
99 
100 
get_proof()101     proof * get_proof() override {
102         scoped_watch _t_(m_pool.m_proof_watch);
103         if (!m_proof.get()) {
104             m_proof = m_base->get_proof();
105             if (m_proof) {
106                 elim_aux_assertions pc(m_pred);
107                 pc(m, m_proof, m_proof);
108             }
109         }
110         return m_proof;
111     }
112 
internalize_assertions()113     void internalize_assertions() {
114         SASSERT(!m_pushed || m_head == m_assertions.size());
115         for (unsigned sz = m_assertions.size(); m_head < sz; ++m_head) {
116             expr_ref f(m);
117             f = m.mk_implies(m_pred, (m_assertions.get(m_head)));
118             m_base->assert_expr(f);
119         }
120     }
121 
get_levels(ptr_vector<expr> const & vars,unsigned_vector & depth)122     void get_levels(ptr_vector<expr> const& vars, unsigned_vector& depth) override {
123         m_base->get_levels(vars, depth);
124     }
125 
get_trail()126     expr_ref_vector get_trail() override {
127         return m_base->get_trail();
128     }
129 
check_sat_core2(unsigned num_assumptions,expr * const * assumptions)130     lbool check_sat_core2(unsigned num_assumptions, expr * const * assumptions) override {
131         SASSERT(!m_pushed || get_scope_level() > 0);
132         m_proof.reset();
133         scoped_watch _t_(m_pool.m_check_watch);
134         m_pool.m_stats.m_num_checks++;
135 
136         stopwatch sw;
137         sw.start();
138         internalize_assertions();
139         lbool res = m_base->check_sat(num_assumptions, assumptions);
140         sw.stop();
141         switch (res) {
142         case l_true:
143             m_pool.m_check_sat_watch.add(sw);
144             m_pool.m_stats.m_num_sat_checks++;
145             break;
146         case l_undef:
147             m_pool.m_check_undef_watch.add(sw);
148             m_pool.m_stats.m_num_undef_checks++;
149             break;
150         default:
151             break;
152         }
153         set_status(res);
154 
155         if (m_dump_benchmarks && sw.get_seconds() >= m_dump_threshold) {
156             expr_ref_vector cube(m, num_assumptions, assumptions);
157             vector<expr_ref_vector> clauses;
158             dump_benchmark(cube, clauses, res, sw.get_seconds());
159         }
160         return res;
161     }
162 
check_sat_cc_core(expr_ref_vector const & cube,vector<expr_ref_vector> const & clauses)163     lbool check_sat_cc_core(expr_ref_vector const & cube,
164                             vector<expr_ref_vector> const & clauses) override {
165         SASSERT(!m_pushed || get_scope_level() > 0);
166         m_proof.reset();
167         scoped_watch _t_(m_pool.m_check_watch);
168         m_pool.m_stats.m_num_checks++;
169 
170         stopwatch sw;
171         sw.start();
172         internalize_assertions();
173         lbool res = m_base->check_sat_cc(cube, clauses);
174         sw.stop();
175         switch (res) {
176         case l_true:
177             m_pool.m_check_sat_watch.add(sw);
178             m_pool.m_stats.m_num_sat_checks++;
179             break;
180         case l_undef:
181             m_pool.m_check_undef_watch.add(sw);
182             m_pool.m_stats.m_num_undef_checks++;
183             break;
184         default:
185             break;
186         }
187         set_status(res);
188 
189         if (m_dump_benchmarks && sw.get_seconds() >= m_dump_threshold) {
190             dump_benchmark(cube, clauses, res, sw.get_seconds());
191         }
192         return res;
193     }
194 
push_core()195     void push_core() override {
196         SASSERT(!m_pushed || get_scope_level() > 0);
197         if (m_in_delayed_scope) {
198             // second push
199             internalize_assertions();
200             m_base->push();
201             m_pushed = true;
202             m_in_delayed_scope = false;
203         }
204 
205         if (!m_pushed) {
206             m_in_delayed_scope = true;
207         }
208         else {
209             SASSERT(!m_in_delayed_scope);
210             m_base->push();
211         }
212     }
213 
pop_core(unsigned n)214     void pop_core(unsigned n) override {
215         unsigned lvl = get_scope_level();
216         SASSERT(!m_pushed || lvl > 0);
217         if (m_pushed) {
218             SASSERT(!m_in_delayed_scope);
219             m_base->pop(n);
220             m_pushed = lvl - n > 0;
221         }
222         else {
223             m_in_delayed_scope = lvl - n > 0;
224         }
225     }
226 
assert_expr_core(expr * e)227     void assert_expr_core(expr * e) override {
228         SASSERT(!m_pushed || get_scope_level() > 0);
229         if (m.is_true(e)) return;
230         if (m_in_delayed_scope) {
231             internalize_assertions();
232             m_base->push();
233             m_pushed = true;
234             m_in_delayed_scope = false;
235         }
236 
237         if (m_pushed) {
238             m_base->assert_expr(e);
239         }
240         else {
241             m_flat.push_back(e);
242             flatten_and(m_flat);
243             m_assertions.append(m_flat);
244             m_flat.reset();
245         }
246     }
247 
get_model_core(model_ref & _m)248     void get_model_core(model_ref & _m) override { m_base->get_model_core(_m); }
249 
get_assumption(unsigned idx) const250     expr * get_assumption(unsigned idx) const override {
251         return solver_na2as::get_assumption(idx + is_virtual());
252     }
253 
reason_unknown() const254     std::string reason_unknown() const override { return m_base->reason_unknown(); }
set_reason_unknown(char const * msg)255     void set_reason_unknown(char const* msg) override { return m_base->set_reason_unknown(msg); }
get_labels(svector<symbol> & r)256     void get_labels(svector<symbol> & r) override { return m_base->get_labels(r); }
set_progress_callback(progress_callback * callback)257     void set_progress_callback(progress_callback * callback) override { m_base->set_progress_callback(callback); }
258 
cube(expr_ref_vector & vars,unsigned)259     expr_ref_vector cube(expr_ref_vector& vars, unsigned ) override { return expr_ref_vector(m); }
260 
get_manager() const261     ast_manager& get_manager() const override { return m_base->get_manager(); }
262 
refresh(solver * new_base)263     void refresh(solver* new_base) {
264         SASSERT(!m_pushed);
265         m_head = 0;
266         m_base = new_base;
267     }
268 
reset()269     void reset() {
270         SASSERT(!m_pushed);
271         m_head = 0;
272         m_assertions.reset();
273         m_pool.refresh(m_base.get());
274     }
275 
276 private:
277 
dump_benchmark(const expr_ref_vector & cube,vector<expr_ref_vector> const & clauses,lbool last_status,double last_time)278     void dump_benchmark(const expr_ref_vector &cube, vector<expr_ref_vector> const & clauses,
279                         lbool last_status, double last_time) {
280         std::string file_name = mk_file_name();
281         std::ofstream out(file_name);
282         STRACE("spacer.ind_gen", tout << "Dumping benchmark to " << file_name << "\n";);
283         if (!out) {
284             IF_VERBOSE(0, verbose_stream() << "could not open file " << file_name << " for output\n");
285             return;
286         }
287 
288         out << "(set-info :status " << lbool2status(last_status) << ")\n";
289         m_base->display(out, cube.size(), cube.c_ptr());
290         for (auto const& clause : clauses) {
291             out << ";; extra clause\n";
292             out << "(assert (or ";
293             for (auto *lit : clause) out << mk_pp(lit, m) << " ";
294             out << "))\n";
295         }
296 
297         out << "(check-sat";
298         for (auto * lit : cube) out << " " << mk_pp(lit, m) << "\n";
299         out << ")\n";
300 
301         out << "(exit)\n";
302         ::statistics st;
303         m_base->collect_statistics(st);
304         st.update("time", last_time);
305         st.display_smt2(out);
306         m_base->get_params().display(out);
307         out.close();
308     }
309 
lbool2status(lbool r) const310     char const* lbool2status(lbool r) const {
311         switch (r) {
312         case l_true:  return "sat";
313         case l_false: return "unsat";
314         case l_undef: return "unknown";
315         }
316         return "?";
317     }
318 
mk_file_name()319     std::string mk_file_name() {
320         std::stringstream file_name;
321         file_name << "pool_solver";
322         if (is_virtual()) file_name << "_" << m_pred->get_decl()->get_name();
323         file_name << "_" << (m_dump_counter++) << ".smt2";
324         return file_name.str();
325     }
326 
327 };
328 
solver_pool(solver * base_solver,unsigned num_pools)329 solver_pool::solver_pool(solver* base_solver, unsigned num_pools):
330     m_base_solver(base_solver),
331     m_num_pools(num_pools),
332     m_current_pool(0)
333 {
334     SASSERT(num_pools > 0);
335 }
336 
337 
get_base_solvers() const338 ptr_vector<solver> solver_pool::get_base_solvers() const {
339     ptr_vector<solver> solvers;
340     for (solver* s0 : m_solvers) {
341         pool_solver* s = dynamic_cast<pool_solver*>(s0);
342         if (!solvers.contains(s->base_solver())) {
343             solvers.push_back(s->base_solver());
344         }
345     }
346     return solvers;
347 }
348 
updt_params(const params_ref & p)349 void solver_pool::updt_params(const params_ref &p) {
350     m_base_solver->updt_params(p);
351     for (solver *s : m_solvers) s->updt_params(p);
352 }
collect_statistics(statistics & st) const353 void solver_pool::collect_statistics(statistics &st) const {
354     ptr_vector<solver> solvers = get_base_solvers();
355     for (solver* s : solvers) s->collect_statistics(st);
356     st.update("time.pool_solver.smt.total", m_check_watch.get_seconds());
357     st.update("time.pool_solver.smt.total.sat", m_check_sat_watch.get_seconds());
358     st.update("time.pool_solver.smt.total.undef", m_check_undef_watch.get_seconds());
359     st.update("time.pool_solver.proof", m_proof_watch.get_seconds());
360     st.update("pool_solver.checks", m_stats.m_num_checks);
361     st.update("pool_solver.checks.sat", m_stats.m_num_sat_checks);
362     st.update("pool_solver.checks.undef", m_stats.m_num_undef_checks);
363 }
364 
reset_statistics()365 void solver_pool::reset_statistics() {
366 #if 0
367     ptr_vector<solver> solvers = get_base_solvers();
368     for (solver* s : solvers) {
369         s->reset_statistics();
370     }
371 #endif
372     m_stats.reset();
373     m_check_sat_watch.reset();
374     m_check_undef_watch.reset();
375     m_check_watch.reset();
376     m_proof_watch.reset();
377 }
378 
379 /**
380    \brief Create a fresh solver instance.
381    The first num_pools solvers are independent and
382    use a fresh instance of the base solver.
383    Subsequent solvers reuse the first num_polls base solvers, rotating
384    among the first num_pools.
385 */
mk_solver()386 solver* solver_pool::mk_solver() {
387     ref<solver> base_solver;
388     ast_manager& m = m_base_solver->get_manager();
389     if (m_solvers.size() < m_num_pools) {
390         base_solver = m_base_solver->translate(m, m_base_solver->get_params());
391     }
392     else {
393         solver* s = m_solvers[(m_current_pool++) % m_num_pools];
394         base_solver = dynamic_cast<pool_solver*>(s)->base_solver();
395     }
396     std::stringstream name;
397     name << "vsolver#" << m_solvers.size();
398     app_ref pred(m.mk_const(symbol(name.str()), m.mk_bool_sort()), m);
399     pool_solver* solver = alloc(pool_solver, base_solver.get(), *this, pred);
400     m_solvers.push_back(solver);
401     return solver;
402 }
403 
reset_solver(solver * s)404 void solver_pool::reset_solver(solver* s) {
405     pool_solver* ps = dynamic_cast<pool_solver*>(s);
406     SASSERT(ps);
407     if (ps) ps->reset();
408 }
409 
refresh(solver * base_solver)410 void solver_pool::refresh(solver* base_solver) {
411     ast_manager& m = m_base_solver->get_manager();
412     ref<solver> new_base = m_base_solver->translate(m, m_base_solver->get_params());
413     for (solver* s0 : m_solvers) {
414         pool_solver* s = dynamic_cast<pool_solver*>(s0);
415         if (base_solver == s->base_solver()) {
416             s->refresh(new_base.get());
417         }
418     }
419 }
420