1 /*++
2 Copyright (c) 2019 Microsoft Corporation
3 
4 Module Name:
5 
6     hoist_rewriter.cpp
7 
8 Abstract:
9 
10     Hoist predicates over disjunctions
11 
12 Author:
13 
14     Nikolaj Bjorner (nbjorner) 2019-2-4
15 
16 Notes:
17 
18 --*/
19 
20 
21 #include "ast/rewriter/hoist_rewriter.h"
22 #include "ast/ast_util.h"
23 #include "ast/ast_pp.h"
24 #include "ast/ast_ll_pp.h"
25 
26 
hoist_rewriter(ast_manager & m,params_ref const & p)27 hoist_rewriter::hoist_rewriter(ast_manager & m, params_ref const & p):
28     m_manager(m), m_args1(m), m_args2(m), m_subst(m) {
29     updt_params(p);
30 }
31 
mk_or(unsigned num_args,expr * const * es,expr_ref & result)32 br_status hoist_rewriter::mk_or(unsigned num_args, expr * const * es, expr_ref & result) {
33     if (num_args < 2) {
34         return BR_FAILED;
35     }
36     for (unsigned i = 0; i < num_args; ++i) {
37         if (!is_and(es[i], nullptr)) {
38             return BR_FAILED;
39         }
40     }
41 
42     bool turn = false;
43     m_preds1.reset();
44     m_preds2.reset();
45     m_uf1.reset();
46     m_uf2.reset();
47     m_expr2var.reset();
48     m_var2expr.reset();
49     basic_union_find* uf[2] = { &m_uf1, &m_uf2 };
50     obj_hashtable<expr>* preds[2] = { &m_preds1, &m_preds2 };
51     expr_ref_vector* args[2] = { &m_args1, &m_args2 };
52     VERIFY(is_and(es[0], args[turn]));
53     expr* e1, *e2;
54     for (expr* e : *(args[turn])) {
55         if (m().is_eq(e, e1, e2)) {
56             (*uf)[turn].merge(mk_var(e1), mk_var(e2));
57         }
58         else {
59             (*preds)[turn].insert(e);
60         }
61     }
62     unsigned round = 0;
63     for (unsigned j = 1; j < num_args; ++j) {
64         ++round;
65         m_es.reset();
66         m_mark.reset();
67 
68         bool last = turn;
69         turn = !turn;
70         (*preds)[turn].reset();
71         reset(m_uf0);
72         VERIFY(is_and(es[j], args[turn]));
73 
74         for (expr* e : *args[turn]) {
75             if (m().is_eq(e, e1, e2)) {
76                 m_es.push_back(e1);
77                 m_uf0.merge(mk_var(e1), mk_var(e2));
78             }
79             else if ((*preds)[last].contains(e)) {
80                 (*preds)[turn].insert(e);
81             }
82         }
83 
84         if ((*preds)[turn].empty() && m_es.empty()) {
85             return BR_FAILED;
86         }
87 
88         m_eqs.reset();
89         for (expr* e : m_es) {
90             if (m_mark.is_marked(e)) {
91                 continue;
92             }
93             unsigned u = mk_var(e);
94             unsigned v = u;
95             m_roots.reset();
96             do {
97                 m_mark.mark(e);
98                 unsigned r = (*uf)[last].find(v);
99                 if (m_roots.find(r, e2)) {
100                     m_eqs.push_back(std::make_pair(e, e2));
101                 }
102                 else {
103                     m_roots.insert(r, e);
104                 }
105                 v = m_uf0.next(v);
106                 e = mk_expr(v);
107             }
108             while (u != v);
109         }
110         reset((*uf)[turn]);
111         for (auto const& p : m_eqs)
112             (*uf)[turn].merge(mk_var(p.first), mk_var(p.second));
113         if ((*preds)[turn].empty() && m_eqs.empty())
114             return BR_FAILED;
115     }
116     if (m_eqs.empty()) {
117         result = hoist_predicates((*preds)[turn], num_args, es);
118         return BR_DONE;
119     }
120     // p & eqs & (or fmls)
121     expr_ref_vector fmls(m());
122     m_subst.reset();
123     for (expr * p : (*preds)[turn]) {
124         expr* q = nullptr;
125         if (m().is_not(p, q)) {
126             m_subst.insert(q, m().mk_false());
127         }
128         else {
129             m_subst.insert(p, m().mk_true());
130         }
131         fmls.push_back(p);
132     }
133     for (auto& p : m_eqs) {
134         if (m().is_value(p.first))
135             std::swap(p.first, p.second);
136         m_subst.insert(p.first, p.second);
137         fmls.push_back(m().mk_eq(p.first, p.second));
138     }
139     expr_ref ors(::mk_or(m(), num_args, es), m());
140     m_subst(ors);
141     fmls.push_back(ors);
142     result = mk_and(fmls);
143     TRACE("hoist", tout << ors << " => " << result << "\n";);
144     return BR_DONE;
145 }
146 
mk_var(expr * e)147 unsigned hoist_rewriter::mk_var(expr* e) {
148     unsigned v = 0;
149     if (m_expr2var.find(e, v)) {
150         return v;
151     }
152     m_uf1.mk_var();
153     v = m_uf2.mk_var();
154     SASSERT(v == m_var2expr.size());
155     m_expr2var.insert(e, v);
156     m_var2expr.push_back(e);
157     return v;
158 }
159 
hoist_predicates(obj_hashtable<expr> const & preds,unsigned num_args,expr * const * es)160 expr_ref hoist_rewriter::hoist_predicates(obj_hashtable<expr> const& preds, unsigned num_args, expr* const* es) {
161     expr_ref result(m());
162     expr_ref_vector args(m()), fmls(m());
163     for (unsigned i = 0; i < num_args; ++i) {
164         VERIFY(is_and(es[i], &m_args1));
165         fmls.reset();
166         for (expr* e : m_args1) {
167             if (!preds.contains(e))
168                 fmls.push_back(e);
169         }
170         args.push_back(::mk_and(fmls));
171     }
172     fmls.reset();
173     fmls.push_back(::mk_or(args));
174     for (auto* p : preds)
175         fmls.push_back(p);
176     result = ::mk_and(fmls);
177     return result;
178 }
179 
180 
mk_app_core(func_decl * f,unsigned num_args,expr * const * args,expr_ref & result)181 br_status hoist_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * const * args, expr_ref & result) {
182     switch (f->get_decl_kind()) {
183     case OP_OR:
184         return mk_or(num_args, args, result);
185     default:
186         return BR_FAILED;
187     }
188 }
189 
is_and(expr * e,expr_ref_vector * args)190 bool hoist_rewriter::is_and(expr * e, expr_ref_vector* args) {
191     if (m().is_and(e)) {
192         if (args) {
193             args->reset();
194             args->append(to_app(e)->get_num_args(), to_app(e)->get_args());
195         }
196         return true;
197     }
198     if (m().is_not(e, e) && m().is_or(e)) {
199         if (args) {
200             args->reset();
201             for (expr* arg : *to_app(e)) {
202                 args->push_back(::mk_not(m(), arg));
203             }
204         }
205         return true;
206     }
207     return false;
208 }
209 
210 
reset(basic_union_find & uf)211 void hoist_rewriter::reset(basic_union_find& uf) {
212     uf.reset();
213     for (expr* e : m_var2expr) {
214         (void)e;
215         uf.mk_var();
216     }
217 }
218