1 
2 /*++
3 Copyright (c) 2015 Microsoft Corporation
4 
5 --*/
6 
7 #include "util/trace.h"
8 #include "util/vector.h"
9 #include "util/sorting_network.h"
10 #include "ast/ast.h"
11 #include "ast/ast_pp.h"
12 #include "ast/reg_decl_plugins.h"
13 #include "ast/ast_util.h"
14 #include "model/model_smt2_pp.h"
15 #include "smt/smt_kernel.h"
16 #include "smt/params/smt_params.h"
17 
18 struct ast_ext {
19     ast_manager& m;
ast_extast_ext20     ast_ext(ast_manager& m):m(m) {}
21     typedef expr* T;
22     typedef expr_ref_vector vector;
mk_iteast_ext23     T mk_ite(T a, T b, T c) {
24         return m.mk_ite(a, b, c);
25     }
mk_least_ext26     T mk_le(T a, T b) {
27         if (m.is_bool(a)) {
28                 return m.mk_implies(a, b);
29         }
30         UNREACHABLE();
31         return nullptr;
32     }
mk_defaultast_ext33     T mk_default() {
34         return m.mk_false();
35     }
36 };
37 
38 
39 
40 struct unsigned_ext {
unsigned_extunsigned_ext41     unsigned_ext() {}
42     typedef unsigned T;
43     typedef svector<unsigned> vector;
mk_iteunsigned_ext44     T mk_ite(T a, T b, T c) {
45         return (a==1)?b:c;
46     }
mk_leunsigned_ext47     T mk_le(T a, T b) {
48         return (a <= b)?1:0;
49     }
mk_defaultunsigned_ext50     T mk_default() {
51         return 0;
52     }
53 };
54 
55 
is_sorted(svector<unsigned> const & v)56 static void is_sorted(svector<unsigned> const& v) {
57     for (unsigned i = 0; i + 1 < v.size(); ++i) {
58         ENSURE(v[i] <= v[i+1]);
59     }
60 }
61 
test_sorting1()62 static void test_sorting1() {
63     svector<unsigned> in, out;
64     unsigned_ext uext;
65     sorting_network<unsigned_ext> sn(uext);
66 
67     in.push_back(0);
68     in.push_back(1);
69     in.push_back(0);
70     in.push_back(1);
71     in.push_back(1);
72     in.push_back(0);
73 
74     sn(in, out);
75 
76     is_sorted(out);
77     for (unsigned i = 0; i < out.size(); ++i) {
78         std::cout << out[i];
79     }
80     std::cout << "\n";
81 }
82 
test_sorting2()83 static void test_sorting2() {
84     svector<unsigned> in, out;
85     unsigned_ext uext;
86     sorting_network<unsigned_ext> sn(uext);
87 
88     in.push_back(0);
89     in.push_back(1);
90     in.push_back(2);
91     in.push_back(1);
92     in.push_back(1);
93     in.push_back(3);
94 
95     sn(in, out);
96 
97     is_sorted(out);
98 
99     for (unsigned i = 0; i < out.size(); ++i) {
100         std::cout << out[i];
101     }
102     std::cout << "\n";
103 }
104 
test_sorting4_r(unsigned i,svector<unsigned> & in)105 static void test_sorting4_r(unsigned i, svector<unsigned>& in) {
106     if (i == in.size()) {
107         svector<unsigned> out;
108         unsigned_ext uext;
109         sorting_network<unsigned_ext> sn(uext);
110         sn(in, out);
111         is_sorted(out);
112         std::cout << "sorted\n";
113     }
114     else {
115         in[i] = 0;
116         test_sorting4_r(i+1, in);
117         in[i] = 1;
118         test_sorting4_r(i+1, in);
119     }
120 }
121 
test_sorting4()122 static void test_sorting4() {
123     svector<unsigned> in;
124     in.resize(5);
125     test_sorting4_r(0, in);
126     in.resize(8);
127     test_sorting4_r(0, in);
128 }
129 
test_sorting3()130 void test_sorting3() {
131     ast_manager m;
132     reg_decl_plugins(m);
133     expr_ref_vector in(m), out(m);
134     for (unsigned i = 0; i < 7; ++i) {
135         in.push_back(m.mk_fresh_const("a",m.mk_bool_sort()));
136     }
137     for (expr* e : in) std::cout << mk_pp(e, m) << "\n";
138     ast_ext aext(m);
139     sorting_network<ast_ext> sn(aext);
140     sn(in, out);
141     std::cout << "size: " << out.size() << "\n";
142     for (expr* e : out) std::cout << mk_pp(e, m) << "\n";
143 }
144 
145 
146 struct ast_ext2 {
147     ast_manager& m;
148     expr_ref_vector m_clauses;
149     expr_ref_vector m_trail;
ast_ext2ast_ext2150     ast_ext2(ast_manager& m):m(m), m_clauses(m), m_trail(m) {}
151     typedef expr* pliteral;
152     typedef ptr_vector<expr> pliteral_vector;
153 
trailast_ext2154     expr* trail(expr* e) {
155         m_trail.push_back(e);
156         return e;
157     }
158 
mk_falseast_ext2159     pliteral mk_false() { return m.mk_false(); }
mk_trueast_ext2160     pliteral mk_true() { return m.mk_true(); }
mk_maxast_ext2161     pliteral mk_max(unsigned n, pliteral const* lits) {
162         return trail(m.mk_or(n, lits));
163     }
mk_minast_ext2164     pliteral mk_min(unsigned n, pliteral const* lits) {
165         return trail(m.mk_and(n, lits));
166     }
mk_notast_ext2167     pliteral mk_not(pliteral a) { if (m.is_not(a,a)) return a;
168         return trail(m.mk_not(a));
169     }
ppast_ext2170     std::ostream& pp(std::ostream& out, pliteral lit) {
171         return out << mk_pp(lit, m);
172     }
freshast_ext2173     pliteral fresh(char const* n) {
174         return trail(m.mk_fresh_const(n, m.mk_bool_sort()));
175     }
mk_clauseast_ext2176     void mk_clause(unsigned n, pliteral const* lits) {
177         m_clauses.push_back(mk_or(m, n, lits));
178     }
179 };
180 
test_eq1(unsigned n,sorting_network_encoding enc)181 static void test_eq1(unsigned n, sorting_network_encoding enc) {
182     //std::cout << "test eq1 " << n << " for encoding: " << enc << "\n";
183     ast_manager m;
184     reg_decl_plugins(m);
185     ast_ext2 ext(m);
186     expr_ref_vector in(m), out(m);
187     for (unsigned i = 0; i < n; ++i) {
188         in.push_back(m.mk_fresh_const("a",m.mk_bool_sort()));
189     }
190     smt_params fp;
191     smt::kernel solver(m, fp);
192     psort_nw<ast_ext2> sn(ext);
193     sn.cfg().m_encoding = enc;
194 
195     expr_ref result1(m), result2(m);
196 
197     // equality:
198     solver.push();
199     result1 = sn.eq(true, 1, in.size(), in.c_ptr());
200     for (expr* cls : ext.m_clauses) {
201         solver.assert_expr(cls);
202     }
203     expr_ref_vector ors(m);
204     for (unsigned i = 0; i < n; ++i) {
205         expr_ref_vector ands(m);
206         for (unsigned j = 0; j < n; ++j) {
207             ands.push_back(j == i ? in[j].get() : m.mk_not(in[j].get()));
208         }
209         ors.push_back(mk_and(ands));
210     }
211     result2 = mk_or(ors);
212     solver.assert_expr(m.mk_not(m.mk_eq(result1, result2)));
213     //std::cout << ext.m_clauses << "\n";
214     //std::cout << result1 << "\n";
215     //std::cout << result2 << "\n";
216     lbool res = solver.check();
217     if (res == l_true) {
218         model_ref model;
219         solver.get_model(model);
220         model_smt2_pp(std::cout, m, *model, 0);
221         TRACE("pb", model_smt2_pp(tout, m, *model, 0););
222     }
223     ENSURE(l_false == res);
224     ext.m_clauses.reset();
225 }
226 
test_sorting_eq(unsigned n,unsigned k,sorting_network_encoding enc)227 static void test_sorting_eq(unsigned n, unsigned k, sorting_network_encoding enc) {
228     ENSURE(k < n);
229     ast_manager m;
230     reg_decl_plugins(m);
231     ast_ext2 ext(m);
232     expr_ref_vector in(m), out(m);
233     for (unsigned i = 0; i < n; ++i) {
234         in.push_back(m.mk_fresh_const("a",m.mk_bool_sort()));
235     }
236     smt_params fp;
237     smt::kernel solver(m, fp);
238     psort_nw<ast_ext2> sn(ext);
239     sn.cfg().m_encoding = enc;
240     expr_ref result(m);
241 
242     // equality:
243     std::cout << "eq " << k << " out of " << n << " for encoding " << enc << "\n";
244     solver.push();
245     result = sn.eq(false, k, in.size(), in.c_ptr());
246     solver.assert_expr(result);
247     for (expr* cl : ext.m_clauses) {
248         solver.assert_expr(cl);
249     }
250     lbool res = solver.check();
251     if (res != l_true) {
252         std::cout << res << "\n";
253         solver.display(std::cout);
254     }
255     ENSURE(res == l_true);
256 
257     solver.push();
258     for (unsigned i = 0; i < k; ++i) {
259         solver.assert_expr(in[i].get());
260     }
261     res = solver.check();
262     if (res != l_true) {
263         std::cout << result << "\n" << ext.m_clauses << "\n";
264     }
265     ENSURE(res == l_true);
266     solver.assert_expr(in[k].get());
267     res = solver.check();
268     if (res == l_true) {
269         TRACE("pb",
270               unsigned sz = solver.size();
271               for (unsigned i = 0; i < sz; ++i) {
272                   tout << mk_pp(solver.get_formula(i), m) << "\n";
273               });
274         model_ref model;
275         solver.get_model(model);
276         model_smt2_pp(std::cout, m, *model, 0);
277         TRACE("pb", model_smt2_pp(tout, m, *model, 0););
278     }
279     ENSURE(res == l_false);
280     solver.pop(1);
281     ext.m_clauses.reset();
282 }
283 
test_sorting_le(unsigned n,unsigned k,sorting_network_encoding enc)284 static void test_sorting_le(unsigned n, unsigned k, sorting_network_encoding enc) {
285     ast_manager m;
286     reg_decl_plugins(m);
287     ast_ext2 ext(m);
288     expr_ref_vector in(m), out(m);
289     for (unsigned i = 0; i < n; ++i) {
290         in.push_back(m.mk_fresh_const("a",m.mk_bool_sort()));
291     }
292     smt_params fp;
293     smt::kernel solver(m, fp);
294     psort_nw<ast_ext2> sn(ext);
295     sn.cfg().m_encoding = enc;
296     expr_ref result(m);
297     // B <= k
298     std::cout << "le " << k << "\n";
299     solver.push();
300     result = sn.le(false, k, in.size(), in.c_ptr());
301     solver.assert_expr(result);
302     for (expr* cls : ext.m_clauses) {
303         solver.assert_expr(cls);
304     }
305     lbool res = solver.check();
306     if (res != l_true) {
307         std::cout << res << "\n";
308         solver.display(std::cout);
309         std::cout << "clauses: " << ext.m_clauses << "\n";
310         std::cout << "result: " << result << "\n";
311     }
312     ENSURE(res == l_true);
313 
314     for (unsigned i = 0; i < k; ++i) {
315         solver.assert_expr(in[i].get());
316     }
317     res = solver.check();
318     if (res != l_true) {
319         std::cout << res << "\n";
320         solver.display(std::cout);
321     }
322     ENSURE(res == l_true);
323     solver.assert_expr(in[k].get());
324     res = solver.check();
325     if (res == l_true) {
326         TRACE("pb",
327               unsigned sz = solver.size();
328               for (unsigned i = 0; i < sz; ++i) {
329                   tout << mk_pp(solver.get_formula(i), m) << "\n";
330               });
331         model_ref model;
332         solver.get_model(model);
333         model_smt2_pp(std::cout, m, *model, 0);
334         TRACE("pb", model_smt2_pp(tout, m, *model, 0););
335     }
336     ENSURE(res == l_false);
337     solver.pop(1);
338     ext.m_clauses.reset();
339 }
340 
341 
test_sorting_ge(unsigned n,unsigned k,sorting_network_encoding enc)342 void test_sorting_ge(unsigned n, unsigned k, sorting_network_encoding enc) {
343     ast_manager m;
344     reg_decl_plugins(m);
345     ast_ext2 ext(m);
346     expr_ref_vector in(m), out(m);
347     for (unsigned i = 0; i < n; ++i) {
348         in.push_back(m.mk_fresh_const("a",m.mk_bool_sort()));
349     }
350     smt_params fp;
351     smt::kernel solver(m, fp);
352     psort_nw<ast_ext2> sn(ext);
353     sn.cfg().m_encoding = enc;
354     expr_ref result(m);
355     // k <= B
356     std::cout << "ge " << k << "\n";
357     solver.push();
358     result = sn.ge(false, k, in.size(), in.c_ptr());
359     solver.assert_expr(result);
360     for (expr* cls : ext.m_clauses) {
361         solver.assert_expr(cls);
362     }
363     lbool res = solver.check();
364     ENSURE(res == l_true);
365 
366     solver.push();
367     for (unsigned i = 0; i < n - k; ++i) {
368         solver.assert_expr(m.mk_not(in[i].get()));
369     }
370     res = solver.check();
371     ENSURE(res == l_true);
372     solver.assert_expr(m.mk_not(in[n - k].get()));
373     res = solver.check();
374     if (res == l_true) {
375         TRACE("pb",
376               unsigned sz = solver.size();
377               for (unsigned i = 0; i < sz; ++i) {
378                   tout << mk_pp(solver.get_formula(i), m) << "\n";
379               });
380         model_ref model;
381         solver.get_model(model);
382         model_smt2_pp(std::cout, m, *model, 0);
383         TRACE("pb", model_smt2_pp(tout, m, *model, 0););
384     }
385     ENSURE(res == l_false);
386     solver.pop(1);
387 }
388 
test_sorting5(unsigned n,unsigned k,sorting_network_encoding enc)389 void test_sorting5(unsigned n, unsigned k, sorting_network_encoding enc) {
390     std::cout << "n: " << n << " k: " << k << "\n";
391     test_sorting_le(n, k, enc);
392     test_sorting_eq(n, k, enc);
393     test_sorting_ge(n, k, enc);
394 }
395 
naive_at_most1(expr_ref_vector const & xs)396 expr_ref naive_at_most1(expr_ref_vector const& xs) {
397     ast_manager& m = xs.get_manager();
398     expr_ref_vector clauses(m);
399     for (unsigned i = 0; i < xs.size(); ++i) {
400         for (unsigned j = i + 1; j < xs.size(); ++j) {
401             clauses.push_back(m.mk_not(m.mk_and(xs[i], xs[j])));
402         }
403     }
404     return mk_and(clauses);
405 }
406 
test_at_most_1(unsigned n,bool full,sorting_network_encoding enc)407 void test_at_most_1(unsigned n, bool full, sorting_network_encoding enc) {
408     ast_manager m;
409     reg_decl_plugins(m);
410     expr_ref_vector in(m), out(m);
411     for (unsigned i = 0; i < n; ++i) {
412         in.push_back(m.mk_fresh_const("a",m.mk_bool_sort()));
413     }
414 
415     ast_ext2 ext(m);
416     psort_nw<ast_ext2> sn(ext);
417     sn.cfg().m_encoding = enc;
418     expr_ref result1(m), result2(m);
419     result1 = sn.le(full, 1, in.size(), in.c_ptr());
420     result2 = naive_at_most1(in);
421 
422 
423     std::cout << "clauses: " << ext.m_clauses << "\n-----\n";
424     //std::cout << "encoded: " << result1 << "\n";
425     //std::cout << "naive: " << result2 << "\n";
426 
427     smt_params fp;
428     smt::kernel solver(m, fp);
429     for (expr* cls : ext.m_clauses) {
430         solver.assert_expr(cls);
431     }
432     if (full) {
433         solver.push();
434         solver.assert_expr(m.mk_not(m.mk_eq(result1, result2)));
435 
436         std::cout << result1 << "\n";
437         lbool res = solver.check();
438         if (res == l_true) {
439             model_ref model;
440             solver.get_model(model);
441             model_smt2_pp(std::cout, m, *model, 0);
442         }
443 
444         VERIFY(l_false == res);
445 
446         solver.pop(1);
447     }
448 
449     if (n >= 9) return;
450     if (n <= 1) return;
451     for (unsigned i = 0; i < static_cast<unsigned>(1 << n); ++i) {
452         std::cout << "checking n: " << n << " bits: ";
453         for (unsigned j = 0; j < n; ++j) {
454             bool is_true = (i & (1 << j)) != 0;
455             std::cout << (is_true?"1":"0");
456         }
457         std::cout << "\n";
458         solver.push();
459         unsigned k = 0;
460         for (unsigned j = 0; j < n; ++j) {
461             bool is_true = (i & (1 << j)) != 0;
462             expr_ref atom(m);
463             atom = is_true ? in[j].get() : m.mk_not(in[j].get());
464             solver.assert_expr(atom);
465             std::cout << atom << "\n";
466             if (is_true) ++k;
467         }
468         if (k > 1) {
469             solver.assert_expr(result1);
470         }
471         else if (!full) {
472             solver.pop(1);
473             continue;
474         }
475         else {
476             solver.assert_expr(m.mk_not(result1));
477         }
478         VERIFY(l_false == solver.check());
479         solver.pop(1);
480     }
481 }
482 
483 
test_at_most1(sorting_network_encoding enc)484 static void test_at_most1(sorting_network_encoding enc) {
485     ast_manager m;
486     reg_decl_plugins(m);
487     expr_ref_vector in(m), out(m);
488     for (unsigned i = 0; i < 5; ++i) {
489         in.push_back(m.mk_fresh_const("a",m.mk_bool_sort()));
490     }
491     in[4] = in[3].get();
492 
493     ast_ext2 ext(m);
494     psort_nw<ast_ext2> sn(ext);
495     sn.cfg().m_encoding = enc;
496     expr_ref result(m);
497     result = sn.le(true, 1, in.size(), in.c_ptr());
498     //std::cout << result << "\n";
499     //std::cout << ext.m_clauses << "\n";
500 }
501 
test_sorting5(sorting_network_encoding enc)502 static void test_sorting5(sorting_network_encoding enc) {
503     test_sorting_eq(11,7, enc);
504     for (unsigned n = 3; n < 20; n += 2) {
505         for (unsigned k = 1; k < n; ++k) {
506             test_sorting5(n, k, enc);
507         }
508     }
509 }
510 
tst_sorting_network(sorting_network_encoding enc)511 static void tst_sorting_network(sorting_network_encoding enc) {
512     for (unsigned i = 1; i < 17; ++i) {
513         test_at_most_1(i, true, enc);
514         test_at_most_1(i, false, enc);
515     }
516     for (unsigned n = 2; n < 20; ++n) {
517         std::cout << "verify eq-1 out of " << n << "\n";
518         test_sorting_eq(n, 1, enc);
519         test_eq1(n, enc);
520     }
521     test_at_most1(enc);
522     test_sorting5(enc);
523 }
524 
test_pb(unsigned max_w,unsigned sz,unsigned_vector & ws)525 static void test_pb(unsigned max_w, unsigned sz, unsigned_vector& ws) {
526     if (ws.empty()) {
527         for (unsigned w = 1; w <= max_w; ++w) {
528             ws.push_back(w);
529             test_pb(max_w, sz, ws);
530             ws.pop_back();
531         }
532     }
533     else if (ws.size() < sz) {
534         for (unsigned w = ws.back(); w <= max_w; ++w) {
535             ws.push_back(w);
536             test_pb(max_w, sz, ws);
537             ws.pop_back();
538         }
539     }
540     else {
541         SASSERT(ws.size() == sz);
542         ast_manager m;
543         reg_decl_plugins(m);
544         expr_ref_vector xs(m), nxs(m);
545         expr_ref ge(m), eq(m);
546         smt_params fp;
547         smt::kernel solver(m, fp);
548         for (unsigned i = 0; i < sz; ++i) {
549             xs.push_back(m.mk_const(symbol(i), m.mk_bool_sort()));
550             nxs.push_back(m.mk_not(xs.back()));
551         }
552         std::cout << ws << " " << "\n";
553         for (unsigned k = max_w + 1; k < ws.size()*max_w; ++k) {
554 
555             ast_ext2 ext(m);
556             psort_nw<ast_ext2> sn(ext);
557             solver.push();
558             //std::cout << "bound: " << k << "\n";
559             //std::cout << ws << " " << xs << "\n";
560             ge = sn.ge(k, sz, ws.c_ptr(), xs.c_ptr());
561             //std::cout << "ge: " << ge << "\n";
562             for (expr* cls : ext.m_clauses) {
563                 solver.assert_expr(cls);
564             }
565             // solver.display(std::cout);
566             // for each truth assignment to xs, validate
567             // that circuit computes the right value for ge
568             for (unsigned i = 0; i < (1ul << sz); ++i) {
569                 solver.push();
570                 unsigned sum = 0;
571                 for (unsigned j = 0; j < sz; ++j) {
572                     if (0 == ((1 << j) & i)) {
573                         solver.assert_expr(xs.get(j));
574                         sum += ws[j];
575                     }
576                     else {
577                         solver.assert_expr(nxs.get(j));
578                     }
579                 }
580                 // std::cout << "bound: " << k << "\n";
581                 // std::cout << ws << " " << xs << "\n";
582                 // std::cout << sum << " >= " << k << " : " << (sum >= k) << " ";
583                 solver.push();
584                 if (sum < k) {
585                     solver.assert_expr(m.mk_not(ge));
586                 }
587                 else {
588                     solver.assert_expr(ge);
589                 }
590                 // solver.display(std::cout) << "\n";
591                 VERIFY(solver.check() == l_true);
592                 solver.pop(1);
593 
594                 solver.push();
595                 if (sum >= k) {
596                     solver.assert_expr(m.mk_not(ge));
597                 }
598                 else {
599                     solver.assert_expr(ge);
600                 }
601                 // solver.display(std::cout) << "\n";
602                 VERIFY(l_false == solver.check());
603                 solver.pop(1);
604                 solver.pop(1);
605             }
606             solver.pop(1);
607 
608             solver.push();
609             eq = sn.eq(k, sz, ws.c_ptr(), xs.c_ptr());
610 
611             for (expr* cls : ext.m_clauses) {
612                 solver.assert_expr(cls);
613             }
614             // for each truth assignment to xs, validate
615             // that circuit computes the right value for ge
616             for (unsigned i = 0; i < (1ul << sz); ++i) {
617                 solver.push();
618                 unsigned sum = 0;
619                 for (unsigned j = 0; j < sz; ++j) {
620                     if (0 == ((1 << j) & i)) {
621                         solver.assert_expr(xs.get(j));
622                         sum += ws[j];
623                     }
624                     else {
625                         solver.assert_expr(nxs.get(j));
626                     }
627                 }
628                 solver.push();
629                 if (sum != k) {
630                     solver.assert_expr(m.mk_not(eq));
631                 }
632                 else {
633                     solver.assert_expr(eq);
634                 }
635                 // solver.display(std::cout) << "\n";
636                 VERIFY(solver.check() == l_true);
637                 solver.pop(1);
638 
639                 solver.push();
640                 if (sum == k) {
641                     solver.assert_expr(m.mk_not(eq));
642                 }
643                 else {
644                     solver.assert_expr(eq);
645                 }
646                 VERIFY(l_false == solver.check());
647                 solver.pop(1);
648                 solver.pop(1);
649             }
650 
651             solver.pop(1);
652         }
653     }
654 }
655 
tst_pb()656 static void tst_pb() {
657     unsigned_vector ws;
658     test_pb(3, 3, ws);
659 }
660 
tst_sorting_network()661 void tst_sorting_network() {
662     tst_pb();
663     tst_sorting_network(sorting_network_encoding::unate_at_most);
664     tst_sorting_network(sorting_network_encoding::circuit_at_most);
665     tst_sorting_network(sorting_network_encoding::ordered_at_most);
666     tst_sorting_network(sorting_network_encoding::grouped_at_most);
667     tst_sorting_network(sorting_network_encoding::bimander_at_most);
668     test_sorting1();
669     test_sorting2();
670     test_sorting3();
671     test_sorting4();
672 }
673 
674