1 /*++ 2 Copyright (c) 2006 Microsoft Corporation 3 4 Module Name: 5 6 smt_cg_table.h 7 8 Abstract: 9 10 <abstract> 11 12 Author: 13 14 Leonardo de Moura (leonardo) 2008-02-19. 15 16 Revision History: 17 18 --*/ 19 #pragma once 20 21 #include "smt/smt_enode.h" 22 #include "util/hashtable.h" 23 #include "util/chashtable.h" 24 25 namespace smt { 26 27 typedef std::pair<enode *, bool> enode_bool_pair; 28 29 // one table per function symbol 30 31 /** 32 \brief Congruence table. 33 */ 34 class cg_table { 35 struct cg_unary_hash { operatorcg_unary_hash36 unsigned operator()(enode * n) const { 37 SASSERT(n->get_num_args() == 1); 38 return n->get_arg(0)->get_root()->hash(); 39 } 40 }; 41 42 struct cg_unary_eq { operatorcg_unary_eq43 bool operator()(enode * n1, enode * n2) const { 44 SASSERT(n1->get_num_args() == 1); 45 SASSERT(n2->get_num_args() == 1); 46 SASSERT(n1->get_decl() == n2->get_decl()); 47 return n1->get_arg(0)->get_root() == n2->get_arg(0)->get_root(); 48 } 49 }; 50 51 typedef chashtable<enode *, cg_unary_hash, cg_unary_eq> unary_table; 52 53 struct cg_binary_hash { operatorcg_binary_hash54 unsigned operator()(enode * n) const { 55 SASSERT(n->get_num_args() == 2); 56 return combine_hash(n->get_arg(0)->get_root()->hash(), n->get_arg(1)->get_root()->hash()); 57 } 58 }; 59 60 struct cg_binary_eq { operatorcg_binary_eq61 bool operator()(enode * n1, enode * n2) const { 62 SASSERT(n1->get_num_args() == 2); 63 SASSERT(n2->get_num_args() == 2); 64 SASSERT(n1->get_decl() == n2->get_decl()); 65 return 66 n1->get_arg(0)->get_root() == n2->get_arg(0)->get_root() && 67 n1->get_arg(1)->get_root() == n2->get_arg(1)->get_root(); 68 } 69 }; 70 71 typedef chashtable<enode*, cg_binary_hash, cg_binary_eq> binary_table; 72 73 struct cg_comm_hash { operatorcg_comm_hash74 unsigned operator()(enode * n) const { 75 SASSERT(n->get_num_args() == 2); 76 unsigned h1 = n->get_arg(0)->get_root()->hash(); 77 unsigned h2 = n->get_arg(1)->get_root()->hash(); 78 if (h1 > h2) 79 std::swap(h1, h2); 80 return hash_u((h1 << 16) | (h2 & 0xFFFF)); 81 } 82 }; 83 84 struct cg_comm_eq { 85 bool & m_commutativity; cg_comm_eqcg_comm_eq86 cg_comm_eq(bool & c):m_commutativity(c) {} operatorcg_comm_eq87 bool operator()(enode * n1, enode * n2) const { 88 SASSERT(n1->get_num_args() == 2); 89 SASSERT(n2->get_num_args() == 2); 90 SASSERT(n1->get_decl() == n2->get_decl()); 91 enode * c1_1 = n1->get_arg(0)->get_root(); 92 enode * c1_2 = n1->get_arg(1)->get_root(); 93 enode * c2_1 = n2->get_arg(0)->get_root(); 94 enode * c2_2 = n2->get_arg(1)->get_root(); 95 if (c1_1 == c2_1 && c1_2 == c2_2) { 96 return true; 97 } 98 if (c1_1 == c2_2 && c1_2 == c2_1) { 99 m_commutativity = true; 100 return true; 101 } 102 return false; 103 } 104 }; 105 106 typedef chashtable<enode*, cg_comm_hash, cg_comm_eq> comm_table; 107 108 struct cg_hash { 109 unsigned operator()(enode * n) const; 110 }; 111 112 struct cg_eq { 113 bool operator()(enode * n1, enode * n2) const; 114 }; 115 116 typedef chashtable<enode*, cg_hash, cg_eq> table; 117 118 ast_manager & m_manager; 119 bool m_commutativity; //!< true if the last found congruence used commutativity 120 ptr_vector<void> m_tables; 121 obj_map<func_decl, unsigned> m_func_decl2id; 122 123 enum table_kind { 124 UNARY, 125 BINARY, 126 BINARY_COMM, 127 NARY 128 }; 129 130 void * mk_table_for(func_decl * d); 131 unsigned set_func_decl_id(enode * n); 132 get_table(enode * n)133 void * get_table(enode * n) { 134 unsigned tid = n->get_func_decl_id(); 135 if (tid == UINT_MAX) 136 tid = set_func_decl_id(n); 137 SASSERT(tid < m_tables.size()); 138 return m_tables[tid]; 139 } 140 141 public: 142 cg_table(ast_manager & m); 143 ~cg_table(); 144 145 /** 146 \brief Try to insert n into the table. If the table already 147 contains an element n' congruent to n, then do nothing and 148 return n' and a boolean indicating whether n and n' are congruence 149 modulo commutativity, otherwise insert n and return (n,false). 150 */ 151 enode_bool_pair insert(enode * n); 152 153 void erase(enode * n); 154 contains(enode * n)155 bool contains(enode * n) const { 156 SASSERT(n->get_num_args() > 0); 157 void * t = const_cast<cg_table*>(this)->get_table(n); 158 switch (static_cast<table_kind>(GET_TAG(t))) { 159 case UNARY: 160 return UNTAG(unary_table*, t)->contains(n); 161 case BINARY: 162 return UNTAG(binary_table*, t)->contains(n); 163 case BINARY_COMM: 164 return UNTAG(comm_table*, t)->contains(n); 165 default: 166 return UNTAG(table*, t)->contains(n); 167 } 168 } 169 find(enode * n)170 enode * find(enode * n) const { 171 SASSERT(n->get_num_args() > 0); 172 enode * r = nullptr; 173 void * t = const_cast<cg_table*>(this)->get_table(n); 174 switch (static_cast<table_kind>(GET_TAG(t))) { 175 case UNARY: 176 return UNTAG(unary_table*, t)->find(n, r) ? r : nullptr; 177 case BINARY: 178 return UNTAG(binary_table*, t)->find(n, r) ? r : nullptr; 179 case BINARY_COMM: 180 return UNTAG(comm_table*, t)->find(n, r) ? r : nullptr; 181 default: 182 return UNTAG(table*, t)->find(n, r) ? r : nullptr; 183 } 184 } 185 contains_ptr(enode * n)186 bool contains_ptr(enode * n) const { 187 enode * r; 188 SASSERT(n->get_num_args() > 0); 189 void * t = const_cast<cg_table*>(this)->get_table(n); 190 switch (static_cast<table_kind>(GET_TAG(t))) { 191 case UNARY: 192 return UNTAG(unary_table*, t)->find(n, r) && n == r; 193 case BINARY: 194 return UNTAG(binary_table*, t)->find(n, r) && n == r; 195 case BINARY_COMM: 196 return UNTAG(comm_table*, t)->find(n, r) && n == r; 197 default: 198 return UNTAG(table*, t)->find(n, r) && n == r; 199 } 200 } 201 202 void reset(); 203 204 void display(std::ostream & out) const; 205 206 void display_binary(std::ostream& out, void* t) const; 207 208 void display_binary_comm(std::ostream& out, void* t) const; 209 210 void display_unary(std::ostream& out, void* t) const; 211 212 void display_nary(std::ostream& out, void* t) const; 213 214 void display_compact(std::ostream & out) const; 215 216 bool check_invariant() const; 217 }; 218 219 }; 220 221 222