1 /*++
2 Copyright (c) 2020 Microsoft Corporation
3 
4 Module Name:
5 
6     bv_delay_internalize.cpp
7 
8 Abstract:
9 
10     Checking of relevant bv nodes, and if required delay axiomatize
11 
12 Author:
13 
14     Nikolaj Bjorner (nbjorner) 2020-09-22
15 
16 --*/
17 
18 #include "sat/smt/bv_solver.h"
19 #include "sat/smt/euf_solver.h"
20 
21 namespace bv {
22 
check_delay_internalized(expr * e)23     bool solver::check_delay_internalized(expr* e) {
24         if (!ctx.is_relevant(e))
25             return true;
26         if (get_internalize_mode(e) != internalize_mode::delay_i)
27             return true;
28         SASSERT(bv.is_bv(e));
29         switch (to_app(e)->get_decl_kind()) {
30         case OP_BMUL:
31             return check_mul(to_app(e));
32         case OP_BSMUL_NO_OVFL:
33         case OP_BSMUL_NO_UDFL:
34         case OP_BUMUL_NO_OVFL:
35             return check_bool_eval(expr2enode(e));
36         default:
37             return check_bv_eval(expr2enode(e));
38         }
39         return true;
40     }
41 
should_bit_blast(app * e)42     bool solver::should_bit_blast(app* e) {
43         if (bv.get_bv_size(e) <= 12)
44             return true;
45         unsigned num_vars = e->get_num_args();
46         for (expr* arg : *e)
47             if (!m.is_value(arg))
48                 --num_vars;
49         if (num_vars <= 1)
50             return true;
51         if (bv.is_bv_add(e) && num_vars * bv.get_bv_size(e) <= 64)
52             return true;
53         return false;
54     }
55 
eval_args(euf::enode * n,expr_ref_vector & args)56     expr_ref solver::eval_args(euf::enode* n, expr_ref_vector& args) {
57         for (euf::enode* arg : euf::enode_args(n))
58             args.push_back(eval_bv(arg));
59         expr_ref r(m.mk_app(n->get_decl(), args), m);
60         ctx.get_rewriter()(r);
61         return r;
62     }
63 
eval_bv(euf::enode * n)64     expr_ref solver::eval_bv(euf::enode* n) {
65         rational val;
66         theory_var v = n->get_th_var(get_id());
67         SASSERT(get_fixed_value(v, val));
68         VERIFY(get_fixed_value(v, val));
69         return expr_ref(bv.mk_numeral(val, get_bv_size(v)), m);
70     }
71 
check_mul(app * e)72     bool solver::check_mul(app* e) {
73         SASSERT(e->get_num_args() >= 2);
74         expr_ref_vector args(m);
75         euf::enode* n = expr2enode(e);
76         if (!reflect())
77             return false;
78         auto r1 = eval_bv(n);
79         auto r2 = eval_args(n, args);
80         if (r1 == r2)
81             return true;
82 
83         TRACE("bv", tout << mk_bounded_pp(e, m) << " evaluates to " << r1 << " arguments: " << args << "\n";);
84         // check x*0 = 0
85         if (!check_mul_zero(e, args, r1, r2))
86             return false;
87 
88         // check x*1 = x
89         if (!check_mul_one(e, args, r1, r2))
90             return false;
91 
92         // Add propagation axiom for arguments
93         if (!check_mul_invertibility(e, args, r1))
94             return false;
95 
96         // Some other possible approaches:
97         // algebraic rules:
98         // x*(y+z), and there are nodes for x*y or x*z -> x*(y+z) = x*y + x*z
99         // compute S-polys for a set of constraints.
100 
101         // Hensel lifting:
102         // The idea is dual to fixing high-order bits. Fix the low order bits where multiplication
103         // is correct, and propagate on the next bit that shows a discrepancy.
104 
105         // check Montgommery properties: (x*y) mod p = (x mod p)*(y mod p) for small primes p
106 
107         // check ranges lo <= x <= hi, lo' <= y <= hi', lo*lo' < x*y <= hi*hi' for non-overflowing values.
108 
109         // check tangets hi >= y >= y0 and hi' >= x => x*y >= x*y0
110 
111 
112         if (m_cheap_axioms)
113             return true;
114 
115         set_delay_internalize(e, internalize_mode::no_delay_i);
116         internalize_circuit(e);
117         return false;
118     }
119 
120     /**
121      * Add invertibility condition for multiplication
122      *
123      * x * y = z => (y | -y) & z = z
124      *
125      * This propagator relates to Niemetz and Preiner's consistency and invertibility conditions.
126      * The idea is that the side-conditions for ensuring invertibility are valid
127      * and in some cases are cheap to bit-blast. For multiplication, we include only
128      * the _consistency_ condition because the side-constraints for invertibility
129      * appear expensive (to paraphrase FMCAD 2020 paper):
130      *  x * s = t => (s = 0 or mcb(x << c, y << c))
131      *
132      *  for c = ctz(s) and y = (t >> c) / (s >> c)
133      *
134      * mcb(x,t/s) just mean that the bit-vectors are compatible as ternary bit-vectors,
135      * which for propagation means that they are the same.
136      */
137 
check_mul_invertibility(app * n,expr_ref_vector const & arg_values,expr * value)138     bool solver::check_mul_invertibility(app* n, expr_ref_vector const& arg_values, expr* value) {
139 
140         expr_ref inv(m);
141 
142         auto invert = [&](expr* s, expr* t) {
143             return bv.mk_bv_and(bv.mk_bv_or(s, bv.mk_bv_neg(s)), t);
144         };
145         auto check_invert = [&](expr* s) {
146             inv = invert(s, value);
147             ctx.get_rewriter()(inv);
148             return inv == value;
149         };
150         auto add_inv = [&](expr* s) {
151             inv = invert(s, n);
152             TRACE("bv", tout << "enforce " << inv << "\n";);
153             add_unit(eq_internalize(inv, n));
154         };
155         bool ok = true;
156         for (unsigned i = 0; i < arg_values.size(); ++i) {
157             if (!check_invert(arg_values[i])) {
158                 add_inv(n->get_arg(i));
159                 ok = false;
160             }
161         }
162         return ok;
163     }
164 
165 
166 
167     /*
168     * Check that multiplication with 0 is correctly propagated.
169     * If not, create algebraic axioms enforcing 0*x = 0 and x*0 = 0
170     *
171     * z = 0, then lsb(x) + 1 + lsb(y) + 1 >= sz
172 
173     */
check_mul_zero(app * n,expr_ref_vector const & arg_values,expr * mul_value,expr * arg_value)174     bool solver::check_mul_zero(app* n, expr_ref_vector const& arg_values, expr* mul_value, expr* arg_value) {
175         SASSERT(mul_value != arg_value);
176         SASSERT(!(bv.is_zero(mul_value) && bv.is_zero(arg_value)));
177         if (bv.is_zero(arg_value)) {
178             unsigned sz = n->get_num_args();
179             expr_ref_vector args(m, sz, n->get_args());
180             for (unsigned i = 0; i < sz && !s().inconsistent(); ++i) {
181 
182                 args[i] = arg_value;
183                 expr_ref r(m.mk_app(n->get_decl(), args), m);
184                 set_delay_internalize(r, internalize_mode::init_bits_only_i); // do not bit-blast this multiplier.
185                 args[i] = n->get_arg(i);
186                 add_unit(eq_internalize(r, arg_value));
187             }
188             IF_VERBOSE(2, verbose_stream() << "delay internalize @" << s().scope_lvl() << "\n");
189             return false;
190         }
191         if (bv.is_zero(mul_value)) {
192             return true;
193 #if 0
194             vector<expr_ref_vector> lsb_bits;
195             for (expr* arg : *n) {
196                 expr_ref_vector bits(m);
197                 encode_lsb_tail(arg, bits);
198                 lsb_bits.push_back(bits);
199             }
200             expr_ref_vector zs(m);
201             literal_vector lits;
202             expr_ref eq(m.mk_eq(n, mul_value), m);
203             lits.push_back(~b_internalize(eq));
204 
205             for (unsigned i = 0; i < lsb_bits.size(); ++i) {
206             }
207             expr_ref z(m.mk_or(zs), m);
208             add_clause(lits);
209             // sum of lsb should be at least sz
210             return true;
211 #endif
212         }
213         return true;
214     }
215 
216     /***
217     * check that 1*y = y, x*1 = x
218     */
check_mul_one(app * n,expr_ref_vector const & arg_values,expr * mul_value,expr * arg_value)219     bool solver::check_mul_one(app* n, expr_ref_vector const& arg_values, expr* mul_value, expr* arg_value) {
220         if (arg_values.size() != 2)
221             return true;
222         if (bv.is_one(arg_values[0])) {
223             expr_ref mul1(m.mk_app(n->get_decl(), arg_values[0], n->get_arg(1)), m);
224             set_delay_internalize(mul1, internalize_mode::init_bits_only_i);
225             add_unit(eq_internalize(mul1, n->get_arg(1)));
226             TRACE("bv", tout << mul1 << "\n";);
227             return false;
228         }
229         if (bv.is_one(arg_values[1])) {
230             expr_ref mul1(m.mk_app(n->get_decl(), n->get_arg(0), arg_values[1]), m);
231             set_delay_internalize(mul1, internalize_mode::init_bits_only_i);
232             add_unit(eq_internalize(mul1, n->get_arg(0)));
233             TRACE("bv", tout << mul1 << "\n";);
234             return false;
235         }
236         return true;
237     }
238 
239 
240     /**
241     * The i'th bit in xs is 1 if the most significant bit of x is i or higher.
242     */
encode_msb_tail(expr * x,expr_ref_vector & xs)243     void solver::encode_msb_tail(expr* x, expr_ref_vector& xs) {
244         theory_var v = expr2enode(x)->get_th_var(get_id());
245         sat::literal_vector const& bits = m_bits[v];
246         if (bits.empty())
247             return;
248         expr_ref tmp = literal2expr(bits.back());
249         for (unsigned i = bits.size() - 1; i-- > 0; ) {
250             auto b = bits[i];
251             tmp = m.mk_or(literal2expr(b), tmp);
252             xs.push_back(tmp);
253         }
254     };
255 
256     /**
257      * The i'th bit in xs is 1 if the least significant bit of x is i or lower.
258      */
encode_lsb_tail(expr * x,expr_ref_vector & xs)259     void solver::encode_lsb_tail(expr* x, expr_ref_vector& xs) {
260         theory_var v = expr2enode(x)->get_th_var(get_id());
261         sat::literal_vector const& bits = m_bits[v];
262         if (bits.empty())
263             return;
264         expr_ref tmp = literal2expr(bits[0]);
265         for (unsigned i = 1; i < bits.size(); ++i) {
266             auto b = bits[i];
267             tmp = m.mk_or(literal2expr(b), tmp);
268             xs.push_back(tmp);
269         }
270     };
271 
272     /**
273     * Check non-overflow of unsigned multiplication.
274     *
275     * no_overflow(x, y) = > msb(x) + msb(y) <= sz;
276     * msb(x) + msb(y) < sz => no_overflow(x,y)
277     */
check_umul_no_overflow(app * n,expr_ref_vector const & arg_values,expr * value)278     bool solver::check_umul_no_overflow(app* n, expr_ref_vector const& arg_values, expr* value) {
279         SASSERT(arg_values.size() == 2);
280         SASSERT(m.is_true(value) || m.is_false(value));
281         rational v0, v1;
282         unsigned sz;
283         VERIFY(bv.is_numeral(arg_values[0], v0, sz));
284         VERIFY(bv.is_numeral(arg_values[1], v1));
285         unsigned msb0 = v0.get_num_bits();
286         unsigned msb1 = v1.get_num_bits();
287         expr_ref_vector xs(m), ys(m), zs(m);
288 
289         if (m.is_true(value) && msb0 + msb1 > sz && !v0.is_zero() && !v1.is_zero()) {
290             sat::literal no_overflow = expr2literal(n);
291             encode_msb_tail(n->get_arg(0), xs);
292             encode_msb_tail(n->get_arg(1), ys);
293             for (unsigned i = 1; i <= sz; ++i) {
294                 sat::literal bit0 = mk_literal(xs.get(i - 1));
295                 sat::literal bit1 = mk_literal(ys.get(sz - i));
296                 add_clause(~no_overflow, ~bit0, ~bit1);
297             }
298             return false;
299         }
300         else if (m.is_false(value) && msb0 + msb1 < sz) {
301             encode_msb_tail(n->get_arg(0), xs);
302             encode_msb_tail(n->get_arg(1), ys);
303             sat::literal_vector lits;
304             lits.push_back(expr2literal(n));
305             for (unsigned i = 1; i < sz; ++i) {
306                 expr_ref msb_ge_sz(m.mk_and(xs.get(i - 1), ys.get(sz - i - 1)), m);
307                 lits.push_back(mk_literal(msb_ge_sz));
308             }
309             add_clause(lits);
310             return false;
311         }
312         return true;
313     }
314 
check_bv_eval(euf::enode * n)315     bool solver::check_bv_eval(euf::enode* n) {
316         expr_ref_vector args(m);
317         app* a = n->get_app();
318         SASSERT(bv.is_bv(a));
319         auto r1 = eval_bv(n);
320         auto r2 = eval_args(n, args);
321         if (r1 == r2)
322             return true;
323         if (m_cheap_axioms)
324             return true;
325         set_delay_internalize(a, internalize_mode::no_delay_i);
326         internalize_circuit(a);
327         return false;
328     }
329 
check_bool_eval(euf::enode * n)330     bool solver::check_bool_eval(euf::enode* n) {
331         expr_ref_vector args(m);
332         SASSERT(m.is_bool(n->get_expr()));
333         sat::literal lit = expr2literal(n->get_expr());
334         expr* r1 = m.mk_bool_val(s().value(lit) == l_true);
335         auto r2 = eval_args(n, args);
336         if (r1 == r2)
337             return true;
338         app* a = n->get_app();
339         if (bv.is_bv_umul_no_ovfl(a) && !check_umul_no_overflow(a, args, r1))
340             return false;
341         if (m_cheap_axioms)
342             return true;
343         set_delay_internalize(a, internalize_mode::no_delay_i);
344         internalize_circuit(a);
345         return false;
346     }
347 
set_delay_internalize(expr * e,internalize_mode mode)348     void solver::set_delay_internalize(expr* e, internalize_mode mode) {
349         if (!m_delay_internalize.contains(e))
350             ctx.push(insert_obj_map<expr, internalize_mode>(m_delay_internalize, e));
351         else
352             ctx.push(remove_obj_map<expr, internalize_mode>(m_delay_internalize, e, m_delay_internalize[e]));
353         m_delay_internalize.insert(e, mode);
354     }
355 
get_internalize_mode(expr * e)356     solver::internalize_mode solver::get_internalize_mode(expr* e) {
357         if (!bv.is_bv(e))
358             return internalize_mode::no_delay_i;
359         if (!get_config().m_bv_delay)
360             return internalize_mode::no_delay_i;
361         if (!reflect())
362             return internalize_mode::no_delay_i;
363         internalize_mode mode;
364         switch (to_app(e)->get_decl_kind()) {
365         case OP_BMUL:
366         case OP_BSMUL_NO_OVFL:
367         case OP_BSMUL_NO_UDFL:
368         case OP_BUMUL_NO_OVFL:
369         case OP_BSMOD_I:
370         case OP_BUREM_I:
371         case OP_BSREM_I:
372         case OP_BUDIV_I:
373         case OP_BSDIV_I:
374         case OP_BADD:
375             if (should_bit_blast(to_app(e)))
376                 return internalize_mode::no_delay_i;
377             mode = internalize_mode::delay_i;
378             if (!m_delay_internalize.find(e, mode))
379                 m_delay_internalize.insert(e, mode);
380             return mode;
381         default:
382             return internalize_mode::no_delay_i;
383         }
384     }
385 }
386