1
2 /*++
3 Copyright (c) 2015 Microsoft Corporation
4
5 --*/
6
7 #include "util/trace.h"
8 #include "util/vector.h"
9 #include "util/sorting_network.h"
10 #include "ast/ast.h"
11 #include "ast/ast_pp.h"
12 #include "ast/reg_decl_plugins.h"
13 #include "ast/ast_util.h"
14 #include "model/model_smt2_pp.h"
15 #include "smt/smt_kernel.h"
16 #include "smt/params/smt_params.h"
17
18 struct ast_ext {
19 ast_manager& m;
ast_extast_ext20 ast_ext(ast_manager& m):m(m) {}
21 typedef expr* T;
22 typedef expr_ref_vector vector;
mk_iteast_ext23 T mk_ite(T a, T b, T c) {
24 return m.mk_ite(a, b, c);
25 }
mk_least_ext26 T mk_le(T a, T b) {
27 if (m.is_bool(a)) {
28 return m.mk_implies(a, b);
29 }
30 UNREACHABLE();
31 return nullptr;
32 }
mk_defaultast_ext33 T mk_default() {
34 return m.mk_false();
35 }
36 };
37
38
39
40 struct unsigned_ext {
unsigned_extunsigned_ext41 unsigned_ext() {}
42 typedef unsigned T;
43 typedef svector<unsigned> vector;
mk_iteunsigned_ext44 T mk_ite(T a, T b, T c) {
45 return (a==1)?b:c;
46 }
mk_leunsigned_ext47 T mk_le(T a, T b) {
48 return (a <= b)?1:0;
49 }
mk_defaultunsigned_ext50 T mk_default() {
51 return 0;
52 }
53 };
54
55
is_sorted(svector<unsigned> const & v)56 static void is_sorted(svector<unsigned> const& v) {
57 for (unsigned i = 0; i + 1 < v.size(); ++i) {
58 ENSURE(v[i] <= v[i+1]);
59 }
60 }
61
test_sorting1()62 static void test_sorting1() {
63 svector<unsigned> in, out;
64 unsigned_ext uext;
65 sorting_network<unsigned_ext> sn(uext);
66
67 in.push_back(0);
68 in.push_back(1);
69 in.push_back(0);
70 in.push_back(1);
71 in.push_back(1);
72 in.push_back(0);
73
74 sn(in, out);
75
76 is_sorted(out);
77 for (unsigned i = 0; i < out.size(); ++i) {
78 std::cout << out[i];
79 }
80 std::cout << "\n";
81 }
82
test_sorting2()83 static void test_sorting2() {
84 svector<unsigned> in, out;
85 unsigned_ext uext;
86 sorting_network<unsigned_ext> sn(uext);
87
88 in.push_back(0);
89 in.push_back(1);
90 in.push_back(2);
91 in.push_back(1);
92 in.push_back(1);
93 in.push_back(3);
94
95 sn(in, out);
96
97 is_sorted(out);
98
99 for (unsigned i = 0; i < out.size(); ++i) {
100 std::cout << out[i];
101 }
102 std::cout << "\n";
103 }
104
test_sorting4_r(unsigned i,svector<unsigned> & in)105 static void test_sorting4_r(unsigned i, svector<unsigned>& in) {
106 if (i == in.size()) {
107 svector<unsigned> out;
108 unsigned_ext uext;
109 sorting_network<unsigned_ext> sn(uext);
110 sn(in, out);
111 is_sorted(out);
112 std::cout << "sorted\n";
113 }
114 else {
115 in[i] = 0;
116 test_sorting4_r(i+1, in);
117 in[i] = 1;
118 test_sorting4_r(i+1, in);
119 }
120 }
121
test_sorting4()122 static void test_sorting4() {
123 svector<unsigned> in;
124 in.resize(5);
125 test_sorting4_r(0, in);
126 in.resize(8);
127 test_sorting4_r(0, in);
128 }
129
test_sorting3()130 void test_sorting3() {
131 ast_manager m;
132 reg_decl_plugins(m);
133 expr_ref_vector in(m), out(m);
134 for (unsigned i = 0; i < 7; ++i) {
135 in.push_back(m.mk_fresh_const("a",m.mk_bool_sort()));
136 }
137 for (expr* e : in) std::cout << mk_pp(e, m) << "\n";
138 ast_ext aext(m);
139 sorting_network<ast_ext> sn(aext);
140 sn(in, out);
141 std::cout << "size: " << out.size() << "\n";
142 for (expr* e : out) std::cout << mk_pp(e, m) << "\n";
143 }
144
145
146 struct ast_ext2 {
147 ast_manager& m;
148 expr_ref_vector m_clauses;
149 expr_ref_vector m_trail;
ast_ext2ast_ext2150 ast_ext2(ast_manager& m):m(m), m_clauses(m), m_trail(m) {}
151 typedef expr* pliteral;
152 typedef ptr_vector<expr> pliteral_vector;
153
trailast_ext2154 expr* trail(expr* e) {
155 m_trail.push_back(e);
156 return e;
157 }
158
mk_falseast_ext2159 pliteral mk_false() { return m.mk_false(); }
mk_trueast_ext2160 pliteral mk_true() { return m.mk_true(); }
mk_maxast_ext2161 pliteral mk_max(unsigned n, pliteral const* lits) {
162 return trail(m.mk_or(n, lits));
163 }
mk_minast_ext2164 pliteral mk_min(unsigned n, pliteral const* lits) {
165 return trail(m.mk_and(n, lits));
166 }
mk_notast_ext2167 pliteral mk_not(pliteral a) { if (m.is_not(a,a)) return a;
168 return trail(m.mk_not(a));
169 }
ppast_ext2170 std::ostream& pp(std::ostream& out, pliteral lit) {
171 return out << mk_pp(lit, m);
172 }
freshast_ext2173 pliteral fresh(char const* n) {
174 return trail(m.mk_fresh_const(n, m.mk_bool_sort()));
175 }
mk_clauseast_ext2176 void mk_clause(unsigned n, pliteral const* lits) {
177 m_clauses.push_back(mk_or(m, n, lits));
178 }
179 };
180
test_eq1(unsigned n,sorting_network_encoding enc)181 static void test_eq1(unsigned n, sorting_network_encoding enc) {
182 //std::cout << "test eq1 " << n << " for encoding: " << enc << "\n";
183 ast_manager m;
184 reg_decl_plugins(m);
185 ast_ext2 ext(m);
186 expr_ref_vector in(m), out(m);
187 for (unsigned i = 0; i < n; ++i) {
188 in.push_back(m.mk_fresh_const("a",m.mk_bool_sort()));
189 }
190 smt_params fp;
191 smt::kernel solver(m, fp);
192 psort_nw<ast_ext2> sn(ext);
193 sn.cfg().m_encoding = enc;
194
195 expr_ref result1(m), result2(m);
196
197 // equality:
198 solver.push();
199 result1 = sn.eq(true, 1, in.size(), in.c_ptr());
200 for (expr* cls : ext.m_clauses) {
201 solver.assert_expr(cls);
202 }
203 expr_ref_vector ors(m);
204 for (unsigned i = 0; i < n; ++i) {
205 expr_ref_vector ands(m);
206 for (unsigned j = 0; j < n; ++j) {
207 ands.push_back(j == i ? in[j].get() : m.mk_not(in[j].get()));
208 }
209 ors.push_back(mk_and(ands));
210 }
211 result2 = mk_or(ors);
212 solver.assert_expr(m.mk_not(m.mk_eq(result1, result2)));
213 //std::cout << ext.m_clauses << "\n";
214 //std::cout << result1 << "\n";
215 //std::cout << result2 << "\n";
216 lbool res = solver.check();
217 if (res == l_true) {
218 model_ref model;
219 solver.get_model(model);
220 model_smt2_pp(std::cout, m, *model, 0);
221 TRACE("pb", model_smt2_pp(tout, m, *model, 0););
222 }
223 ENSURE(l_false == res);
224 ext.m_clauses.reset();
225 }
226
test_sorting_eq(unsigned n,unsigned k,sorting_network_encoding enc)227 static void test_sorting_eq(unsigned n, unsigned k, sorting_network_encoding enc) {
228 ENSURE(k < n);
229 ast_manager m;
230 reg_decl_plugins(m);
231 ast_ext2 ext(m);
232 expr_ref_vector in(m), out(m);
233 for (unsigned i = 0; i < n; ++i) {
234 in.push_back(m.mk_fresh_const("a",m.mk_bool_sort()));
235 }
236 smt_params fp;
237 smt::kernel solver(m, fp);
238 psort_nw<ast_ext2> sn(ext);
239 sn.cfg().m_encoding = enc;
240 expr_ref result(m);
241
242 // equality:
243 std::cout << "eq " << k << " out of " << n << " for encoding " << enc << "\n";
244 solver.push();
245 result = sn.eq(false, k, in.size(), in.c_ptr());
246 solver.assert_expr(result);
247 for (expr* cl : ext.m_clauses) {
248 solver.assert_expr(cl);
249 }
250 lbool res = solver.check();
251 if (res != l_true) {
252 std::cout << res << "\n";
253 solver.display(std::cout);
254 }
255 ENSURE(res == l_true);
256
257 solver.push();
258 for (unsigned i = 0; i < k; ++i) {
259 solver.assert_expr(in[i].get());
260 }
261 res = solver.check();
262 if (res != l_true) {
263 std::cout << result << "\n" << ext.m_clauses << "\n";
264 }
265 ENSURE(res == l_true);
266 solver.assert_expr(in[k].get());
267 res = solver.check();
268 if (res == l_true) {
269 TRACE("pb",
270 unsigned sz = solver.size();
271 for (unsigned i = 0; i < sz; ++i) {
272 tout << mk_pp(solver.get_formula(i), m) << "\n";
273 });
274 model_ref model;
275 solver.get_model(model);
276 model_smt2_pp(std::cout, m, *model, 0);
277 TRACE("pb", model_smt2_pp(tout, m, *model, 0););
278 }
279 ENSURE(res == l_false);
280 solver.pop(1);
281 ext.m_clauses.reset();
282 }
283
test_sorting_le(unsigned n,unsigned k,sorting_network_encoding enc)284 static void test_sorting_le(unsigned n, unsigned k, sorting_network_encoding enc) {
285 ast_manager m;
286 reg_decl_plugins(m);
287 ast_ext2 ext(m);
288 expr_ref_vector in(m), out(m);
289 for (unsigned i = 0; i < n; ++i) {
290 in.push_back(m.mk_fresh_const("a",m.mk_bool_sort()));
291 }
292 smt_params fp;
293 smt::kernel solver(m, fp);
294 psort_nw<ast_ext2> sn(ext);
295 sn.cfg().m_encoding = enc;
296 expr_ref result(m);
297 // B <= k
298 std::cout << "le " << k << "\n";
299 solver.push();
300 result = sn.le(false, k, in.size(), in.c_ptr());
301 solver.assert_expr(result);
302 for (expr* cls : ext.m_clauses) {
303 solver.assert_expr(cls);
304 }
305 lbool res = solver.check();
306 if (res != l_true) {
307 std::cout << res << "\n";
308 solver.display(std::cout);
309 std::cout << "clauses: " << ext.m_clauses << "\n";
310 std::cout << "result: " << result << "\n";
311 }
312 ENSURE(res == l_true);
313
314 for (unsigned i = 0; i < k; ++i) {
315 solver.assert_expr(in[i].get());
316 }
317 res = solver.check();
318 if (res != l_true) {
319 std::cout << res << "\n";
320 solver.display(std::cout);
321 }
322 ENSURE(res == l_true);
323 solver.assert_expr(in[k].get());
324 res = solver.check();
325 if (res == l_true) {
326 TRACE("pb",
327 unsigned sz = solver.size();
328 for (unsigned i = 0; i < sz; ++i) {
329 tout << mk_pp(solver.get_formula(i), m) << "\n";
330 });
331 model_ref model;
332 solver.get_model(model);
333 model_smt2_pp(std::cout, m, *model, 0);
334 TRACE("pb", model_smt2_pp(tout, m, *model, 0););
335 }
336 ENSURE(res == l_false);
337 solver.pop(1);
338 ext.m_clauses.reset();
339 }
340
341
test_sorting_ge(unsigned n,unsigned k,sorting_network_encoding enc)342 void test_sorting_ge(unsigned n, unsigned k, sorting_network_encoding enc) {
343 ast_manager m;
344 reg_decl_plugins(m);
345 ast_ext2 ext(m);
346 expr_ref_vector in(m), out(m);
347 for (unsigned i = 0; i < n; ++i) {
348 in.push_back(m.mk_fresh_const("a",m.mk_bool_sort()));
349 }
350 smt_params fp;
351 smt::kernel solver(m, fp);
352 psort_nw<ast_ext2> sn(ext);
353 sn.cfg().m_encoding = enc;
354 expr_ref result(m);
355 // k <= B
356 std::cout << "ge " << k << "\n";
357 solver.push();
358 result = sn.ge(false, k, in.size(), in.c_ptr());
359 solver.assert_expr(result);
360 for (expr* cls : ext.m_clauses) {
361 solver.assert_expr(cls);
362 }
363 lbool res = solver.check();
364 ENSURE(res == l_true);
365
366 solver.push();
367 for (unsigned i = 0; i < n - k; ++i) {
368 solver.assert_expr(m.mk_not(in[i].get()));
369 }
370 res = solver.check();
371 ENSURE(res == l_true);
372 solver.assert_expr(m.mk_not(in[n - k].get()));
373 res = solver.check();
374 if (res == l_true) {
375 TRACE("pb",
376 unsigned sz = solver.size();
377 for (unsigned i = 0; i < sz; ++i) {
378 tout << mk_pp(solver.get_formula(i), m) << "\n";
379 });
380 model_ref model;
381 solver.get_model(model);
382 model_smt2_pp(std::cout, m, *model, 0);
383 TRACE("pb", model_smt2_pp(tout, m, *model, 0););
384 }
385 ENSURE(res == l_false);
386 solver.pop(1);
387 }
388
test_sorting5(unsigned n,unsigned k,sorting_network_encoding enc)389 void test_sorting5(unsigned n, unsigned k, sorting_network_encoding enc) {
390 std::cout << "n: " << n << " k: " << k << "\n";
391 test_sorting_le(n, k, enc);
392 test_sorting_eq(n, k, enc);
393 test_sorting_ge(n, k, enc);
394 }
395
naive_at_most1(expr_ref_vector const & xs)396 expr_ref naive_at_most1(expr_ref_vector const& xs) {
397 ast_manager& m = xs.get_manager();
398 expr_ref_vector clauses(m);
399 for (unsigned i = 0; i < xs.size(); ++i) {
400 for (unsigned j = i + 1; j < xs.size(); ++j) {
401 clauses.push_back(m.mk_not(m.mk_and(xs[i], xs[j])));
402 }
403 }
404 return mk_and(clauses);
405 }
406
test_at_most_1(unsigned n,bool full,sorting_network_encoding enc)407 void test_at_most_1(unsigned n, bool full, sorting_network_encoding enc) {
408 ast_manager m;
409 reg_decl_plugins(m);
410 expr_ref_vector in(m), out(m);
411 for (unsigned i = 0; i < n; ++i) {
412 in.push_back(m.mk_fresh_const("a",m.mk_bool_sort()));
413 }
414
415 ast_ext2 ext(m);
416 psort_nw<ast_ext2> sn(ext);
417 sn.cfg().m_encoding = enc;
418 expr_ref result1(m), result2(m);
419 result1 = sn.le(full, 1, in.size(), in.c_ptr());
420 result2 = naive_at_most1(in);
421
422
423 std::cout << "clauses: " << ext.m_clauses << "\n-----\n";
424 //std::cout << "encoded: " << result1 << "\n";
425 //std::cout << "naive: " << result2 << "\n";
426
427 smt_params fp;
428 smt::kernel solver(m, fp);
429 for (expr* cls : ext.m_clauses) {
430 solver.assert_expr(cls);
431 }
432 if (full) {
433 solver.push();
434 solver.assert_expr(m.mk_not(m.mk_eq(result1, result2)));
435
436 std::cout << result1 << "\n";
437 lbool res = solver.check();
438 if (res == l_true) {
439 model_ref model;
440 solver.get_model(model);
441 model_smt2_pp(std::cout, m, *model, 0);
442 }
443
444 VERIFY(l_false == res);
445
446 solver.pop(1);
447 }
448
449 if (n >= 9) return;
450 if (n <= 1) return;
451 for (unsigned i = 0; i < static_cast<unsigned>(1 << n); ++i) {
452 std::cout << "checking n: " << n << " bits: ";
453 for (unsigned j = 0; j < n; ++j) {
454 bool is_true = (i & (1 << j)) != 0;
455 std::cout << (is_true?"1":"0");
456 }
457 std::cout << "\n";
458 solver.push();
459 unsigned k = 0;
460 for (unsigned j = 0; j < n; ++j) {
461 bool is_true = (i & (1 << j)) != 0;
462 expr_ref atom(m);
463 atom = is_true ? in[j].get() : m.mk_not(in[j].get());
464 solver.assert_expr(atom);
465 std::cout << atom << "\n";
466 if (is_true) ++k;
467 }
468 if (k > 1) {
469 solver.assert_expr(result1);
470 }
471 else if (!full) {
472 solver.pop(1);
473 continue;
474 }
475 else {
476 solver.assert_expr(m.mk_not(result1));
477 }
478 VERIFY(l_false == solver.check());
479 solver.pop(1);
480 }
481 }
482
483
test_at_most1(sorting_network_encoding enc)484 static void test_at_most1(sorting_network_encoding enc) {
485 ast_manager m;
486 reg_decl_plugins(m);
487 expr_ref_vector in(m), out(m);
488 for (unsigned i = 0; i < 5; ++i) {
489 in.push_back(m.mk_fresh_const("a",m.mk_bool_sort()));
490 }
491 in[4] = in[3].get();
492
493 ast_ext2 ext(m);
494 psort_nw<ast_ext2> sn(ext);
495 sn.cfg().m_encoding = enc;
496 expr_ref result(m);
497 result = sn.le(true, 1, in.size(), in.c_ptr());
498 //std::cout << result << "\n";
499 //std::cout << ext.m_clauses << "\n";
500 }
501
test_sorting5(sorting_network_encoding enc)502 static void test_sorting5(sorting_network_encoding enc) {
503 test_sorting_eq(11,7, enc);
504 for (unsigned n = 3; n < 20; n += 2) {
505 for (unsigned k = 1; k < n; ++k) {
506 test_sorting5(n, k, enc);
507 }
508 }
509 }
510
tst_sorting_network(sorting_network_encoding enc)511 static void tst_sorting_network(sorting_network_encoding enc) {
512 for (unsigned i = 1; i < 17; ++i) {
513 test_at_most_1(i, true, enc);
514 test_at_most_1(i, false, enc);
515 }
516 for (unsigned n = 2; n < 20; ++n) {
517 std::cout << "verify eq-1 out of " << n << "\n";
518 test_sorting_eq(n, 1, enc);
519 test_eq1(n, enc);
520 }
521 test_at_most1(enc);
522 test_sorting5(enc);
523 }
524
test_pb(unsigned max_w,unsigned sz,unsigned_vector & ws)525 static void test_pb(unsigned max_w, unsigned sz, unsigned_vector& ws) {
526 if (ws.empty()) {
527 for (unsigned w = 1; w <= max_w; ++w) {
528 ws.push_back(w);
529 test_pb(max_w, sz, ws);
530 ws.pop_back();
531 }
532 }
533 else if (ws.size() < sz) {
534 for (unsigned w = ws.back(); w <= max_w; ++w) {
535 ws.push_back(w);
536 test_pb(max_w, sz, ws);
537 ws.pop_back();
538 }
539 }
540 else {
541 SASSERT(ws.size() == sz);
542 ast_manager m;
543 reg_decl_plugins(m);
544 expr_ref_vector xs(m), nxs(m);
545 expr_ref ge(m), eq(m);
546 smt_params fp;
547 smt::kernel solver(m, fp);
548 for (unsigned i = 0; i < sz; ++i) {
549 xs.push_back(m.mk_const(symbol(i), m.mk_bool_sort()));
550 nxs.push_back(m.mk_not(xs.back()));
551 }
552 std::cout << ws << " " << "\n";
553 for (unsigned k = max_w + 1; k < ws.size()*max_w; ++k) {
554
555 ast_ext2 ext(m);
556 psort_nw<ast_ext2> sn(ext);
557 solver.push();
558 //std::cout << "bound: " << k << "\n";
559 //std::cout << ws << " " << xs << "\n";
560 ge = sn.ge(k, sz, ws.c_ptr(), xs.c_ptr());
561 //std::cout << "ge: " << ge << "\n";
562 for (expr* cls : ext.m_clauses) {
563 solver.assert_expr(cls);
564 }
565 // solver.display(std::cout);
566 // for each truth assignment to xs, validate
567 // that circuit computes the right value for ge
568 for (unsigned i = 0; i < (1ul << sz); ++i) {
569 solver.push();
570 unsigned sum = 0;
571 for (unsigned j = 0; j < sz; ++j) {
572 if (0 == ((1 << j) & i)) {
573 solver.assert_expr(xs.get(j));
574 sum += ws[j];
575 }
576 else {
577 solver.assert_expr(nxs.get(j));
578 }
579 }
580 // std::cout << "bound: " << k << "\n";
581 // std::cout << ws << " " << xs << "\n";
582 // std::cout << sum << " >= " << k << " : " << (sum >= k) << " ";
583 solver.push();
584 if (sum < k) {
585 solver.assert_expr(m.mk_not(ge));
586 }
587 else {
588 solver.assert_expr(ge);
589 }
590 // solver.display(std::cout) << "\n";
591 VERIFY(solver.check() == l_true);
592 solver.pop(1);
593
594 solver.push();
595 if (sum >= k) {
596 solver.assert_expr(m.mk_not(ge));
597 }
598 else {
599 solver.assert_expr(ge);
600 }
601 // solver.display(std::cout) << "\n";
602 VERIFY(l_false == solver.check());
603 solver.pop(1);
604 solver.pop(1);
605 }
606 solver.pop(1);
607
608 solver.push();
609 eq = sn.eq(k, sz, ws.c_ptr(), xs.c_ptr());
610
611 for (expr* cls : ext.m_clauses) {
612 solver.assert_expr(cls);
613 }
614 // for each truth assignment to xs, validate
615 // that circuit computes the right value for ge
616 for (unsigned i = 0; i < (1ul << sz); ++i) {
617 solver.push();
618 unsigned sum = 0;
619 for (unsigned j = 0; j < sz; ++j) {
620 if (0 == ((1 << j) & i)) {
621 solver.assert_expr(xs.get(j));
622 sum += ws[j];
623 }
624 else {
625 solver.assert_expr(nxs.get(j));
626 }
627 }
628 solver.push();
629 if (sum != k) {
630 solver.assert_expr(m.mk_not(eq));
631 }
632 else {
633 solver.assert_expr(eq);
634 }
635 // solver.display(std::cout) << "\n";
636 VERIFY(solver.check() == l_true);
637 solver.pop(1);
638
639 solver.push();
640 if (sum == k) {
641 solver.assert_expr(m.mk_not(eq));
642 }
643 else {
644 solver.assert_expr(eq);
645 }
646 VERIFY(l_false == solver.check());
647 solver.pop(1);
648 solver.pop(1);
649 }
650
651 solver.pop(1);
652 }
653 }
654 }
655
tst_pb()656 static void tst_pb() {
657 unsigned_vector ws;
658 test_pb(3, 3, ws);
659 }
660
tst_sorting_network()661 void tst_sorting_network() {
662 tst_pb();
663 tst_sorting_network(sorting_network_encoding::unate_at_most);
664 tst_sorting_network(sorting_network_encoding::circuit_at_most);
665 tst_sorting_network(sorting_network_encoding::ordered_at_most);
666 tst_sorting_network(sorting_network_encoding::grouped_at_most);
667 tst_sorting_network(sorting_network_encoding::bimander_at_most);
668 test_sorting1();
669 test_sorting2();
670 test_sorting3();
671 test_sorting4();
672 }
673
674