1 /*++
2 Copyright (c) 2009 Microsoft Corporation
3 
4 Module Name:
5 
6     bit2cpp.cpp
7 
8 Abstract:
9 
10     Routines for simplifying bit2int expressions.
11     This propagates bv2int over arithmetical symbols as much as possible,
12     converting arithmetic operations into bit-vector operations.
13 
14 Author:
15 
16     Nikolaj Bjorner (nbjorner) 2009-08-28
17 
18 Revision History:
19 
20 --*/
21 
22 #include "ast/ast_pp.h"
23 #include "ast/ast_ll_pp.h"
24 #include "ast/for_each_ast.h"
25 #include "ast/rewriter/bit2int.h"
26 
bit2int(ast_manager & m)27 bit2int::bit2int(ast_manager & m) :
28     m(m), m_bv_util(m), m_rewriter(m), m_arith_util(m), m_cache(m, false), m_bit0(m) {
29     m_bit0     = m_bv_util.mk_numeral(0,1);
30 }
31 
operator ()(expr * n,expr_ref & result,proof_ref & p)32 void bit2int::operator()(expr * n, expr_ref & result, proof_ref& p) {
33     flush_cache();
34     expr_reduce emap(*this);
35     for_each_ast(emap, n);
36     result = get_cached(n);
37     if (m.proofs_enabled() && n != result.get()) {
38         // TBD: rough
39         p = m.mk_rewrite(n, result);
40     }
41     TRACE("bit2int",
42           tout << mk_pp(n, m) << "======>\n" << result << "\n";);
43 
44 }
45 
46 
get_b2i_size(expr * n)47 unsigned bit2int::get_b2i_size(expr* n) {
48     expr* arg = nullptr;
49     VERIFY(m_bv_util.is_bv2int(n, arg));
50     return m_bv_util.get_bv_size(arg);
51 }
52 
get_numeral_bits(numeral const & k)53 unsigned bit2int::get_numeral_bits(numeral const& k) {
54     numeral two(2);
55     numeral n(abs(k));
56     unsigned num_bits = 1;
57     n = div(n, two);
58     while (n.is_pos()) {
59         ++num_bits;
60         n = div(n, two);
61     }
62     return num_bits;
63 }
64 
align_size(expr * e,unsigned sz,expr_ref & result)65 void bit2int::align_size(expr* e, unsigned sz, expr_ref& result) {
66     unsigned sz1 = m_bv_util.get_bv_size(e);
67     SASSERT(sz1 <= sz);
68     result = m_rewriter.mk_zero_extend(sz - sz1, e);
69 }
70 
align_sizes(expr_ref & a,expr_ref & b)71 void bit2int::align_sizes(expr_ref& a, expr_ref& b) {
72     unsigned sz1 = m_bv_util.get_bv_size(a);
73     unsigned sz2 = m_bv_util.get_bv_size(b);
74     if (sz1 > sz2) {
75         b = m_rewriter.mk_zero_extend(sz1 - sz2, b);
76     }
77     else if (sz2 > sz1) {
78         a = m_rewriter.mk_zero_extend(sz2-sz1, a);
79     }
80 }
81 
extract_bv(expr * n,unsigned & sz,bool & sign,expr_ref & bv)82 bool bit2int::extract_bv(expr* n, unsigned& sz, bool& sign, expr_ref& bv) {
83     numeral k;
84     bool is_int;
85     expr* r = nullptr;
86     if (m_bv_util.is_bv2int(n, r)) {
87         bv = r;
88         sz = m_bv_util.get_bv_size(bv);
89         sign = false;
90         return true;
91     }
92     else if (m_arith_util.is_numeral(n, k, is_int) && is_int) {
93         sz = get_numeral_bits(k);
94         bv = m_bv_util.mk_numeral(k, m_bv_util.mk_sort(sz));
95         sign = k.is_neg();
96         return true;
97     }
98     else {
99         return false;
100     }
101 }
102 
103 
mk_add(expr * e1,expr * e2,expr_ref & result)104 bool bit2int::mk_add(expr* e1, expr* e2, expr_ref& result) {
105     unsigned sz1, sz2;
106     bool sign1, sign2;
107     expr_ref tmp1(m), tmp2(m), tmp3(m);
108 
109     if (extract_bv(e1, sz1, sign1, tmp1) && !sign1 &&
110         extract_bv(e2, sz2, sign2, tmp2) && !sign2) {
111         unsigned sz;
112         numeral k;
113         if (m_bv_util.is_numeral(tmp1, k, sz) && k.is_zero()) {
114             result = e2;
115             return true;
116         }
117         if (m_bv_util.is_numeral(tmp2, k, sz) && k.is_zero()) {
118             result = e1;
119             return true;
120         }
121         align_sizes(tmp1, tmp2);
122         tmp1 = m_rewriter.mk_zero_extend(1, tmp1);
123         tmp2 = m_rewriter.mk_zero_extend(1, tmp2);
124         SASSERT(m_bv_util.get_bv_size(tmp1) == m_bv_util.get_bv_size(tmp2));
125         tmp3 = m_rewriter.mk_bv_add(tmp1, tmp2);
126         result = m_rewriter.mk_bv2int(tmp3);
127         return true;
128     }
129     return false;
130 }
131 
mk_comp(eq_type ty,expr * e1,expr * e2,expr_ref & result)132 bool bit2int::mk_comp(eq_type ty, expr* e1, expr* e2, expr_ref& result) {
133     unsigned sz1, sz2;
134     bool sign1, sign2;
135     expr_ref tmp1(m), tmp2(m), tmp3(m);
136     if (extract_bv(e1, sz1, sign1, tmp1) && !sign1 &&
137         extract_bv(e2, sz2, sign2, tmp2) && !sign2) {
138         align_sizes(tmp1, tmp2);
139         SASSERT(m_bv_util.get_bv_size(tmp1) == m_bv_util.get_bv_size(tmp2));
140         switch(ty) {
141         case lt:
142             tmp3 = m_rewriter.mk_ule(tmp2, tmp1);
143             result = m.mk_not(tmp3);
144             break;
145         case le:
146             result = m_rewriter.mk_ule(tmp1, tmp2);
147             break;
148         case eq:
149             result = m.mk_eq(tmp1, tmp2);
150             break;
151         }
152         return true;
153     }
154     return false;
155 }
156 
mk_mul(expr * e1,expr * e2,expr_ref & result)157 bool bit2int::mk_mul(expr* e1, expr* e2, expr_ref& result) {
158     unsigned sz1, sz2;
159     bool sign1, sign2;
160     expr_ref tmp1(m), tmp2(m);
161     expr_ref tmp3(m);
162 
163     if (extract_bv(e1, sz1, sign1, tmp1) &&
164         extract_bv(e2, sz2, sign2, tmp2)) {
165         align_sizes(tmp1, tmp2);
166         tmp1 = m_rewriter.mk_zero_extend(m_bv_util.get_bv_size(tmp1), tmp1);
167         tmp2 = m_rewriter.mk_zero_extend(m_bv_util.get_bv_size(tmp2), tmp2);
168 
169         SASSERT(m_bv_util.get_bv_size(tmp1) == m_bv_util.get_bv_size(tmp2));
170         tmp3 = m_rewriter.mk_bv_mul(tmp1, tmp2);
171         result = m_rewriter.mk_bv2int(tmp3);
172         if (sign1 != sign2) {
173             result = m_arith_util.mk_uminus(result);
174         }
175         return true;
176     }
177     return false;
178 }
179 
is_bv_poly(expr * n,expr_ref & pos,expr_ref & neg)180 bool bit2int::is_bv_poly(expr* n, expr_ref& pos, expr_ref& neg) {
181     ptr_vector<expr> todo;
182     expr_ref tmp(m);
183     numeral k;
184     bool is_int;
185     todo.push_back(n);
186     neg = pos = m_rewriter.mk_bv2int(m_bit0);
187 
188     while (!todo.empty()) {
189         n = todo.back();
190         todo.pop_back();
191         expr* arg1 = nullptr, *arg2 = nullptr;
192         if (m_bv_util.is_bv2int(n)) {
193             VERIFY(mk_add(n, pos, pos));
194         }
195         else if (m_arith_util.is_numeral(n, k, is_int) && is_int) {
196             if (k.is_nonneg()) {
197                 VERIFY(mk_add(n, pos, pos));
198             }
199             else {
200                 tmp = m_arith_util.mk_numeral(-k, true);
201                 VERIFY(mk_add(tmp, neg, neg));
202             }
203         }
204         else if (m_arith_util.is_add(n)) {
205             for (expr* arg : *to_app(n)) {
206                 todo.push_back(arg);
207             }
208         }
209         else if (m_arith_util.is_mul(n, arg1, arg2) &&
210                  m_arith_util.is_numeral(arg1, k, is_int) && is_int && k.is_minus_one() &&
211                  m_bv_util.is_bv2int(arg2)) {
212             VERIFY(mk_add(arg2, neg, neg));
213         }
214         else if (m_arith_util.is_mul(n, arg1, arg2) &&
215                  m_arith_util.is_numeral(arg2, k, is_int) && is_int && k.is_minus_one() &&
216                  m_bv_util.is_bv2int(arg1)) {
217             VERIFY(mk_add(arg1, neg, neg));
218         }
219         else if (m_arith_util.is_uminus(n, arg1) &&
220                  m_bv_util.is_bv2int(arg1)) {
221             VERIFY(mk_add(arg1, neg, neg));
222         }
223         else {
224             TRACE("bit2int", tout << "Not a poly: " << mk_pp(n, m) << "\n";);
225             return false;
226         }
227     }
228     return true;
229 }
230 
visit(quantifier * q)231 void bit2int::visit(quantifier* q) {
232     expr_ref result(m);
233     result = get_cached(q->get_expr());
234     result = m.update_quantifier(q, result);
235     cache_result(q, result);
236 }
237 
visit(app * n)238 void bit2int::visit(app* n) {
239     func_decl* f = n->get_decl();
240     unsigned num_args = n->get_num_args();
241 
242     m_args.reset();
243     for (expr* arg : *n) {
244         m_args.push_back(get_cached(arg));
245     }
246 
247     expr* const* args = m_args.data();
248 
249     bool has_b2i =
250         m_arith_util.is_le(n) || m_arith_util.is_ge(n) || m_arith_util.is_gt(n) ||
251         m_arith_util.is_lt(n) || m.is_eq(n);
252     expr_ref result(m);
253     for (unsigned i = 0; !has_b2i && i < num_args; ++i) {
254         has_b2i = m_bv_util.is_bv2int(args[i]);
255     }
256     if (!has_b2i) {
257         result = m.mk_app(f, num_args, args);
258         cache_result(n, result);
259         return;
260     }
261     //
262     // bv2int(x) + bv2int(y) -> bv2int(pad(x) + pad(y))
263     // bv2int(x) + k         -> bv2int(pad(x) + pad(k))
264     // bv2int(x) * bv2int(y) -> bv2int(pad(x) * pad(y))
265     // bv2int(x) * k         -> sign(k)*bv2int(pad(x) * pad(k))
266     // bv2int(x) - bv2int(y) <= z -> bv2int(x) <= bv2int(y) + z
267     // bv2int(x) <= z - bv2int(y) -> bv2int(x) + bv2int(y) <= z
268     //
269 
270     expr* e1 = nullptr, *e2 = nullptr;
271     expr_ref tmp1(m), tmp2(m);
272     expr_ref tmp3(m);
273     expr_ref pos1(m), neg1(m);
274     expr_ref pos2(m), neg2(m);
275     expr_ref e2bv(m);
276     bool sign2;
277     numeral k;
278     unsigned sz2;
279 
280     if (num_args >= 2) {
281         e1 = args[0];
282         e2 = args[1];
283     }
284 
285     if (m_arith_util.is_add(n) && num_args >= 1) {
286         result = e1;
287         for (unsigned i = 1; i < num_args; ++i) {
288             e1 = result;
289             e2 = args[i];
290             if (!mk_add(e1, e2, result)) {
291                 result = m.mk_app(f, num_args, args);
292                 cache_result(n, result);
293                 return;
294             }
295         }
296         cache_result(n, result);
297     }
298     else if (m_arith_util.is_mul(n) && num_args >= 1) {
299         result = e1;
300         for (unsigned i = 1; i < num_args; ++i) {
301             e1 = result;
302             e2 = args[i];
303             if (!mk_mul(e1, e2, result)) {
304                 result = m.mk_app(f, num_args, args);
305                 cache_result(n, result);
306                 return;
307             }
308         }
309         cache_result(n, result);
310     }
311     else if (m.is_eq(n) &&
312              is_bv_poly(e1, pos1, neg1) &&
313              is_bv_poly(e2, pos2, neg2) &&
314              mk_add(pos1, neg2, tmp1) &&
315              mk_add(neg1, pos2, tmp2) &&
316              mk_comp(eq, tmp1, tmp2, result)) {
317         cache_result(n, result);
318     }
319     else if (m_arith_util.is_le(n) &&
320              is_bv_poly(e1, pos1, neg1) &&
321              is_bv_poly(e2, pos2, neg2) &&
322              mk_add(pos1, neg2, tmp1) &&
323              mk_add(neg1, pos2, tmp2) &&
324              mk_comp(le, tmp1, tmp2, result)) {
325         cache_result(n, result);
326     }
327     else if (m_arith_util.is_lt(n) &&
328              is_bv_poly(e1, pos1, neg1) &&
329              is_bv_poly(e2, pos2, neg2) &&
330              mk_add(pos1, neg2, tmp1) &&
331              mk_add(neg1, pos2, tmp2) &&
332              mk_comp(lt, tmp1, tmp2, result)) {
333         cache_result(n, result);
334     }
335     else if (m_arith_util.is_ge(n) &&
336              is_bv_poly(e1, pos1, neg1) &&
337              is_bv_poly(e2, pos2, neg2) &&
338              mk_add(pos1, neg2, tmp1) &&
339              mk_add(neg1, pos2, tmp2) &&
340              mk_comp(le, tmp2, tmp1, result)) {
341         cache_result(n, result);
342     }
343     else if (m_arith_util.is_gt(n) &&
344              is_bv_poly(e1, pos1, neg1) &&
345              is_bv_poly(e2, pos2, neg2) &&
346              mk_add(pos1, neg2, tmp1) &&
347              mk_add(neg1, pos2, tmp2) &&
348              mk_comp(lt, tmp2, tmp1, result)) {
349         cache_result(n, result);
350     }
351     else if (m_arith_util.is_mod(n) &&
352              is_bv_poly(e1, pos1, neg1) &&
353              extract_bv(e2, sz2, sign2, e2bv) && !sign2) {
354         //
355         // (pos1 - neg1) mod e2 = (pos1 + (e2 - (neg1 mod e2))) mod e2
356         //
357         unsigned sz_p, sz_n, sz;
358         bool sign_p, sign_n;
359         expr_ref tmp_p(m), tmp_n(m);
360         VERIFY(extract_bv(pos1, sz_p, sign_p, tmp_p));
361         VERIFY(extract_bv(neg1, sz_n, sign_n, tmp_n));
362         SASSERT(!sign_p && !sign_n);
363 
364         // pos1 mod e2
365         if (m_bv_util.is_numeral(tmp_n, k, sz) && k.is_zero()) {
366             tmp1 = tmp_p;
367             tmp2 = e2bv;
368             align_sizes(tmp1, tmp2);
369             tmp3 = m_rewriter.mk_bv_urem(tmp1, tmp2);
370             result = m_rewriter.mk_bv2int(tmp3);
371             cache_result(n, result);
372             return;
373         }
374 
375         // neg1 mod e2;
376         tmp1 = tmp_n;
377         tmp2 = e2bv;
378         align_sizes(tmp1, tmp2);
379         tmp3 = m_rewriter.mk_bv_urem(tmp1, tmp2);
380         // e2 - (neg1 mod e2)
381         tmp1 = e2bv;
382         tmp2 = tmp3;
383         align_sizes(tmp1, tmp2);
384         tmp3 = m_rewriter.mk_bv_sub(tmp1, tmp2);
385         // pos1 + (e2 - (neg1 mod e2))
386         tmp1 = tmp_p;
387         tmp2 = tmp3;
388         align_sizes(tmp1, tmp2);
389         tmp_p = m_rewriter.mk_zero_extend(1, tmp1);
390         tmp_n = m_rewriter.mk_zero_extend(1, tmp2);
391         tmp1 = m_rewriter.mk_bv_add(tmp_p, tmp_n);
392         // (pos1 + (e2 - (neg1 mod e2))) mod e2
393         tmp2 = e2bv;
394         align_sizes(tmp1, tmp2);
395         tmp3 = m_rewriter.mk_bv_urem(tmp1, tmp2);
396         result = m_rewriter.mk_bv2int(tmp3);
397 
398         cache_result(n, result);
399     }
400     else {
401         result = m.mk_app(f, num_args, args);
402         cache_result(n, result);
403     }
404 }
405 
get_cached(expr * n) const406 expr * bit2int::get_cached(expr * n) const {
407     expr* r = nullptr;
408     proof* p = nullptr;
409     const_cast<bit2int*>(this)->m_cache.get(n, r, p);
410     CTRACE("bit2int", !r, tout << mk_pp(n, m) << "\n";);
411     return r;
412 }
413 
cache_result(expr * n,expr * r)414 void bit2int::cache_result(expr * n, expr * r) {
415     TRACE("bit2int_verbose", tout << "caching:\n" << mk_pp(n, m) <<
416           "======>\n" << mk_ll_pp(r, m) << "\n";);
417     m_cache.insert(n, r, nullptr);
418 }
419