1 /*++
2 Copyright (c) 2009 Microsoft Corporation
3
4 Module Name:
5
6 bit2cpp.cpp
7
8 Abstract:
9
10 Routines for simplifying bit2int expressions.
11 This propagates bv2int over arithmetical symbols as much as possible,
12 converting arithmetic operations into bit-vector operations.
13
14 Author:
15
16 Nikolaj Bjorner (nbjorner) 2009-08-28
17
18 Revision History:
19
20 --*/
21
22 #include "ast/ast_pp.h"
23 #include "ast/ast_ll_pp.h"
24 #include "ast/for_each_ast.h"
25 #include "ast/rewriter/bit2int.h"
26
bit2int(ast_manager & m)27 bit2int::bit2int(ast_manager & m) :
28 m(m), m_bv_util(m), m_rewriter(m), m_arith_util(m), m_cache(m, false), m_bit0(m) {
29 m_bit0 = m_bv_util.mk_numeral(0,1);
30 }
31
operator ()(expr * n,expr_ref & result,proof_ref & p)32 void bit2int::operator()(expr * n, expr_ref & result, proof_ref& p) {
33 flush_cache();
34 expr_reduce emap(*this);
35 for_each_ast(emap, n);
36 result = get_cached(n);
37 if (m.proofs_enabled() && n != result.get()) {
38 // TBD: rough
39 p = m.mk_rewrite(n, result);
40 }
41 TRACE("bit2int",
42 tout << mk_pp(n, m) << "======>\n" << result << "\n";);
43
44 }
45
46
get_b2i_size(expr * n)47 unsigned bit2int::get_b2i_size(expr* n) {
48 expr* arg = nullptr;
49 VERIFY(m_bv_util.is_bv2int(n, arg));
50 return m_bv_util.get_bv_size(arg);
51 }
52
get_numeral_bits(numeral const & k)53 unsigned bit2int::get_numeral_bits(numeral const& k) {
54 numeral two(2);
55 numeral n(abs(k));
56 unsigned num_bits = 1;
57 n = div(n, two);
58 while (n.is_pos()) {
59 ++num_bits;
60 n = div(n, two);
61 }
62 return num_bits;
63 }
64
align_size(expr * e,unsigned sz,expr_ref & result)65 void bit2int::align_size(expr* e, unsigned sz, expr_ref& result) {
66 unsigned sz1 = m_bv_util.get_bv_size(e);
67 SASSERT(sz1 <= sz);
68 result = m_rewriter.mk_zero_extend(sz - sz1, e);
69 }
70
align_sizes(expr_ref & a,expr_ref & b)71 void bit2int::align_sizes(expr_ref& a, expr_ref& b) {
72 unsigned sz1 = m_bv_util.get_bv_size(a);
73 unsigned sz2 = m_bv_util.get_bv_size(b);
74 if (sz1 > sz2) {
75 b = m_rewriter.mk_zero_extend(sz1 - sz2, b);
76 }
77 else if (sz2 > sz1) {
78 a = m_rewriter.mk_zero_extend(sz2-sz1, a);
79 }
80 }
81
extract_bv(expr * n,unsigned & sz,bool & sign,expr_ref & bv)82 bool bit2int::extract_bv(expr* n, unsigned& sz, bool& sign, expr_ref& bv) {
83 numeral k;
84 bool is_int;
85 expr* r = nullptr;
86 if (m_bv_util.is_bv2int(n, r)) {
87 bv = r;
88 sz = m_bv_util.get_bv_size(bv);
89 sign = false;
90 return true;
91 }
92 else if (m_arith_util.is_numeral(n, k, is_int) && is_int) {
93 sz = get_numeral_bits(k);
94 bv = m_bv_util.mk_numeral(k, m_bv_util.mk_sort(sz));
95 sign = k.is_neg();
96 return true;
97 }
98 else {
99 return false;
100 }
101 }
102
103
mk_add(expr * e1,expr * e2,expr_ref & result)104 bool bit2int::mk_add(expr* e1, expr* e2, expr_ref& result) {
105 unsigned sz1, sz2;
106 bool sign1, sign2;
107 expr_ref tmp1(m), tmp2(m), tmp3(m);
108
109 if (extract_bv(e1, sz1, sign1, tmp1) && !sign1 &&
110 extract_bv(e2, sz2, sign2, tmp2) && !sign2) {
111 unsigned sz;
112 numeral k;
113 if (m_bv_util.is_numeral(tmp1, k, sz) && k.is_zero()) {
114 result = e2;
115 return true;
116 }
117 if (m_bv_util.is_numeral(tmp2, k, sz) && k.is_zero()) {
118 result = e1;
119 return true;
120 }
121 align_sizes(tmp1, tmp2);
122 tmp1 = m_rewriter.mk_zero_extend(1, tmp1);
123 tmp2 = m_rewriter.mk_zero_extend(1, tmp2);
124 SASSERT(m_bv_util.get_bv_size(tmp1) == m_bv_util.get_bv_size(tmp2));
125 tmp3 = m_rewriter.mk_bv_add(tmp1, tmp2);
126 result = m_rewriter.mk_bv2int(tmp3);
127 return true;
128 }
129 return false;
130 }
131
mk_comp(eq_type ty,expr * e1,expr * e2,expr_ref & result)132 bool bit2int::mk_comp(eq_type ty, expr* e1, expr* e2, expr_ref& result) {
133 unsigned sz1, sz2;
134 bool sign1, sign2;
135 expr_ref tmp1(m), tmp2(m), tmp3(m);
136 if (extract_bv(e1, sz1, sign1, tmp1) && !sign1 &&
137 extract_bv(e2, sz2, sign2, tmp2) && !sign2) {
138 align_sizes(tmp1, tmp2);
139 SASSERT(m_bv_util.get_bv_size(tmp1) == m_bv_util.get_bv_size(tmp2));
140 switch(ty) {
141 case lt:
142 tmp3 = m_rewriter.mk_ule(tmp2, tmp1);
143 result = m.mk_not(tmp3);
144 break;
145 case le:
146 result = m_rewriter.mk_ule(tmp1, tmp2);
147 break;
148 case eq:
149 result = m.mk_eq(tmp1, tmp2);
150 break;
151 }
152 return true;
153 }
154 return false;
155 }
156
mk_mul(expr * e1,expr * e2,expr_ref & result)157 bool bit2int::mk_mul(expr* e1, expr* e2, expr_ref& result) {
158 unsigned sz1, sz2;
159 bool sign1, sign2;
160 expr_ref tmp1(m), tmp2(m);
161 expr_ref tmp3(m);
162
163 if (extract_bv(e1, sz1, sign1, tmp1) &&
164 extract_bv(e2, sz2, sign2, tmp2)) {
165 align_sizes(tmp1, tmp2);
166 tmp1 = m_rewriter.mk_zero_extend(m_bv_util.get_bv_size(tmp1), tmp1);
167 tmp2 = m_rewriter.mk_zero_extend(m_bv_util.get_bv_size(tmp2), tmp2);
168
169 SASSERT(m_bv_util.get_bv_size(tmp1) == m_bv_util.get_bv_size(tmp2));
170 tmp3 = m_rewriter.mk_bv_mul(tmp1, tmp2);
171 result = m_rewriter.mk_bv2int(tmp3);
172 if (sign1 != sign2) {
173 result = m_arith_util.mk_uminus(result);
174 }
175 return true;
176 }
177 return false;
178 }
179
is_bv_poly(expr * n,expr_ref & pos,expr_ref & neg)180 bool bit2int::is_bv_poly(expr* n, expr_ref& pos, expr_ref& neg) {
181 ptr_vector<expr> todo;
182 expr_ref tmp(m);
183 numeral k;
184 bool is_int;
185 todo.push_back(n);
186 neg = pos = m_rewriter.mk_bv2int(m_bit0);
187
188 while (!todo.empty()) {
189 n = todo.back();
190 todo.pop_back();
191 expr* arg1 = nullptr, *arg2 = nullptr;
192 if (m_bv_util.is_bv2int(n)) {
193 VERIFY(mk_add(n, pos, pos));
194 }
195 else if (m_arith_util.is_numeral(n, k, is_int) && is_int) {
196 if (k.is_nonneg()) {
197 VERIFY(mk_add(n, pos, pos));
198 }
199 else {
200 tmp = m_arith_util.mk_numeral(-k, true);
201 VERIFY(mk_add(tmp, neg, neg));
202 }
203 }
204 else if (m_arith_util.is_add(n)) {
205 for (expr* arg : *to_app(n)) {
206 todo.push_back(arg);
207 }
208 }
209 else if (m_arith_util.is_mul(n, arg1, arg2) &&
210 m_arith_util.is_numeral(arg1, k, is_int) && is_int && k.is_minus_one() &&
211 m_bv_util.is_bv2int(arg2)) {
212 VERIFY(mk_add(arg2, neg, neg));
213 }
214 else if (m_arith_util.is_mul(n, arg1, arg2) &&
215 m_arith_util.is_numeral(arg2, k, is_int) && is_int && k.is_minus_one() &&
216 m_bv_util.is_bv2int(arg1)) {
217 VERIFY(mk_add(arg1, neg, neg));
218 }
219 else if (m_arith_util.is_uminus(n, arg1) &&
220 m_bv_util.is_bv2int(arg1)) {
221 VERIFY(mk_add(arg1, neg, neg));
222 }
223 else {
224 TRACE("bit2int", tout << "Not a poly: " << mk_pp(n, m) << "\n";);
225 return false;
226 }
227 }
228 return true;
229 }
230
visit(quantifier * q)231 void bit2int::visit(quantifier* q) {
232 expr_ref result(m);
233 result = get_cached(q->get_expr());
234 result = m.update_quantifier(q, result);
235 cache_result(q, result);
236 }
237
visit(app * n)238 void bit2int::visit(app* n) {
239 func_decl* f = n->get_decl();
240 unsigned num_args = n->get_num_args();
241
242 m_args.reset();
243 for (expr* arg : *n) {
244 m_args.push_back(get_cached(arg));
245 }
246
247 expr* const* args = m_args.data();
248
249 bool has_b2i =
250 m_arith_util.is_le(n) || m_arith_util.is_ge(n) || m_arith_util.is_gt(n) ||
251 m_arith_util.is_lt(n) || m.is_eq(n);
252 expr_ref result(m);
253 for (unsigned i = 0; !has_b2i && i < num_args; ++i) {
254 has_b2i = m_bv_util.is_bv2int(args[i]);
255 }
256 if (!has_b2i) {
257 result = m.mk_app(f, num_args, args);
258 cache_result(n, result);
259 return;
260 }
261 //
262 // bv2int(x) + bv2int(y) -> bv2int(pad(x) + pad(y))
263 // bv2int(x) + k -> bv2int(pad(x) + pad(k))
264 // bv2int(x) * bv2int(y) -> bv2int(pad(x) * pad(y))
265 // bv2int(x) * k -> sign(k)*bv2int(pad(x) * pad(k))
266 // bv2int(x) - bv2int(y) <= z -> bv2int(x) <= bv2int(y) + z
267 // bv2int(x) <= z - bv2int(y) -> bv2int(x) + bv2int(y) <= z
268 //
269
270 expr* e1 = nullptr, *e2 = nullptr;
271 expr_ref tmp1(m), tmp2(m);
272 expr_ref tmp3(m);
273 expr_ref pos1(m), neg1(m);
274 expr_ref pos2(m), neg2(m);
275 expr_ref e2bv(m);
276 bool sign2;
277 numeral k;
278 unsigned sz2;
279
280 if (num_args >= 2) {
281 e1 = args[0];
282 e2 = args[1];
283 }
284
285 if (m_arith_util.is_add(n) && num_args >= 1) {
286 result = e1;
287 for (unsigned i = 1; i < num_args; ++i) {
288 e1 = result;
289 e2 = args[i];
290 if (!mk_add(e1, e2, result)) {
291 result = m.mk_app(f, num_args, args);
292 cache_result(n, result);
293 return;
294 }
295 }
296 cache_result(n, result);
297 }
298 else if (m_arith_util.is_mul(n) && num_args >= 1) {
299 result = e1;
300 for (unsigned i = 1; i < num_args; ++i) {
301 e1 = result;
302 e2 = args[i];
303 if (!mk_mul(e1, e2, result)) {
304 result = m.mk_app(f, num_args, args);
305 cache_result(n, result);
306 return;
307 }
308 }
309 cache_result(n, result);
310 }
311 else if (m.is_eq(n) &&
312 is_bv_poly(e1, pos1, neg1) &&
313 is_bv_poly(e2, pos2, neg2) &&
314 mk_add(pos1, neg2, tmp1) &&
315 mk_add(neg1, pos2, tmp2) &&
316 mk_comp(eq, tmp1, tmp2, result)) {
317 cache_result(n, result);
318 }
319 else if (m_arith_util.is_le(n) &&
320 is_bv_poly(e1, pos1, neg1) &&
321 is_bv_poly(e2, pos2, neg2) &&
322 mk_add(pos1, neg2, tmp1) &&
323 mk_add(neg1, pos2, tmp2) &&
324 mk_comp(le, tmp1, tmp2, result)) {
325 cache_result(n, result);
326 }
327 else if (m_arith_util.is_lt(n) &&
328 is_bv_poly(e1, pos1, neg1) &&
329 is_bv_poly(e2, pos2, neg2) &&
330 mk_add(pos1, neg2, tmp1) &&
331 mk_add(neg1, pos2, tmp2) &&
332 mk_comp(lt, tmp1, tmp2, result)) {
333 cache_result(n, result);
334 }
335 else if (m_arith_util.is_ge(n) &&
336 is_bv_poly(e1, pos1, neg1) &&
337 is_bv_poly(e2, pos2, neg2) &&
338 mk_add(pos1, neg2, tmp1) &&
339 mk_add(neg1, pos2, tmp2) &&
340 mk_comp(le, tmp2, tmp1, result)) {
341 cache_result(n, result);
342 }
343 else if (m_arith_util.is_gt(n) &&
344 is_bv_poly(e1, pos1, neg1) &&
345 is_bv_poly(e2, pos2, neg2) &&
346 mk_add(pos1, neg2, tmp1) &&
347 mk_add(neg1, pos2, tmp2) &&
348 mk_comp(lt, tmp2, tmp1, result)) {
349 cache_result(n, result);
350 }
351 else if (m_arith_util.is_mod(n) &&
352 is_bv_poly(e1, pos1, neg1) &&
353 extract_bv(e2, sz2, sign2, e2bv) && !sign2) {
354 //
355 // (pos1 - neg1) mod e2 = (pos1 + (e2 - (neg1 mod e2))) mod e2
356 //
357 unsigned sz_p, sz_n, sz;
358 bool sign_p, sign_n;
359 expr_ref tmp_p(m), tmp_n(m);
360 VERIFY(extract_bv(pos1, sz_p, sign_p, tmp_p));
361 VERIFY(extract_bv(neg1, sz_n, sign_n, tmp_n));
362 SASSERT(!sign_p && !sign_n);
363
364 // pos1 mod e2
365 if (m_bv_util.is_numeral(tmp_n, k, sz) && k.is_zero()) {
366 tmp1 = tmp_p;
367 tmp2 = e2bv;
368 align_sizes(tmp1, tmp2);
369 tmp3 = m_rewriter.mk_bv_urem(tmp1, tmp2);
370 result = m_rewriter.mk_bv2int(tmp3);
371 cache_result(n, result);
372 return;
373 }
374
375 // neg1 mod e2;
376 tmp1 = tmp_n;
377 tmp2 = e2bv;
378 align_sizes(tmp1, tmp2);
379 tmp3 = m_rewriter.mk_bv_urem(tmp1, tmp2);
380 // e2 - (neg1 mod e2)
381 tmp1 = e2bv;
382 tmp2 = tmp3;
383 align_sizes(tmp1, tmp2);
384 tmp3 = m_rewriter.mk_bv_sub(tmp1, tmp2);
385 // pos1 + (e2 - (neg1 mod e2))
386 tmp1 = tmp_p;
387 tmp2 = tmp3;
388 align_sizes(tmp1, tmp2);
389 tmp_p = m_rewriter.mk_zero_extend(1, tmp1);
390 tmp_n = m_rewriter.mk_zero_extend(1, tmp2);
391 tmp1 = m_rewriter.mk_bv_add(tmp_p, tmp_n);
392 // (pos1 + (e2 - (neg1 mod e2))) mod e2
393 tmp2 = e2bv;
394 align_sizes(tmp1, tmp2);
395 tmp3 = m_rewriter.mk_bv_urem(tmp1, tmp2);
396 result = m_rewriter.mk_bv2int(tmp3);
397
398 cache_result(n, result);
399 }
400 else {
401 result = m.mk_app(f, num_args, args);
402 cache_result(n, result);
403 }
404 }
405
get_cached(expr * n) const406 expr * bit2int::get_cached(expr * n) const {
407 expr* r = nullptr;
408 proof* p = nullptr;
409 const_cast<bit2int*>(this)->m_cache.get(n, r, p);
410 CTRACE("bit2int", !r, tout << mk_pp(n, m) << "\n";);
411 return r;
412 }
413
cache_result(expr * n,expr * r)414 void bit2int::cache_result(expr * n, expr * r) {
415 TRACE("bit2int_verbose", tout << "caching:\n" << mk_pp(n, m) <<
416 "======>\n" << mk_ll_pp(r, m) << "\n";);
417 m_cache.insert(n, r, nullptr);
418 }
419