1 /*++ 2 Copyright (c) 2006 Microsoft Corporation 3 4 Module Name: 5 6 union_find.h 7 8 Abstract: 9 10 <abstract> 11 12 Author: 13 14 Leonardo de Moura (leonardo) 2008-05-31. 15 16 Revision History: 17 18 --*/ 19 #pragma once 20 21 #include "util/trail.h" 22 #include "util/trace.h" 23 24 class union_find_default_ctx { 25 public: 26 typedef trail_stack<union_find_default_ctx> _trail_stack; union_find_default_ctx()27 union_find_default_ctx() : m_stack(*this) {} 28 unmerge_eh(unsigned,unsigned)29 void unmerge_eh(unsigned, unsigned) {} merge_eh(unsigned,unsigned,unsigned,unsigned)30 void merge_eh(unsigned, unsigned, unsigned, unsigned) {} after_merge_eh(unsigned,unsigned,unsigned,unsigned)31 void after_merge_eh(unsigned, unsigned, unsigned, unsigned) {} 32 get_trail_stack()33 _trail_stack& get_trail_stack() { return m_stack; } 34 35 private: 36 _trail_stack m_stack; 37 }; 38 39 template<typename Ctx = union_find_default_ctx, typename StackCtx = Ctx> 40 class union_find { 41 Ctx & m_ctx; 42 trail_stack<StackCtx> & m_trail_stack; 43 svector<unsigned> m_find; 44 svector<unsigned> m_size; 45 svector<unsigned> m_next; 46 47 class mk_var_trail; 48 friend class mk_var_trail; 49 50 class mk_var_trail : public trail<StackCtx> { 51 union_find & m_owner; 52 public: mk_var_trail(union_find & o)53 mk_var_trail(union_find & o):m_owner(o) {} ~mk_var_trail()54 ~mk_var_trail() override {} undo(StackCtx & ctx)55 void undo(StackCtx& ctx) override { 56 m_owner.m_find.pop_back(); 57 m_owner.m_size.pop_back(); 58 m_owner.m_next.pop_back(); 59 } 60 }; 61 62 mk_var_trail m_mk_var_trail; 63 64 class merge_trail; 65 friend class merge_trail; 66 67 class merge_trail : public trail<StackCtx> { 68 union_find & m_owner; 69 unsigned m_r1; 70 public: merge_trail(union_find & o,unsigned r1)71 merge_trail(union_find & o, unsigned r1):m_owner(o), m_r1(r1) {} ~merge_trail()72 ~merge_trail() override {} undo(StackCtx & ctx)73 void undo(StackCtx& ctx) override { m_owner.unmerge(m_r1); } 74 }; 75 unmerge(unsigned r1)76 void unmerge(unsigned r1) { 77 unsigned r2 = m_find[r1]; 78 TRACE("union_find", tout << "unmerging " << r1 << " " << r2 << "\n";); 79 SASSERT(find(r2) == r2); 80 m_size[r2] -= m_size[r1]; 81 m_find[r1] = r1; 82 std::swap(m_next[r1], m_next[r2]); 83 m_ctx.unmerge_eh(r2, r1); 84 CASSERT("union_find", check_invariant()); 85 } 86 87 public: union_find(Ctx & ctx)88 union_find(Ctx & ctx):m_ctx(ctx), m_trail_stack(ctx.get_trail_stack()), m_mk_var_trail(*this) {} 89 mk_var()90 unsigned mk_var() { 91 unsigned r = m_find.size(); 92 m_find.push_back(r); 93 m_size.push_back(1); 94 m_next.push_back(r); 95 m_trail_stack.push_ptr(&m_mk_var_trail); 96 return r; 97 } 98 get_num_vars()99 unsigned get_num_vars() const { return m_find.size(); } 100 101 find(unsigned v)102 unsigned find(unsigned v) const { 103 while (true) { 104 SASSERT(v < m_find.size()); 105 unsigned new_v = m_find[v]; 106 if (new_v == v) 107 return v; 108 v = new_v; 109 } 110 } 111 next(unsigned v)112 unsigned next(unsigned v) const { return m_next[v]; } 113 size(unsigned v)114 unsigned size(unsigned v) const { return m_size[find(v)]; } 115 is_root(unsigned v)116 bool is_root(unsigned v) const { return m_find[v] == v; } 117 merge(unsigned v1,unsigned v2)118 void merge(unsigned v1, unsigned v2) { 119 unsigned r1 = find(v1); 120 unsigned r2 = find(v2); 121 TRACE("union_find", tout << "merging " << r1 << " " << r2 << "\n";); 122 if (r1 == r2) 123 return; 124 if (m_size[r1] > m_size[r2]) { 125 std::swap(r1, r2); 126 std::swap(v1, v2); 127 } 128 m_ctx.merge_eh(r2, r1, v2, v1); 129 m_find[r1] = r2; 130 m_size[r2] += m_size[r1]; 131 std::swap(m_next[r1], m_next[r2]); 132 m_trail_stack.push(merge_trail(*this, r1)); 133 m_ctx.after_merge_eh(r2, r1, v2, v1); 134 CASSERT("union_find", check_invariant()); 135 } 136 137 // dissolve equivalence class of v 138 // this method cannot be used with backtracking. dissolve(unsigned v)139 void dissolve(unsigned v) { 140 unsigned w; 141 do { 142 w = next(v); 143 m_size[v] = 1; 144 m_find[v] = v; 145 m_next[v] = v; 146 } 147 while (w != v); 148 } 149 display(std::ostream & out)150 void display(std::ostream & out) const { 151 unsigned num = get_num_vars(); 152 for (unsigned v = 0; v < num; v++) { 153 out << "v" << v << " --> v" << m_find[v] << " (" << size(v) << ")\n"; 154 } 155 } 156 157 #ifdef Z3DEBUG check_invariant()158 bool check_invariant() const { 159 unsigned num = get_num_vars(); 160 for (unsigned v = 0; v < num; v++) { 161 if (is_root(v)) { 162 unsigned curr = v; 163 unsigned sz = 0; 164 do { 165 SASSERT(find(curr) == v); 166 sz++; 167 curr = next(curr); 168 } 169 while (curr != v); 170 SASSERT(m_size[v] == sz); 171 } 172 } 173 return true; 174 } 175 #endif 176 }; 177 178 179 class basic_union_find { 180 unsigned_vector m_find; 181 unsigned_vector m_size; 182 unsigned_vector m_next; 183 ensure_size(unsigned v)184 void ensure_size(unsigned v) { 185 while (v >= get_num_vars()) { 186 mk_var(); 187 } 188 } 189 public: mk_var()190 unsigned mk_var() { 191 unsigned r = m_find.size(); 192 m_find.push_back(r); 193 m_size.push_back(1); 194 m_next.push_back(r); 195 return r; 196 } get_num_vars()197 unsigned get_num_vars() const { return m_find.size(); } 198 find(unsigned v)199 unsigned find(unsigned v) const { 200 if (v >= get_num_vars()) { 201 return v; 202 } 203 while (true) { 204 unsigned new_v = m_find[v]; 205 if (new_v == v) 206 return v; 207 v = new_v; 208 } 209 } 210 next(unsigned v)211 unsigned next(unsigned v) const { 212 if (v >= get_num_vars()) { 213 return v; 214 } 215 return m_next[v]; 216 } 217 is_root(unsigned v)218 bool is_root(unsigned v) const { 219 return v >= get_num_vars() || m_find[v] == v; 220 } 221 merge(unsigned v1,unsigned v2)222 void merge(unsigned v1, unsigned v2) { 223 unsigned r1 = find(v1); 224 unsigned r2 = find(v2); 225 if (r1 == r2) 226 return; 227 ensure_size(v1); 228 ensure_size(v2); 229 if (m_size[r1] > m_size[r2]) 230 std::swap(r1, r2); 231 m_find[r1] = r2; 232 m_size[r2] += m_size[r1]; 233 std::swap(m_next[r1], m_next[r2]); 234 } 235 reset()236 void reset() { 237 m_find.reset(); 238 m_next.reset(); 239 m_size.reset(); 240 } 241 }; 242 243 244 245