1 /*++ Copyright (c) 2006 Microsoft Corporation
2 
3 Module Name:
4 
5     euf_etable.h
6 
7 Author:
8 
9     Leonardo de Moura (leonardo) 2008-02-19.
10 
11 Revision History:
12 
13     copied from smt_cg_table
14 
15 --*/
16 
17 #pragma once
18 
19 #include "ast/euf/euf_enode.h"
20 #include "util/hashtable.h"
21 #include "util/chashtable.h"
22 
23 namespace euf {
24 
25     // one table per function symbol
26 
27 
28     /**
29        \brief Congruence table.
30     */
31     class etable {
get_root(enode * n,unsigned idx)32         static enode* get_root(enode* n, unsigned idx) { return n->get_arg(idx)->get_root(); }
33 
34         struct cg_unary_hash {
operatorcg_unary_hash35             unsigned operator()(enode * n) const {
36                 SASSERT(n->num_args() == 1);
37                 return get_root(n, 0)->hash();
38             }
39         };
40 
41         struct cg_unary_eq {
42 
operatorcg_unary_eq43             bool operator()(enode * n1, enode * n2) const {
44                 SASSERT(n1->num_args() == 1);
45                 SASSERT(n2->num_args() == 1);
46                 SASSERT(n1->get_decl() == n2->get_decl());
47                 return get_root(n1, 0) == get_root(n2, 0);
48             }
49         };
50 
51         typedef chashtable<enode *, cg_unary_hash, cg_unary_eq> unary_table;
52 
53         struct cg_binary_hash {
operatorcg_binary_hash54             unsigned operator()(enode * n) const {
55                 SASSERT(n->num_args() == 2);
56                 return combine_hash(get_root(n, 0)->hash(), get_root(n, 1)->hash());
57             }
58         };
59 
60         struct cg_binary_eq {
operatorcg_binary_eq61             bool operator()(enode * n1, enode * n2) const {
62                 SASSERT(n1->num_args() == 2);
63                 SASSERT(n2->num_args() == 2);
64                 SASSERT(n1->get_decl() == n2->get_decl());
65                 return
66                     get_root(n1, 0) == get_root(n2, 0) &&
67                     get_root(n1, 1) == get_root(n2, 1);
68             }
69         };
70 
71         typedef chashtable<enode*, cg_binary_hash, cg_binary_eq> binary_table;
72 
73         struct cg_comm_hash {
operatorcg_comm_hash74             unsigned operator()(enode * n) const {
75                 SASSERT(n->num_args() == 2);
76                 unsigned h1 = get_root(n, 0)->hash();
77                 unsigned h2 = get_root(n, 1)->hash();
78                 if (h1 > h2)
79                     std::swap(h1, h2);
80                 return hash_u((h1 << 16) | (h2 & 0xFFFF));
81             }
82         };
83 
84         struct cg_comm_eq {
85             bool & m_commutativity;
cg_comm_eqcg_comm_eq86             cg_comm_eq(bool & c): m_commutativity(c) {}
operatorcg_comm_eq87             bool operator()(enode * n1, enode * n2) const {
88                 SASSERT(n1->num_args() == 2);
89                 SASSERT(n2->num_args() == 2);
90                 if (n1->get_decl() != n2->get_decl())
91                     return false;
92                 enode* c1_1 = get_root(n1, 0);
93                 enode* c1_2 = get_root(n1, 1);
94                 enode* c2_1 = get_root(n2, 0);
95                 enode* c2_2 = get_root(n2, 1);
96                 if (c1_1 == c2_1 && c1_2 == c2_2) {
97                     return true;
98                 }
99                 if (c1_1 == c2_2 && c1_2 == c2_1) {
100                     m_commutativity = true;
101                     return true;
102                 }
103                 return false;
104             }
105         };
106 
107         typedef chashtable<enode*, cg_comm_hash, cg_comm_eq> comm_table;
108 
109         struct cg_hash {
110             unsigned operator()(enode * n) const;
111         };
112 
113         struct cg_eq {
114             bool operator()(enode * n1, enode * n2) const;
115         };
116 
117         typedef chashtable<enode*, cg_hash, cg_eq> table;
118         typedef std::pair<func_decl*, unsigned> decl_info;
119         struct decl_hash {
operatordecl_hash120             unsigned operator()(decl_info const& d) const { return d.first->hash(); }
121         };
122         struct decl_eq {
operatordecl_eq123             bool operator()(decl_info const& a, decl_info const& b) const { return a == b; }
124         };
125 
126 
127         ast_manager &                 m_manager;
128         bool                          m_commutativity = false; //!< true if the last found congruence used commutativity
129         ptr_vector<void>              m_tables;
130         map<decl_info, unsigned, decl_hash, decl_eq>  m_func_decl2id;
131 
132         enum table_kind {
133             UNARY,
134             BINARY,
135             BINARY_COMM,
136             NARY
137         };
138 
139         void * mk_table_for(unsigned n, func_decl * d);
140         unsigned set_table_id(enode * n);
141 
get_table(enode * n)142         void * get_table(enode * n) {
143             unsigned tid = n->get_table_id();
144             if (tid == UINT_MAX)
145                 tid = set_table_id(n);
146             SASSERT(tid < m_tables.size());
147             return m_tables[tid];
148         }
149 
150         void display_binary(std::ostream& out, void* t) const;
151         void display_binary_comm(std::ostream& out, void* t) const;
152         void display_unary(std::ostream& out, void* t) const;
153         void display_nary(std::ostream& out, void* t) const;
154 
155     public:
156         etable(ast_manager & m);
157 
158         ~etable();
159 
160         /**
161            \brief Try to insert n into the table. If the table already
162            contains an element n' congruent to n, then do nothing and
163            return n' and a boolean indicating whether n and n' are congruence
164            modulo commutativity, otherwise insert n and return (n,false).
165         */
166         enode_bool_pair insert(enode * n);
167 
168         void erase(enode * n);
169 
170         bool contains(enode* n) const;
171 
172         enode* find(enode* n) const;
173 
174         bool contains_ptr(enode* n) const;
175 
176         void reset();
177 
178         void display(std::ostream & out) const;
179 
180     };
181 
182 };
183 
184 
185 
186 
187