1 /*++
2 Copyright (c) 2012 Microsoft Corporation
3 
4 Module Name:
5 
6     smt_implied_equalities.cpp
7 
8 Abstract:
9 
10     Procedure for obtaining implied equalities relative to the
11     state of a solver.
12 
13 Author:
14 
15     Nikolaj Bjorner (nbjorner) 2012-02-29
16 
17 Revision History:
18 
19 
20 --*/
21 
22 #include "smt/smt_implied_equalities.h"
23 #include "util/union_find.h"
24 #include "ast/ast_pp.h"
25 #include "ast/array_decl_plugin.h"
26 #include "util/uint_set.h"
27 #include "smt/smt_value_sort.h"
28 #include "model/model_smt2_pp.h"
29 #include "util/stopwatch.h"
30 #include "model/model.h"
31 #include "solver/solver.h"
32 
33 namespace {
34 
35     class get_implied_equalities_impl {
36 
37         ast_manager&                       m;
38         solver&                            m_solver;
39         union_find_default_ctx             m_df;
40         union_find<union_find_default_ctx> m_uf;
41         array_util                         m_array_util;
42         stopwatch                          m_stats_timer;
43         unsigned                           m_stats_calls;
44         stopwatch                          m_stats_val_eq_timer;
45         static stopwatch                   s_timer;
46         static stopwatch                   s_stats_val_eq_timer;
47 
48         struct term_id {
49             expr_ref term;
50             unsigned id;
term_id__anoncb4f4c600111::get_implied_equalities_impl::term_id51             term_id(expr_ref t, unsigned id): term(t), id(id) {}
52         };
53 
54         typedef vector<term_id> term_ids;
55 
56         typedef obj_map<sort, term_ids> sort2term_ids; // partition of terms by sort.
57 
partition_terms(unsigned num_terms,expr * const * terms,sort2term_ids & termids)58         void partition_terms(unsigned num_terms, expr* const* terms, sort2term_ids& termids) {
59             for (unsigned i = 0; i < num_terms; ++i) {
60                 sort* s = m.get_sort(terms[i]);
61                 term_ids& vec = termids.insert_if_not_there(s, term_ids());
62                 vec.push_back(term_id(expr_ref(terms[i],m), i));
63             }
64         }
65 
66         /**
67            \brief Basic implied equalities method.
68            It performs a simple N^2 loop over all pairs of terms.
69 
70            n1, .., n_k,
71            t1, .., t_l
72         */
73 
get_implied_equalities_filter_basic(uint_set const & non_values,term_ids & terms)74         void get_implied_equalities_filter_basic(uint_set const& non_values, term_ids& terms) {
75             m_stats_timer.start();
76             uint_set root_indices;
77             for (unsigned j = 0; j < terms.size(); ++j) {
78                 if (terms[j].id == m_uf.find(terms[j].id)) {
79                     root_indices.insert(j);
80                 }
81             }
82             uint_set::iterator it = non_values.begin(), end = non_values.end();
83 
84             for (; it != end; ++it) {
85                 unsigned i = *it;
86                 expr* t = terms[i].term;
87                 uint_set::iterator it2 = root_indices.begin(), end2 = root_indices.end();
88                 bool found_root_value = false;
89                 for (; it2 != end2; ++it2) {
90                     unsigned j = *it2;
91                     if (j == i) continue;
92                     if (j < i && non_values.contains(j)) continue;
93                     if (found_root_value && !non_values.contains(j)) continue;
94                     expr* s = terms[j].term;
95                     SASSERT(m.get_sort(t) == m.get_sort(s));
96                     ++m_stats_calls;
97                     m_solver.push();
98                     m_solver.assert_expr(m.mk_not(m.mk_eq(s, t)));
99                     bool is_eq = l_false == m_solver.check_sat(0,nullptr);
100                     m_solver.pop(1);
101                     TRACE("get_implied_equalities", tout << mk_pp(t, m) << " = " << mk_pp(s, m) << " " << (is_eq?"eq":"unrelated") << "\n";);
102                     if (is_eq) {
103                         m_uf.merge(terms[i].id, terms[j].id);
104                         if (!non_values.contains(j)) {
105                             found_root_value = true;
106                         }
107                     }
108                 }
109             }
110             m_stats_timer.stop();
111         }
112 
get_implied_equalities_basic(term_ids & terms)113         void get_implied_equalities_basic(term_ids& terms) {
114             for (unsigned i = 0; i < terms.size(); ++i) {
115                 if (terms[i].id != m_uf.find(terms[i].id)) {
116                     continue;
117                 }
118                 expr* t = terms[i].term;
119                 for (unsigned j = 0; j < i; ++j) {
120                     expr* s = terms[j].term;
121                     SASSERT(m.get_sort(t) == m.get_sort(s));
122                     ++m_stats_calls;
123                     m_stats_timer.start();
124                     m_solver.push();
125                     m_solver.assert_expr(m.mk_not(m.mk_eq(s, t)));
126                     bool is_eq = l_false == m_solver.check_sat(0,nullptr);
127                     m_solver.pop(1);
128                     m_stats_timer.stop();
129                     TRACE("get_implied_equalities", tout << mk_pp(t, m) << " = " << mk_pp(s, m) << " " << (is_eq?"eq":"unrelated") << "\n";);
130                     if (is_eq) {
131                         m_uf.merge(terms[i].id, terms[j].id);
132                         break;
133                     }
134                 }
135             }
136         }
137 
138         /**
139            \brief Extract implied equalities for a collection of terms in the current context.
140 
141            The routine relies on model values being unique for equal terms.
142            So in particular, arrays that are equal should be canonized to the same value.
143            This is not the case for Z3's models of arrays.
144            Arrays are treated by extensionality: introduce a fresh index and compare
145            the select of the arrays.
146         */
get_implied_equalities_model_based(model_ref & model,term_ids & terms)147         void get_implied_equalities_model_based(model_ref& model, term_ids& terms) {
148 
149             SASSERT(!terms.empty());
150 
151             sort* srt = m.get_sort(terms[0].term);
152 
153             if (m_array_util.is_array(srt)) {
154 
155                 m_solver.push();
156                 unsigned arity = get_array_arity(srt);
157                 expr_ref_vector args(m);
158                 args.push_back(nullptr);
159                 for (unsigned i = 0; i < arity; ++i) {
160                     sort* srt_i = get_array_domain(srt, i);
161                     expr* idx = m.mk_fresh_const("index", srt_i);
162                     args.push_back(idx);
163                 }
164                 for (unsigned i = 0; i < terms.size(); ++i) {
165                     args[0] = terms[i].term;
166                     terms[i].term = m.mk_app(m_array_util.get_family_id(), OP_SELECT, 0, nullptr, args.size(), args.c_ptr());
167                 }
168                 assert_relevant(terms);
169                 VERIFY(m_solver.check_sat(0,nullptr) != l_false);
170                 model_ref model1;
171                 m_solver.get_model(model1);
172                 SASSERT(model1.get());
173                 get_implied_equalities_model_based(model1, terms);
174                 m_solver.pop(1);
175                 return;
176             }
177 
178             uint_set non_values;
179 
180             if (!smt::is_value_sort(m, srt)) {
181                 for (unsigned i = 0; i < terms.size(); ++i) {
182                     non_values.insert(i);
183                 }
184                 get_implied_equalities_filter_basic(non_values, terms);
185                 //get_implied_equalities_basic(terms);
186                 return;
187             }
188 
189             expr_ref_vector vals(m);
190             expr_ref vl(m), eq(m);
191             obj_map<expr, unsigned_vector>  vals_map;
192 
193             m_stats_val_eq_timer.start();
194             s_stats_val_eq_timer.start();
195 
196             params_ref p;
197             p.set_bool("produce_models", false);
198             m_solver.updt_params(p);
199 
200             for (unsigned i = 0; i < terms.size(); ++i) {
201                 expr* t = terms[i].term;
202                 vl = (*model)(t);
203                 TRACE("get_implied_equalities", tout << mk_pp(t, m) << " |-> " << mk_pp(vl, m) << "\n";);
204                 reduce_value(model, vl);
205                 if (!m.is_value(vl)) {
206                     TRACE("get_implied_equalities", tout << "Not a value: " << mk_pp(vl, m) << "\n";);
207                     non_values.insert(i);
208                     continue;
209                 }
210                 vals.push_back(vl);
211                 unsigned_vector& vec = vals_map.insert_if_not_there(vl, unsigned_vector());
212                 bool found = false;
213 
214                 for (unsigned j = 0; !found && j < vec.size(); ++j) {
215                     expr* s = terms[vec[j]].term;
216                     m_solver.push();
217                     m_solver.assert_expr(m.mk_not(m.mk_eq(t, s)));
218                     lbool is_sat = m_solver.check_sat(0,nullptr);
219                     m_solver.pop(1);
220                     TRACE("get_implied_equalities", tout << mk_pp(t, m) << " = " << mk_pp(s, m) << " " << is_sat << "\n";);
221                     if (is_sat == l_false) {
222                         found = true;
223                         m_uf.merge(terms[i].id, terms[vec[j]].id);
224                     }
225                 }
226                 if (!found) {
227                     vec.push_back(i);
228                 }
229             }
230             m_stats_val_eq_timer.stop();
231             s_stats_val_eq_timer.stop();
232             p.set_bool("produce_models", true);
233             m_solver.updt_params(p);
234 
235 
236             if (!non_values.empty()) {
237                 TRACE("get_implied_equalities", model_smt2_pp(tout, m, *model, 0););
238                 get_implied_equalities_filter_basic(non_values, terms);
239                 //get_implied_equalities_basic(terms);
240             }
241         }
242 
243 
get_implied_equalities_core(model_ref & model,term_ids & terms)244         void get_implied_equalities_core(model_ref& model, term_ids& terms) {
245             get_implied_equalities_model_based(model, terms);
246             //get_implied_equalities_basic(terms);
247         }
248 
249 
assert_relevant(unsigned num_terms,expr * const * terms)250         void assert_relevant(unsigned num_terms, expr* const* terms) {
251             for (unsigned i = 0; i < num_terms; ++i) {
252                 sort* srt = m.get_sort(terms[i]);
253                 if (!m_array_util.is_array(srt)) {
254                     m_solver.assert_expr(m.mk_app(m.mk_func_decl(symbol("Relevant!"), 1, &srt, m.mk_bool_sort()), terms[i]));
255                 }
256             }
257         }
258 
assert_relevant(term_ids & terms)259         void assert_relevant(term_ids& terms) {
260             for (unsigned i = 0; i < terms.size(); ++i) {
261                 expr* t = terms[i].term;
262                 sort* srt = m.get_sort(t);
263                 if (!m_array_util.is_array(srt)) {
264                     m_solver.assert_expr(m.mk_app(m.mk_func_decl(symbol("Relevant!"), 1, &srt, m.mk_bool_sort()), t));
265                 }
266             }
267         }
268 
reduce_value(model_ref & model,expr_ref & vl)269         void reduce_value(model_ref& model, expr_ref& vl) {
270             expr* c, *e1, *e2;
271             while (m.is_ite(vl, c, e1, e2)) {
272                 lbool r = reduce_cond(model, c);
273                 switch(r) {
274                 case l_true:
275                     vl = e1;
276                     break;
277                 case l_false:
278                     vl = e2;
279                     break;
280                 default:
281                     return;
282                 }
283             }
284         }
285 
reduce_cond(model_ref & model,expr * e)286         lbool reduce_cond(model_ref& model, expr* e) {
287             expr* e1 = nullptr, *e2 = nullptr;
288             if (m.is_eq(e, e1, e2) && m_array_util.is_as_array(e1) && m_array_util.is_as_array(e2)) {
289                 if (e1 == e2) {
290                     return l_true;
291                 }
292                 func_decl* f1 = m_array_util.get_as_array_func_decl(to_app(e1));
293                 func_decl* f2 = m_array_util.get_as_array_func_decl(to_app(e2));
294                 func_interp* fi1 = model->get_func_interp(f1);
295                 func_interp* fi2 = model->get_func_interp(f2);
296                 if (fi1 == fi2) {
297                     return l_true;
298                 }
299                 unsigned n1 = fi1->num_entries();
300                 for (unsigned i = 0; i < n1; ++i) {
301                     func_entry const* h1 = fi1->get_entry(i);
302                     for (unsigned j = 0; j < fi1->get_arity(); ++j) {
303                         if (!m.is_value(h1->get_arg(j))) {
304                             return l_undef;
305                         }
306                     }
307                     func_entry* h2 = fi2->get_entry(h1->get_args());
308                     if (h2 &&
309                         h1->get_result() != h2->get_result() &&
310                         m.is_value(h1->get_result()) &&
311                         m.is_value(h2->get_result())) {
312                         return l_false;
313                     }
314                 }
315             }
316             return l_undef;
317         }
318 
319     public:
320 
get_implied_equalities_impl(ast_manager & m,solver & s)321         get_implied_equalities_impl(ast_manager& m, solver& s) : m(m), m_solver(s), m_uf(m_df), m_array_util(m), m_stats_calls(0) {}
322 
operator ()(unsigned num_terms,expr * const * terms,unsigned * class_ids)323         lbool operator()(unsigned num_terms, expr* const* terms, unsigned* class_ids) {
324             params_ref p;
325             p.set_bool("produce_models", true);
326             m_solver.updt_params(p);
327             sort2term_ids termids;
328             stopwatch timer;
329             timer.start();
330             s_timer.start();
331 
332             for (unsigned i = 0; i < num_terms; ++i) {
333                 m_uf.mk_var();
334             }
335 
336             m_solver.push();
337             assert_relevant(num_terms, terms);
338             lbool is_sat = m_solver.check_sat(0,nullptr);
339 
340             if (is_sat != l_false) {
341                 model_ref model;
342                 m_solver.get_model(model);
343                 SASSERT(model.get());
344 
345                 partition_terms(num_terms, terms, termids);
346                 sort2term_ids::iterator it = termids.begin(), end = termids.end();
347                 for (; it != end; ++it) {
348                     term_ids& term_ids = it->m_value;
349                     get_implied_equalities_core(model, term_ids);
350                     for (unsigned i = 0; i < term_ids.size(); ++i) {
351                         class_ids[term_ids[i].id] = m_uf.find(term_ids[i].id);
352                     }
353                 }
354                 TRACE("get_implied_equalities",
355                       for (unsigned i = 0; i < num_terms; ++i) {
356                           tout << mk_pp(terms[i], m) << " |-> " << class_ids[i] << "\n";
357                       });
358             }
359             m_solver.pop(1);
360             timer.stop();
361             s_timer.stop();
362             IF_VERBOSE(1, verbose_stream()  << s_timer.get_seconds() << "\t" << num_terms << "\t"
363                        << timer.get_seconds()   << "\t" << m_stats_calls << "\t"
364                        << m_stats_timer.get_seconds() << "\t"
365                        << m_stats_val_eq_timer.get_seconds() << "\t"
366                        << s_stats_val_eq_timer.get_seconds() << "\n";);
367             return is_sat;
368         }
369     };
370 
371     stopwatch get_implied_equalities_impl::s_timer;
372     stopwatch get_implied_equalities_impl::s_stats_val_eq_timer;
373 }
374 
375 namespace smt {
implied_equalities(ast_manager & m,solver & solver,unsigned num_terms,expr * const * terms,unsigned * class_ids)376     lbool implied_equalities(ast_manager& m, solver& solver, unsigned num_terms, expr* const* terms, unsigned* class_ids) {
377         get_implied_equalities_impl gi(m, solver);
378         return gi(num_terms, terms, class_ids);
379     }
380 }
381 
382 
383 
384 
385 
386 
387 
388 #if 0
389     // maxsat class for internal purposes.
390     class maxsat {
391         ast_manager& m;
392         solver&      m_solver;
393     public:
394         maxsat(solver& s) : m(s.m()), m_solver(s) {}
395 
396         lbool operator()(ptr_vector<expr>& soft_cnstrs) {
397             return l_undef;
398         }
399 
400     };
401 
402     class term_equivs {
403         union_find_default_ctx             m_df;
404         union_find<union_find_default_ctx> m_uf;
405         obj_map<expr,unsigned>             m_term2idx;
406         ptr_vector<expr>                   m_idx2term;
407 
408     public:
409         term_equivs(): m_uf(m_df) {}
410 
411         void merge(expr* t, expr* s) {
412             m_uf.merge(var(t), var(s));
413         }
414     private:
415         unsigned var(expr* t) {
416             map::obj_map_entry* e = m_term2idx.insert_if_not_there(t, m_idx2term.size());
417             unsigned idx = e->get_data().m_value;
418             if (idx == m_idx2term.size()) {
419                 m_idx2term.push_back(t);
420             }
421             return idx;
422         }
423     };
424 
425     /**
426        \brief class to find implied equalities.
427 
428        It implements the following half-naive algorithm.
429        The algorithm is half-naive because the terms being checked for equivalence class membership
430        are foreign and it is up to the theory integration whether pairs of interface equalities
431        are checked. The idea is that the model-based combination would avoid useless equality literals
432        in the core.
433        An alternative algorithm could use 'distinct' and an efficient solver for 'distinct'.
434 
435        Given terms t1, ..., tn, of the same type.
436        - assert f(t1) = 1, .., f(tn) = n.
437        - find MAX-SAT set A1, let the other literals be in B.
438        - find MAX-SAT set of B, put it in A2, etc.
439        - we now have MAX-SAT sets A1, A2, ... A_m.
440        - terms in each set A_i can be different, but cannot be different at the same time as elements in A_{i+1}.
441        - for i = m to 2 do:
442        -   Let A = A_i B = A_{i-1}
443        -   assert g(A) = 0, g(B) = 1
444        -   find MAX-SAT set C over this constraint.
445        -   For each element t from A\C
446        -           check if g(t) = 0 and g(B) = 1 is unsat
447        -           minimize core, if there is pair such that
448        -           g(t) = 0, g(b) = 1 is unsat, then equality is forced.
449     */
450 
451     class implied_equalities_finder {
452         ast_manager& m;
453         solver&      m_solver;
454         term_equivs  m_find;
455         expr_ref_vector m_refs;
456         obj_map<expr,expr*> m_fs; // t_i -> f(t_i) = i
457         obj_map<expr,epxr*> m_gs; // t_i -> g(t_i)
458 
459     public:
460         implied_equalities_finder(solver& solver): m(solver.m()), m_solver(solver), m_refs(m) {}
461 
462         lbool operator()(unsigned num_terms, expr* const* terms, unsigned* class_ids) {
463             m_find.reset();
464             //
465             return l_undef;
466         }
467     private:
468 
469         void initialize(unsigned num_terms, expr* const* terms) {
470             sort_ref bv(m);
471             expr_ref eq(m), g(m), eq_proxy(m);
472             symbol f("f"), g("g");
473             unsigned log_terms = 1, nt = num_terms;
474             while (nt > 0) { log_terms++; nt /= 2; }
475 
476             bv = m_bv.mk_bv_sort(log_terms);
477             for (unsigned i = 0; i < num_terms; ++i) {
478                 expr* t = terms[i];
479                 sort* s = m.get_sort(t);
480                 eq = m.mk_eq(m.mk_app(m.mk_func_decl(f, 1, &s, bv), t), m_bv.mk_numeral(rational(i), bv));
481                 eq_proxy = m.mk_fresh_const("f", m.mk_bool_sort());
482                 m_solver.assert_expr(m.mk_iff(eq, eq_proxy));
483                 g = m.mk_app(m.mk_func_decl(g, 1, &s, bv), t)
484                 m_fs.insert(t, eq_proxy);
485                 m_gs.insert(t, g);
486             }
487         }
488 
489         //
490         // For each t in src, check if t can be different from all s in dst.
491         // - if it can, then add t to dst.
492         // - if it cannot, then record equivalence class.
493         //
494         void merge_classes(expr_ref_vector& src, expr_ref_vector& dst, equivs& eqs) {
495 
496         }
497     };
498 
499     lbool implied_equalities_core_based(
500         solver& solver,
501         unsigned num_terms, expr* const* terms,
502         unsigned* class_ids,
503         unsigned num_assumptions, expr * const * assumptions) {
504         implied_equalities_finder ief(solver);
505 
506         solver.push();
507         for (unsigned i = 0; i < num_assumptions; ++i) {
508             solver.assert_expr(assumptions[i]);
509         }
510         lbool is_sat = ief(num_terms, terms, class_ids);
511         solver.pop(1);
512 
513         return is_sat;
514     }
515 
516         /**
517            \brief Extract implied equalities for a collection of terms in the current context.
518 
519            The routine uses a partition refinement approach.
520            It assumes that all terms have the same sort.
521 
522            Initially, create the equalities E_1: t0 = t1, E_2: t1 = t2, ..., E_n: t_{n-1} = t_n
523 
524            Check if ! (E_1 & E_2 & ... & E_n) is satisfiable.
525 
526            if it is unsat, then all terms are equal.
527            Otherwise, partition the terms by the equalities that are true in the current model,
528            iterate.
529 
530 
531            This version does not attempt to be economical on how many equalities are introduced and the
532            size of the resulting clauses. The more advanced version of this approach re-uses
533            equalities from a previous iteration and also represents a binary tree of propositional variables
534            that cover multiple equalities. Eg.,
535 
536                  E_12 => E_1 & E_2,   E_34 => E_3 & E_4, ...
537 
538 
539         */
540 
541         void get_implied_equalities_eq_based(term_ids& terms) {
542             expr_ref_vector eqs(m);
543             if (terms.size() == 1) {
544                 return;
545             }
546             m_solver.push();
547             for (unsigned i = 0; i + 1 < terms.size(); ++i) {
548                 expr* eq = m.mk_eq(terms[i].term, terms[i+1].term);
549                 expr* eq_lit = m.mk_fresh_const("E", m.mk_bool_sort());
550                 eqs.push_back(eq_lit);
551                 m_solver.assert_expr(m.mk_implies(eq_lit, eq));
552             }
553             m_solver.assert_expr(m.mk_not(m.mk_and(eqs.size(), eqs.c_ptr())));
554             lbool is_sat = m_solver.check_sat(0,0);
555             switch(is_sat) {
556             case l_false:
557                 for (unsigned i = 0; i + 1 < terms.size(); ++i) {
558                     m_uf.merge(terms[i].id, terms[i+1].id);
559                 }
560                 break;
561             default: {
562                 term_ids tems2;
563                 for (unsigned i = 0; i + 1 < terms.size(); ++i) {
564                     expr_ref vl(m);
565                     model->eval(terms[i].term, vl);
566                     if (m.is_false(vl)) {
567 
568                     }
569                 }
570                 break;
571             }
572             }
573             m_solver.pop(1);
574         }
575 
576 
577 #endif
578 
579 
580 
581 
582 
583