1 /*++
2 Copyright (c) 2020 Microsoft Corporation
3 
4 --*/
5 
6 #include "util/util.h"
7 #include "util/timer.h"
8 #include "ast/euf/euf_egraph.h"
9 #include "ast/reg_decl_plugins.h"
10 #include "ast/ast_pp.h"
11 #include "ast/arith_decl_plugin.h"
12 
mk_const(ast_manager & m,char const * name,sort * s)13 static expr_ref mk_const(ast_manager& m, char const* name, sort* s) {
14     return expr_ref(m.mk_const(symbol(name), s), m);
15 }
16 
mk_app(char const * name,expr_ref const & arg,sort * s)17 static expr_ref mk_app(char const* name, expr_ref const& arg, sort* s) {
18     ast_manager& m = arg.m();
19     func_decl_ref f(m.mk_func_decl(symbol(name), arg->get_sort(), s), m);
20     return expr_ref(m.mk_app(f, arg.get()), m);
21 }
22 
23 #if 0
24 static expr_ref mk_app(char const* name, expr_ref const& arg1, expr_ref const& arg2, sort* s) {
25     ast_manager& m = arg1.m();
26     func_decl_ref f(m.mk_func_decl(symbol(name), m.get_sort(arg1), m.get_sort(arg2), s), m);
27     return expr_ref(m.mk_app(f, arg1.get(), arg2.get()), m);
28 }
29 #endif
30 
test1()31 static void test1() {
32     ast_manager m;
33     reg_decl_plugins(m);
34     euf::egraph g(m);
35     sort_ref S(m.mk_uninterpreted_sort(symbol("S")), m);
36     expr_ref a = mk_const(m, "a", S);
37     expr_ref fa = mk_app("f", a, S);
38     expr_ref ffa = mk_app("f", fa, S);
39     expr_ref fffa = mk_app("f", ffa, S);
40     euf::enode* na = g.mk(a, 0, 0, nullptr);
41     euf::enode* nfa = g.mk(fa, 0, 1, &na);
42     euf::enode* nffa = g.mk(ffa, 0, 1, &nfa);
43     euf::enode* nfffa = g.mk(fffa, 0, 1, &nffa);
44     std::cout << g << "\n";
45     g.merge(na, nffa, nullptr);
46     g.merge(na, nfffa, nullptr);
47     std::cout << g << "\n";
48     g.propagate();
49     std::cout << g << "\n";
50 }
51 
test2()52 static void test2() {
53     ast_manager m;
54     reg_decl_plugins(m);
55     euf::egraph g(m);
56     sort_ref S(m.mk_uninterpreted_sort(symbol("S")), m);
57     unsigned d = 100, w = 100;
58     euf::enode_vector nodes, top_nodes;
59     expr_ref_vector pinned(m);
60     for (unsigned i = 0; i < w; ++i) {
61         std::string xn("x");
62         xn += std::to_string(i);
63         expr_ref x = mk_const(m, xn.c_str(), S);
64         euf::enode* n = g.mk(x, 0, 0, nullptr);
65         nodes.push_back(n);
66         for (unsigned j = 0; j < d; ++j) {
67             std::string f("f");
68             f += std::to_string(j);
69             x = mk_app(f.c_str(), x, S);
70             n = g.mk(x, 0, 1, &n);
71         }
72         top_nodes.push_back(n);
73         pinned.push_back(x);
74     }
75     std::cout << "merge\n";
76     timer t;
77     for (euf::enode* n : nodes) {
78         g.merge(n, nodes[0], nullptr);
79     }
80     std::cout << "merged " << t.get_seconds() << "\n";
81     g.propagate();
82     std::cout << "propagated " << t.get_seconds() << "\n";
83     for (euf::enode* n : top_nodes) {
84         VERIFY(n->get_root() == top_nodes[0]->get_root());
85     }
86 }
87 
88 
89 
test3()90 static void test3() {
91     ast_manager m;
92     reg_decl_plugins(m);
93     arith_util a(m);
94     euf::egraph g(m);
95     sort_ref I(a.mk_int(), m);
96     expr_ref zero(a.mk_int(0), m);
97     expr_ref one(a.mk_int(1), m);
98     expr_ref x = mk_const(m, "x", I);
99     expr_ref y = mk_const(m, "y", I);
100     expr_ref z = mk_const(m, "z", I);
101     expr_ref u = mk_const(m, "u", I);
102     expr_ref fx = mk_app("f", x, I);
103     expr_ref fy = mk_app("f", y, I);
104     euf::enode* nx = g.mk(x, 0, 0, nullptr);
105     euf::enode* ny = g.mk(y, 0, 0, nullptr);
106     euf::enode* nz = g.mk(z, 0, 0, nullptr);
107     euf::enode* nu = g.mk(u, 0, 0, nullptr);
108     euf::enode* n0 = g.mk(zero, 0, 0, nullptr);
109     euf::enode* n1 = g.mk(one, 0, 0, nullptr);
110     euf::enode* nfx = g.mk(fx, 0, 1, &nx);
111     euf::enode* nfy = g.mk(fy, 0, 1, &ny);
112     int justifications[5] = { 1, 2, 3, 4, 5 };
113     g.merge(nfx, n0, justifications + 0);
114     g.merge(nfy, n1, justifications + 1);
115     g.merge(nx,  nz, justifications + 2);
116     g.merge(nx,  nu, justifications + 3);
117     g.propagate();
118     SASSERT(!g.inconsistent());
119     g.merge(nx, ny, justifications + 4);
120     std::cout << g << "\n";
121     g.propagate();
122     std::cout << g << "\n";
123     SASSERT(g.inconsistent());
124     ptr_vector<int> js;
125     g.begin_explain();
126     g.explain<int>(js);
127     g.end_explain();
128     for (int* j : js)
129         std::cout << "conflict: " << *j << "\n";
130 }
131 
tst_egraph()132 void tst_egraph() {
133     enable_trace("euf");
134     test3();
135     test1();
136     test2();
137 }
138