1 /*++
2 Copyright (c) 2013 Microsoft Corporation
3 
4 Module Name:
5 
6     smt_farkas_util.cpp
7 
8 Abstract:
9 
10     Utility for combining inequalities using coefficients obtained from Farkas lemmas.
11 
12 Author:
13 
14     Nikolaj Bjorner (nbjorner) 2013-11-2.
15 
16 Revision History:
17 
18     NB. This utility is specialized to proofs generated by the arithmetic solvers.
19 
20 --*/
21 
22 #include "smt/smt_farkas_util.h"
23 #include "ast/ast_pp.h"
24 #include "ast/rewriter/th_rewriter.h"
25 #include "ast/rewriter/bool_rewriter.h"
26 
27 
28 namespace smt {
29 
farkas_util(ast_manager & m)30     farkas_util::farkas_util(ast_manager& m):
31         m(m),
32         a(m),
33         m_ineqs(m),
34         m_split_literals(false),
35         m_time(0) {
36     }
37 
mk_coerce(expr * & e1,expr * & e2)38     void farkas_util::mk_coerce(expr*& e1, expr*& e2) {
39         if (a.is_int(e1) && a.is_real(e2)) {
40             e1 = a.mk_to_real(e1);
41         }
42         else if (a.is_int(e2) && a.is_real(e1)) {
43             e2 = a.mk_to_real(e2);
44         }
45     }
46 
47     // TBD: arith_decl_util now supports coercion, so this should be deprecated.
mk_add(expr * e1,expr * e2)48     app* farkas_util::mk_add(expr* e1, expr* e2) {
49         mk_coerce(e1, e2);
50         return a.mk_add(e1, e2);
51     }
52 
mk_mul(expr * e1,expr * e2)53     app* farkas_util::mk_mul(expr* e1, expr* e2) {
54         mk_coerce(e1, e2);
55         return a.mk_mul(e1, e2);
56     }
57 
mk_le(expr * e1,expr * e2)58     app* farkas_util::mk_le(expr* e1, expr* e2) {
59         mk_coerce(e1, e2);
60         return a.mk_le(e1, e2);
61     }
62 
mk_ge(expr * e1,expr * e2)63     app* farkas_util::mk_ge(expr* e1, expr* e2) {
64         mk_coerce(e1, e2);
65         return a.mk_ge(e1, e2);
66     }
67 
mk_gt(expr * e1,expr * e2)68     app* farkas_util::mk_gt(expr* e1, expr* e2) {
69         mk_coerce(e1, e2);
70         return a.mk_gt(e1, e2);
71     }
72 
mk_lt(expr * e1,expr * e2)73     app* farkas_util::mk_lt(expr* e1, expr* e2) {
74         mk_coerce(e1, e2);
75         return a.mk_lt(e1, e2);
76     }
77 
mul(rational const & c,expr * e,expr_ref & res)78     void farkas_util::mul(rational const& c, expr* e, expr_ref& res) {
79         expr_ref tmp(m);
80         if (c.is_one()) {
81             tmp = e;
82         }
83         else {
84             tmp = mk_mul(a.mk_numeral(c, c.is_int() && a.is_int(e)), e);
85         }
86         res = mk_add(res, tmp);
87     }
88 
is_int_sort(app * c)89     bool farkas_util::is_int_sort(app* c) {
90         SASSERT(m.is_eq(c) || a.is_le(c) || a.is_lt(c) || a.is_gt(c) || a.is_ge(c));
91         SASSERT(a.is_int(c->get_arg(0)) || a.is_real(c->get_arg(0)));
92         return a.is_int(c->get_arg(0));
93     }
94 
is_int_sort()95     bool farkas_util::is_int_sort() {
96         SASSERT(!m_ineqs.empty());
97         return is_int_sort(m_ineqs[0].get());
98     }
99 
normalize_coeffs()100     void farkas_util::normalize_coeffs() {
101         rational l(1);
102         for (unsigned i = 0; i < m_coeffs.size(); ++i) {
103             l = lcm(l, denominator(m_coeffs[i]));
104         }
105         if (!l.is_one()) {
106             for (unsigned i = 0; i < m_coeffs.size(); ++i) {
107                 m_coeffs[i] *= l;
108             }
109         }
110         m_normalize_factor = l;
111     }
112 
mk_one()113     app* farkas_util::mk_one() {
114         return a.mk_numeral(rational(1), true);
115     }
116 
fix_sign(bool is_pos,app * c)117     app* farkas_util::fix_sign(bool is_pos, app* c) {
118         expr* x, *y;
119         SASSERT(m.is_eq(c) || a.is_le(c) || a.is_lt(c) || a.is_gt(c) || a.is_ge(c));
120         bool is_int = is_int_sort(c);
121         if (is_int && is_pos && (a.is_lt(c, x, y) || a.is_gt(c, y, x))) {
122             return mk_le(mk_add(x, mk_one()), y);
123         }
124         if (is_int && !is_pos && (a.is_le(c, x, y) || a.is_ge(c, y, x))) {
125             // !(x <= y) <=> x > y <=> x >= y + 1
126             return mk_ge(x, mk_add(y, mk_one()));
127         }
128         if (is_pos) {
129             return c;
130         }
131         if (a.is_le(c, x, y)) return mk_gt(x, y);
132         if (a.is_lt(c, x, y)) return mk_ge(x, y);
133         if (a.is_ge(c, x, y)) return mk_lt(x, y);
134         if (a.is_gt(c, x, y)) return mk_le(x, y);
135         UNREACHABLE();
136         return c;
137     }
138 
partition_ineqs()139     void farkas_util::partition_ineqs() {
140         m_reps.reset();
141         m_his.reset();
142         ++m_time;
143         for (unsigned i = 0; i < m_ineqs.size(); ++i) {
144             m_reps.push_back(process_term(m_ineqs[i].get()));
145         }
146         unsigned head = 0;
147         while (head < m_ineqs.size()) {
148             unsigned r = find(m_reps[head]);
149             unsigned tail = head;
150             for (unsigned i = head+1; i < m_ineqs.size(); ++i) {
151                 if (find(m_reps[i]) == r) {
152                     ++tail;
153                     if (tail != i) {
154                         SASSERT(tail < i);
155                         std::swap(m_reps[tail], m_reps[i]);
156                         app_ref tmp(m);
157                         tmp = m_ineqs[i].get();
158                         m_ineqs[i] = m_ineqs[tail].get();
159                         m_ineqs[tail] = tmp;
160                         std::swap(m_coeffs[tail], m_coeffs[i]);
161                     }
162                 }
163             }
164             head = tail + 1;
165             m_his.push_back(head);
166         }
167     }
168 
find(unsigned idx)169     unsigned farkas_util::find(unsigned idx) {
170         if (m_ts.size() <= idx) {
171             m_roots.resize(idx+1);
172             m_size.resize(idx+1);
173             m_ts.resize(idx+1);
174             m_roots[idx] = idx;
175             m_ts[idx] = m_time;
176             m_size[idx] = 1;
177             return idx;
178         }
179         if (m_ts[idx] != m_time) {
180             m_size[idx] = 1;
181             m_ts[idx]    = m_time;
182             m_roots[idx] = idx;
183             return idx;
184         }
185         while (true) {
186             if (m_roots[idx] == idx) {
187                 return idx;
188             }
189             idx = m_roots[idx];
190         }
191     }
192 
merge(unsigned i,unsigned j)193     void farkas_util::merge(unsigned i, unsigned j) {
194         i = find(i);
195         j = find(j);
196         if (i == j) {
197             return;
198         }
199         if (m_size[i] > m_size[j]) {
200             std::swap(i, j);
201         }
202         m_roots[i] = j;
203         m_size[j] += m_size[i];
204     }
process_term(expr * e)205     unsigned farkas_util::process_term(expr* e) {
206         unsigned r = e->get_id();
207         ptr_vector<expr> todo;
208         ast_mark mark;
209         todo.push_back(e);
210         while (!todo.empty()) {
211             e = todo.back();
212             todo.pop_back();
213             if (mark.is_marked(e)) {
214                 continue;
215             }
216             mark.mark(e, true);
217             if (is_uninterp(e)) {
218                 merge(r, e->get_id());
219             }
220             if (is_app(e)) {
221                 app* a = to_app(e);
222                 for (unsigned i = 0; i < a->get_num_args(); ++i) {
223                     todo.push_back(a->get_arg(i));
224                 }
225             }
226         }
227         return r;
228     }
extract_consequence(unsigned lo,unsigned hi)229     expr_ref farkas_util::extract_consequence(unsigned lo, unsigned hi) {
230         bool is_int = is_int_sort();
231         app_ref zero(a.mk_numeral(rational::zero(), is_int), m);
232         expr_ref res(m);
233         res = zero;
234         bool is_strict = false;
235         bool is_eq     = true;
236         expr* x, *y;
237         for (unsigned i = lo; i < hi; ++i) {
238             app* c = m_ineqs[i].get();
239             if (m.is_eq(c, x, y)) {
240                 mul(m_coeffs[i],  x, res);
241                 mul(-m_coeffs[i], y, res);
242             }
243             if (a.is_lt(c, x, y) || a.is_gt(c, y, x)) {
244                 mul(m_coeffs[i],  x, res);
245                 mul(-m_coeffs[i], y, res);
246                 is_strict = true;
247                 is_eq = false;
248             }
249             if (a.is_le(c, x, y) || a.is_ge(c, y, x)) {
250                 mul(m_coeffs[i],  x, res);
251                 mul(-m_coeffs[i], y, res);
252                 is_eq = false;
253             }
254         }
255 
256         zero = a.mk_numeral(rational::zero(), a.is_int(res));
257         if (is_eq) {
258             res = m.mk_eq(res, zero);
259         }
260         else if (is_strict) {
261             res = mk_lt(res, zero);
262         }
263         else {
264             res = mk_le(res, zero);
265         }
266         res = m.mk_not(res);
267         th_rewriter rw(m);
268         params_ref params;
269         params.set_bool("gcd_rounding", true);
270         rw.updt_params(params);
271         proof_ref pr(m);
272         expr_ref result(m);
273         rw(res, result, pr);
274         fix_dl(result);
275         return result;
276     }
277 
fix_dl(expr_ref & r)278     void farkas_util::fix_dl(expr_ref& r) {
279         expr* e;
280         if (m.is_not(r, e)) {
281             r = e;
282             fix_dl(r);
283             r = m.mk_not(r);
284             return;
285         }
286         expr* e1, *e2, *e3, *e4;
287         if ((m.is_eq(r, e1, e2) || a.is_lt(r, e1, e2) || a.is_gt(r, e1, e2) ||
288              a.is_le(r, e1, e2) || a.is_ge(r, e1, e2))) {
289             if (a.is_add(e1, e3, e4) && a.is_mul(e3)) {
290                 r = m.mk_app(to_app(r)->get_decl(), a.mk_add(e4,e3), e2);
291             }
292         }
293     }
294 
reset()295     void farkas_util::reset() {
296         m_ineqs.reset();
297         m_coeffs.reset();
298     }
299 
add(rational const & coef,app * c)300     bool farkas_util::add(rational const & coef, app * c) {
301         bool is_pos = true;
302         expr* e;
303         while (m.is_not(c, e)) {
304             is_pos = !is_pos;
305             c = to_app(e);
306         }
307 
308         if (!coef.is_zero() && !m.is_true(c)) {
309             if (m.is_eq(c) || a.is_le(c) || a.is_lt(c) || a.is_gt(c) || a.is_ge(c)) {
310                 m_coeffs.push_back(coef);
311                 m_ineqs.push_back(fix_sign(is_pos, c));
312             }
313             else {
314                 return false;
315             }
316         }
317         return true;
318     }
319 
get()320     expr_ref farkas_util::get() {
321         TRACE("arith",
322               for (unsigned i = 0; i < m_coeffs.size(); ++i) {
323                   tout << m_coeffs[i] << " * (" << mk_pp(m_ineqs[i].get(), m) << ") ";
324               }
325               tout << "\n";
326               );
327 
328         m_normalize_factor = rational::one();
329         expr_ref res(m);
330         if (m_coeffs.empty()) {
331             res = m.mk_false();
332             return res;
333         }
334         bool is_int = is_int_sort();
335         if (is_int) {
336             normalize_coeffs();
337         }
338 
339         if (m_split_literals) {
340             // partition equalities into variable disjoint sets.
341             // take the conjunction of these instead of the
342             // linear combination.
343             partition_ineqs();
344             expr_ref_vector lits(m);
345             unsigned lo = 0;
346             for (unsigned hi : m_his) {
347                 lits.push_back(extract_consequence(lo, hi));
348                 lo = hi;
349             }
350             bool_rewriter(m).mk_or(lits.size(), lits.data(), res);
351             IF_VERBOSE(2, { if (lits.size() > 1) { verbose_stream() << "combined lemma: " << res << "\n"; } });
352         }
353         else {
354             res = extract_consequence(0, m_coeffs.size());
355         }
356 
357         TRACE("arith",
358               for (unsigned i = 0; i < m_coeffs.size(); ++i) {
359                   tout << m_coeffs[i] << " * (" << mk_pp(m_ineqs[i].get(), m) << ") ";
360               }
361               tout << "\n";
362               tout << res << "\n";
363               );
364 
365         return res;
366     }
367 }
368 
369