1 /*++
2 Copyright (c) 2006 Microsoft Corporation
3 
4 Module Name:
5 
6     smt_cg_table.h
7 
8 Abstract:
9 
10     <abstract>
11 
12 Author:
13 
14     Leonardo de Moura (leonardo) 2008-02-19.
15 
16 Revision History:
17 
18 --*/
19 #pragma once
20 
21 #include "smt/smt_enode.h"
22 #include "util/hashtable.h"
23 #include "util/chashtable.h"
24 
25 namespace smt {
26 
27     typedef std::pair<enode *, bool> enode_bool_pair;
28 
29     // one table per function symbol
30 
31     /**
32        \brief Congruence table.
33     */
34     class cg_table {
35         struct cg_unary_hash {
operatorcg_unary_hash36             unsigned operator()(enode * n) const {
37                 SASSERT(n->get_num_args() == 1);
38                 return n->get_arg(0)->get_root()->hash();
39             }
40         };
41 
42         struct cg_unary_eq {
operatorcg_unary_eq43             bool operator()(enode * n1, enode * n2) const {
44                 SASSERT(n1->get_num_args() == 1);
45                 SASSERT(n2->get_num_args() == 1);
46                 SASSERT(n1->get_decl() == n2->get_decl());
47                 return n1->get_arg(0)->get_root() == n2->get_arg(0)->get_root();
48             }
49         };
50 
51         typedef chashtable<enode *, cg_unary_hash, cg_unary_eq> unary_table;
52 
53         struct cg_binary_hash {
operatorcg_binary_hash54             unsigned operator()(enode * n) const {
55                 SASSERT(n->get_num_args() == 2);
56                 return combine_hash(n->get_arg(0)->get_root()->hash(), n->get_arg(1)->get_root()->hash());
57             }
58         };
59 
60         struct cg_binary_eq {
operatorcg_binary_eq61             bool operator()(enode * n1, enode * n2) const {
62                 SASSERT(n1->get_num_args() == 2);
63                 SASSERT(n2->get_num_args() == 2);
64                 SASSERT(n1->get_decl() == n2->get_decl());
65                 return
66                     n1->get_arg(0)->get_root() == n2->get_arg(0)->get_root() &&
67                     n1->get_arg(1)->get_root() == n2->get_arg(1)->get_root();
68             }
69         };
70 
71         typedef chashtable<enode*, cg_binary_hash, cg_binary_eq> binary_table;
72 
73         struct cg_comm_hash {
operatorcg_comm_hash74             unsigned operator()(enode * n) const {
75                 SASSERT(n->get_num_args() == 2);
76                 unsigned h1 = n->get_arg(0)->get_root()->hash();
77                 unsigned h2 = n->get_arg(1)->get_root()->hash();
78                 if (h1 > h2)
79                     std::swap(h1, h2);
80                 return hash_u((h1 << 16) | (h2 & 0xFFFF));
81             }
82         };
83 
84         struct cg_comm_eq {
85             bool & m_commutativity;
cg_comm_eqcg_comm_eq86             cg_comm_eq(bool & c):m_commutativity(c) {}
operatorcg_comm_eq87             bool operator()(enode * n1, enode * n2) const {
88                 SASSERT(n1->get_num_args() == 2);
89                 SASSERT(n2->get_num_args() == 2);
90                 SASSERT(n1->get_decl() == n2->get_decl());
91                 enode * c1_1 = n1->get_arg(0)->get_root();
92                 enode * c1_2 = n1->get_arg(1)->get_root();
93                 enode * c2_1 = n2->get_arg(0)->get_root();
94                 enode * c2_2 = n2->get_arg(1)->get_root();
95                 if (c1_1 == c2_1 && c1_2 == c2_2) {
96                     return true;
97                 }
98                 if (c1_1 == c2_2 && c1_2 == c2_1) {
99                     m_commutativity = true;
100                     return true;
101                 }
102                 return false;
103             }
104         };
105 
106         typedef chashtable<enode*, cg_comm_hash, cg_comm_eq> comm_table;
107 
108         struct cg_hash {
109             unsigned operator()(enode * n) const;
110         };
111 
112         struct cg_eq {
113             bool operator()(enode * n1, enode * n2) const;
114         };
115 
116         typedef chashtable<enode*, cg_hash, cg_eq> table;
117 
118         ast_manager &                 m_manager;
119         bool                          m_commutativity; //!< true if the last found congruence used commutativity
120         ptr_vector<void>              m_tables;
121         obj_map<func_decl, unsigned>  m_func_decl2id;
122 
123         enum table_kind {
124             UNARY,
125             BINARY,
126             BINARY_COMM,
127             NARY
128         };
129 
130         void * mk_table_for(func_decl * d);
131         unsigned set_func_decl_id(enode * n);
132 
get_table(enode * n)133         void * get_table(enode * n) {
134             unsigned tid = n->get_func_decl_id();
135             if (tid == UINT_MAX)
136                 tid = set_func_decl_id(n);
137             SASSERT(tid < m_tables.size());
138             return m_tables[tid];
139         }
140 
141     public:
142         cg_table(ast_manager & m);
143         ~cg_table();
144 
145         /**
146            \brief Try to insert n into the table. If the table already
147            contains an element n' congruent to n, then do nothing and
148            return n' and a boolean indicating whether n and n' are congruence
149            modulo commutativity, otherwise insert n and return (n,false).
150         */
151         enode_bool_pair insert(enode * n);
152 
153         void erase(enode * n);
154 
contains(enode * n)155         bool contains(enode * n) const {
156             SASSERT(n->get_num_args() > 0);
157             void * t = const_cast<cg_table*>(this)->get_table(n);
158             switch (static_cast<table_kind>(GET_TAG(t))) {
159             case UNARY:
160                 return UNTAG(unary_table*, t)->contains(n);
161             case BINARY:
162                 return UNTAG(binary_table*, t)->contains(n);
163             case BINARY_COMM:
164                 return UNTAG(comm_table*, t)->contains(n);
165             default:
166                 return UNTAG(table*, t)->contains(n);
167             }
168         }
169 
find(enode * n)170         enode * find(enode * n) const {
171             SASSERT(n->get_num_args() > 0);
172             enode * r = nullptr;
173             void * t = const_cast<cg_table*>(this)->get_table(n);
174             switch (static_cast<table_kind>(GET_TAG(t))) {
175             case UNARY:
176                 return UNTAG(unary_table*, t)->find(n, r) ? r : nullptr;
177             case BINARY:
178                 return UNTAG(binary_table*, t)->find(n, r) ? r : nullptr;
179             case BINARY_COMM:
180                 return UNTAG(comm_table*, t)->find(n, r) ? r : nullptr;
181             default:
182                 return UNTAG(table*, t)->find(n, r) ? r : nullptr;
183             }
184         }
185 
contains_ptr(enode * n)186         bool contains_ptr(enode * n) const {
187             enode * r;
188             SASSERT(n->get_num_args() > 0);
189             void * t = const_cast<cg_table*>(this)->get_table(n);
190             switch (static_cast<table_kind>(GET_TAG(t))) {
191             case UNARY:
192                 return UNTAG(unary_table*, t)->find(n, r) && n == r;
193             case BINARY:
194                 return UNTAG(binary_table*, t)->find(n, r) && n == r;
195             case BINARY_COMM:
196                 return UNTAG(comm_table*, t)->find(n, r) && n == r;
197             default:
198                 return UNTAG(table*, t)->find(n, r) && n == r;
199             }
200         }
201 
202         void reset();
203 
204         void display(std::ostream & out) const;
205 
206         void display_binary(std::ostream& out, void* t) const;
207 
208         void display_binary_comm(std::ostream& out, void* t) const;
209 
210         void display_unary(std::ostream& out, void* t) const;
211 
212         void display_nary(std::ostream& out, void* t) const;
213 
214         void display_compact(std::ostream & out) const;
215 
216         bool check_invariant() const;
217     };
218 
219 };
220 
221 
222