1 /*++ 2 Copyright (c) 2017 Microsoft Corporation 3 4 Module Name: 5 6 <name> 7 8 Abstract: 9 10 <abstract> 11 12 Author: 13 Nikolaj Bjorner (nbjorner) 14 Lev Nachmanson (levnach) 15 16 Revision History: 17 18 19 --*/ 20 21 #pragma once 22 #include "util/union_find.h" 23 #include "math/lp/nla_defs.h" 24 #include "util/rational.h" 25 #include "math/lp/explanation.h" 26 #include "math/lp/incremental_vector.h" 27 28 namespace nla { 29 30 class eq_justification { 31 lpci m_cs[4]; 32 public: eq_justification(std::initializer_list<lpci> cs)33 eq_justification(std::initializer_list<lpci> cs) { 34 int i = 0; 35 for (lpci c: cs) { 36 m_cs[i++] = c; 37 } 38 for (; i < 4; i++) { 39 m_cs[i] = -1; 40 } 41 } 42 explain(lp::explanation & e)43 void explain(lp::explanation& e) const { 44 for (lpci c : m_cs) 45 if (c + 1 != 0) // c != -1 46 e.push_back(c); 47 } 48 }; 49 50 template <typename T> 51 class var_eqs { 52 struct eq_edge { 53 signed_var m_var; 54 eq_justification m_just; eq_edgeeq_edge55 eq_edge(signed_var v, eq_justification const& j): m_var(v), m_just(j) {} 56 }; 57 58 struct var_frame { 59 signed_var m_var; 60 unsigned m_index; var_framevar_frame61 var_frame(signed_var v, unsigned i): m_var(v), m_index(i) {} 62 }; 63 struct stats { 64 unsigned m_num_explain_calls; 65 unsigned m_num_explains; statsstats66 stats() { memset(this, 0, sizeof(*this)); } 67 }; 68 69 T* m_merge_handler; 70 union_find<var_eqs> m_uf; 71 lp::incremental_vector<std::pair<signed_var, signed_var>> 72 m_trail; 73 vector<svector<eq_edge>> m_eqs; // signed_var.index() -> the edges adjacent to signed_var.index() 74 75 trail_stack m_stack; 76 mutable svector<var_frame> m_todo; 77 mutable bool_vector m_marked; 78 mutable unsigned_vector m_marked_trail; 79 mutable svector<eq_justification> m_justtrail; 80 81 mutable stats m_stats; 82 public: var_eqs()83 var_eqs(): m_merge_handler(nullptr), m_uf(*this), m_stack() {} 84 /** 85 \brief push a scope */ push()86 void push() { 87 m_trail.push_scope(); 88 m_stack.push_scope(); 89 } 90 91 /** 92 \brief pop n scopes 93 */ pop(unsigned n)94 void pop(unsigned n) { 95 unsigned old_sz = m_trail.peek_size(n); 96 for (unsigned i = m_trail.size(); i-- > old_sz; ) { 97 auto const& sv = m_trail[i]; 98 m_eqs[sv.first.index()].pop_back(); 99 m_eqs[sv.second.index()].pop_back(); 100 m_eqs[(~sv.first).index()].pop_back(); 101 m_eqs[(~sv.second).index()].pop_back(); 102 } 103 m_trail.pop_scope(n); 104 m_stack.pop_scope(n); // this cass takes care of unmerging through union_find m_uf 105 } 106 107 /** 108 \brief merge equivalence classes for v1, v2 with justification j 109 */ merge(signed_var v1,signed_var v2,eq_justification const & j)110 void merge(signed_var v1, signed_var v2, eq_justification const& j) { 111 if (v1 == v2) 112 return; 113 if (find(v1).var() == find(v2).var()) 114 return; 115 unsigned max_i = std::max(v1.index(), v2.index()) + 2; 116 m_eqs.reserve(max_i); 117 while (m_uf.get_num_vars() <= max_i) m_uf.mk_var(); 118 TRACE("nla_solver_mons", tout << v1 << " == " << v2 << " " << m_uf.find(v1.index()) << " == " << m_uf.find(v2.index()) << "\n";); 119 m_trail.push_back(std::make_pair(v1, v2)); 120 m_uf.merge(v1.index(), v2.index()); 121 m_uf.merge((~v1).index(), (~v2).index()); 122 m_eqs[v1.index()].push_back(eq_edge(v2, j)); 123 m_eqs[v2.index()].push_back(eq_edge(v1, j)); 124 m_eqs[(~v1).index()].push_back(eq_edge(~v2, j)); 125 m_eqs[(~v2).index()].push_back(eq_edge(~v1, j)); 126 } 127 merge_plus(lpvar v1,lpvar v2,eq_justification const & j)128 void merge_plus(lpvar v1, lpvar v2, eq_justification const& j) { merge(signed_var(v1, false), signed_var(v2, false), j); } merge_minus(lpvar v1,lpvar v2,eq_justification const & j)129 void merge_minus(lpvar v1, lpvar v2, eq_justification const& j) { merge(signed_var(v1, false), signed_var(v2, true), j); } 130 131 /** 132 \brief find equivalence class representative for v 133 */ find(signed_var v)134 signed_var find(signed_var v) const { 135 if (v.index() >= m_uf.get_num_vars()) { 136 return v; 137 } 138 unsigned idx = m_uf.find(v.index()); 139 return signed_var(idx); 140 } 141 find(lpvar j)142 inline signed_var find(lpvar j) const { 143 return find(signed_var(j, false)); 144 } 145 is_root(lpvar j)146 inline bool is_root(lpvar j) const { 147 signed_var sv = find(signed_var(j, false)); 148 return sv.var() == j; 149 } is_root(svector<lpvar> v)150 inline bool is_root(svector<lpvar> v) const { 151 for (lpvar j : v) 152 if (!is_root(j)) 153 return false; 154 return true; 155 } 156 vars_are_equiv(lpvar j,lpvar k)157 bool vars_are_equiv(lpvar j, lpvar k) const { 158 signed_var sj = find(signed_var(j, false)); 159 signed_var sk = find(signed_var(k, false)); 160 return sj.var() == sk.var(); 161 } 162 /** 163 \brief Returns eq_justifications for 164 \pre find(v1) == find(v2) 165 */ explain_dfs(signed_var v1,signed_var v2,lp::explanation & e)166 void explain_dfs(signed_var v1, signed_var v2, lp::explanation& e) const { 167 SASSERT(find(v1) == find(v2)); 168 if (v1 == v2) { 169 return; 170 } 171 m_todo.push_back(var_frame(v1, 0)); 172 m_justtrail.reset(); 173 m_marked.reserve(m_eqs.size(), false); 174 SASSERT(m_marked_trail.empty()); 175 m_marked[v1.index()] = true; 176 m_marked_trail.push_back(v1.index()); 177 while (true) { 178 SASSERT(!m_todo.empty()); 179 var_frame& f = m_todo.back(); 180 signed_var v = f.m_var; 181 if (v == v2) { 182 break; 183 } 184 auto const& next = m_eqs[v.index()]; 185 bool seen_all = true; 186 unsigned sz = next.size(); 187 for (unsigned i = f.m_index; seen_all && i < sz; ++i) { 188 eq_edge const& jv = next[i]; 189 signed_var v3 = jv.m_var; 190 if (!m_marked[v3.index()]) { 191 seen_all = false; 192 f.m_index = i + 1; 193 m_todo.push_back(var_frame(v3, 0)); 194 m_justtrail.push_back(jv.m_just); 195 m_marked_trail.push_back(v3.index()); 196 m_marked[v3.index()] = true; 197 } 198 } 199 if (seen_all) { 200 m_todo.pop_back(); 201 m_justtrail.pop_back(); 202 } 203 } 204 205 for (eq_justification const& j : m_justtrail) { 206 j.explain(e); 207 } 208 m_stats.m_num_explains += m_justtrail.size(); 209 m_stats.m_num_explain_calls++; 210 m_todo.reset(); 211 m_justtrail.reset(); 212 for (unsigned idx : m_marked_trail) { 213 m_marked[idx] = false; 214 } 215 m_marked_trail.reset(); 216 217 // IF_VERBOSE(2, verbose_stream() << (double)m_stats.m_num_explains / m_stats.m_num_explain_calls << "\n"); 218 } 219 explain_bfs(signed_var v1,signed_var v2,lp::explanation & e)220 void explain_bfs(signed_var v1, signed_var v2, lp::explanation& e) const { 221 SASSERT(find(v1) == find(v2)); 222 if (v1 == v2) { 223 return; 224 } 225 m_todo.push_back(var_frame(v1, 0)); 226 m_justtrail.push_back(eq_justification({})); 227 m_marked.reserve(m_eqs.size(), false); 228 SASSERT(m_marked_trail.empty()); 229 m_marked[v1.index()] = true; 230 m_marked_trail.push_back(v1.index()); 231 unsigned head = 0; 232 for (; ; ++head) { 233 var_frame& f = m_todo[head]; 234 signed_var v = f.m_var; 235 if (v == v2) { 236 break; 237 } 238 auto const& next = m_eqs[v.index()]; 239 unsigned sz = next.size(); 240 for (unsigned i = sz; i-- > 0; ) { 241 eq_edge const& jv = next[i]; 242 signed_var v3 = jv.m_var; 243 if (!m_marked[v3.index()]) { 244 m_todo.push_back(var_frame(v3, head)); 245 m_justtrail.push_back(jv.m_just); 246 m_marked_trail.push_back(v3.index()); 247 m_marked[v3.index()] = true; 248 } 249 } 250 } 251 252 while (head != 0) { 253 m_justtrail[head].explain(e); 254 head = m_todo[head].m_index; 255 ++m_stats.m_num_explains; 256 } 257 ++m_stats.m_num_explain_calls; 258 259 m_todo.reset(); 260 m_justtrail.reset(); 261 for (unsigned idx : m_marked_trail) { 262 m_marked[idx] = false; 263 } 264 m_marked_trail.reset(); 265 266 // IF_VERBOSE(2, verbose_stream() << (double)m_stats.m_num_explains / m_stats.m_num_explain_calls << "\n"); 267 } 268 269 explain(signed_var v1,signed_var v2,lp::explanation & e)270 inline void explain(signed_var v1, signed_var v2, lp::explanation& e) const { 271 explain_bfs(v1, v2, e); 272 } explain(lpvar v1,lpvar v2,lp::explanation & e)273 inline void explain(lpvar v1, lpvar v2, lp::explanation & e) const { 274 return explain(signed_var(v1, false), signed_var(v2, false), e); 275 } 276 explain(lpvar j,lp::explanation & e)277 inline void explain(lpvar j, lp::explanation& e) const { 278 signed_var s(j, false); 279 return explain(find(s), s, e); 280 } 281 282 // iterates over the class of lpvar(m_idx) 283 class iterator { 284 var_eqs& m_ve; // context. 285 unsigned m_idx; // index into a signed variable, same as union-find index 286 bool m_touched; // toggle between initial and final state 287 public: iterator(var_eqs & ve,unsigned idx,bool t)288 iterator(var_eqs& ve, unsigned idx, bool t) : m_ve(ve), m_idx(idx), m_touched(t) {} 289 signed_var operator*() const { 290 return signed_var(m_idx); 291 } 292 iterator& operator++() { m_idx = m_ve.m_uf.next(m_idx); m_touched = true; return *this; } 293 bool operator==(iterator const& other) const { return m_idx == other.m_idx && m_touched == other.m_touched; } 294 bool operator!=(iterator const& other) const { return m_idx != other.m_idx || m_touched != other.m_touched; } 295 }; 296 297 class eq_class { 298 var_eqs& m_ve; 299 signed_var m_v; 300 public: eq_class(var_eqs & ve,signed_var v)301 eq_class(var_eqs& ve, signed_var v) : m_ve(ve), m_v(v) {} begin()302 iterator begin() { return iterator(m_ve, m_v.index(), false); } end()303 iterator end() { return iterator(m_ve, m_v.index(), true); } 304 }; 305 equiv_class(signed_var v)306 eq_class equiv_class(signed_var v) { return eq_class(*this, v); } 307 equiv_class(lpvar v)308 eq_class equiv_class(lpvar v) { return equiv_class(signed_var(v, false)); } 309 310 display(std::ostream & out)311 std::ostream& display(std::ostream& out) const { 312 m_uf.display(out); 313 unsigned idx = 0; 314 for (auto const& edges : m_eqs) { 315 if (!edges.empty()) { 316 auto v = signed_var(idx); 317 out << v << " root: " << find(v) << " : "; 318 for (auto const& jv : edges) { 319 out << jv.m_var << " "; 320 } 321 out << "\n"; 322 } 323 ++idx; 324 } 325 return out; 326 } 327 328 // union find event handlers set_merge_handler(T * mh)329 void set_merge_handler(T* mh) { m_merge_handler = mh; } 330 // this method is required by union_find get_trail_stack()331 trail_stack & get_trail_stack() { return m_stack; } 332 unmerge_eh(unsigned i,unsigned j)333 void unmerge_eh(unsigned i, unsigned j) { 334 if (m_merge_handler) { 335 m_merge_handler->unmerge_eh(signed_var(i), signed_var(j)); 336 } 337 } merge_eh(unsigned r2,unsigned r1,unsigned v2,unsigned v1)338 void merge_eh(unsigned r2, unsigned r1, unsigned v2, unsigned v1) { 339 if (m_merge_handler) { 340 m_merge_handler->merge_eh(signed_var(r2), signed_var(r1), 341 signed_var(v2), signed_var(v1)); 342 } 343 } 344 after_merge_eh(unsigned r2,unsigned r1,unsigned v2,unsigned v1)345 void after_merge_eh(unsigned r2, unsigned r1, unsigned v2, unsigned v1) { 346 if (m_merge_handler) { 347 m_merge_handler->after_merge_eh(signed_var(r2), signed_var(r1), 348 signed_var(v2), signed_var(v1)); 349 } 350 } 351 }; // end of var_eqs 352 } 353