1 /*++
2 Copyright (c) 2006 Microsoft Corporation
3 
4 Module Name:
5 
6     union_find.h
7 
8 Abstract:
9 
10     <abstract>
11 
12 Author:
13 
14     Leonardo de Moura (leonardo) 2008-05-31.
15 
16 Revision History:
17 
18 --*/
19 #pragma once
20 
21 #include "util/trail.h"
22 #include "util/trace.h"
23 
24 class union_find_default_ctx {
25 public:
26     typedef trail_stack<union_find_default_ctx> _trail_stack;
union_find_default_ctx()27     union_find_default_ctx() : m_stack(*this) {}
28 
unmerge_eh(unsigned,unsigned)29     void unmerge_eh(unsigned, unsigned) {}
merge_eh(unsigned,unsigned,unsigned,unsigned)30     void merge_eh(unsigned, unsigned, unsigned, unsigned) {}
after_merge_eh(unsigned,unsigned,unsigned,unsigned)31     void after_merge_eh(unsigned, unsigned, unsigned, unsigned) {}
32 
get_trail_stack()33     _trail_stack& get_trail_stack() { return m_stack; }
34 
35 private:
36     _trail_stack m_stack;
37 };
38 
39 template<typename Ctx = union_find_default_ctx, typename StackCtx = Ctx>
40 class union_find {
41     Ctx &                         m_ctx;
42     trail_stack<StackCtx> &       m_trail_stack;
43     svector<unsigned>             m_find;
44     svector<unsigned>             m_size;
45     svector<unsigned>             m_next;
46 
47     class mk_var_trail;
48     friend class mk_var_trail;
49 
50     class mk_var_trail : public trail<StackCtx> {
51         union_find & m_owner;
52     public:
mk_var_trail(union_find & o)53         mk_var_trail(union_find & o):m_owner(o) {}
~mk_var_trail()54         ~mk_var_trail() override {}
undo(StackCtx & ctx)55         void undo(StackCtx& ctx) override {
56             m_owner.m_find.pop_back();
57             m_owner.m_size.pop_back();
58             m_owner.m_next.pop_back();
59         }
60     };
61 
62     mk_var_trail                  m_mk_var_trail;
63 
64     class merge_trail;
65     friend class merge_trail;
66 
67     class merge_trail : public trail<StackCtx> {
68         union_find & m_owner;
69         unsigned     m_r1;
70     public:
merge_trail(union_find & o,unsigned r1)71         merge_trail(union_find & o, unsigned r1):m_owner(o), m_r1(r1) {}
~merge_trail()72         ~merge_trail() override {}
undo(StackCtx & ctx)73         void undo(StackCtx& ctx) override { m_owner.unmerge(m_r1); }
74     };
75 
unmerge(unsigned r1)76     void unmerge(unsigned r1) {
77         unsigned r2 = m_find[r1];
78         TRACE("union_find", tout << "unmerging " << r1 << " " << r2 << "\n";);
79         SASSERT(find(r2) == r2);
80         m_size[r2] -= m_size[r1];
81         m_find[r1]  = r1;
82         std::swap(m_next[r1], m_next[r2]);
83         m_ctx.unmerge_eh(r2, r1);
84         CASSERT("union_find", check_invariant());
85     }
86 
87 public:
union_find(Ctx & ctx)88     union_find(Ctx & ctx):m_ctx(ctx), m_trail_stack(ctx.get_trail_stack()), m_mk_var_trail(*this) {}
89 
mk_var()90     unsigned mk_var() {
91         unsigned r = m_find.size();
92         m_find.push_back(r);
93         m_size.push_back(1);
94         m_next.push_back(r);
95         m_trail_stack.push_ptr(&m_mk_var_trail);
96         return r;
97     }
98 
get_num_vars()99     unsigned get_num_vars() const { return m_find.size(); }
100 
101 
find(unsigned v)102     unsigned find(unsigned v) const {
103         while (true) {
104             SASSERT(v < m_find.size());
105             unsigned new_v = m_find[v];
106             if (new_v == v)
107                 return v;
108             v = new_v;
109         }
110     }
111 
next(unsigned v)112     unsigned next(unsigned v) const { return m_next[v]; }
113 
size(unsigned v)114     unsigned size(unsigned v) const { return m_size[find(v)]; }
115 
is_root(unsigned v)116     bool is_root(unsigned v) const { return m_find[v] == v; }
117 
merge(unsigned v1,unsigned v2)118     void merge(unsigned v1, unsigned v2) {
119         unsigned r1 = find(v1);
120         unsigned r2 = find(v2);
121         TRACE("union_find", tout << "merging " << r1 << " " << r2 << "\n";);
122         if (r1 == r2)
123             return;
124         if (m_size[r1] > m_size[r2]) {
125             std::swap(r1, r2);
126             std::swap(v1, v2);
127         }
128         m_ctx.merge_eh(r2, r1, v2, v1);
129         m_find[r1] = r2;
130         m_size[r2] += m_size[r1];
131         std::swap(m_next[r1], m_next[r2]);
132         m_trail_stack.push(merge_trail(*this, r1));
133         m_ctx.after_merge_eh(r2, r1, v2, v1);
134         CASSERT("union_find", check_invariant());
135     }
136 
137     // dissolve equivalence class of v
138     // this method cannot be used with backtracking.
dissolve(unsigned v)139     void dissolve(unsigned v) {
140         unsigned w;
141         do {
142             w = next(v);
143             m_size[v] = 1;
144             m_find[v] = v;
145             m_next[v] = v;
146         }
147         while (w != v);
148     }
149 
display(std::ostream & out)150     void display(std::ostream & out) const {
151         unsigned num = get_num_vars();
152         for (unsigned v = 0; v < num; v++) {
153             out << "v" << v << " --> v" << m_find[v] << " (" << size(v) << ")\n";
154         }
155     }
156 
157 #ifdef Z3DEBUG
check_invariant()158     bool check_invariant() const {
159         unsigned num = get_num_vars();
160         for (unsigned v = 0; v < num; v++) {
161             if (is_root(v)) {
162                 unsigned curr = v;
163                 unsigned sz   = 0;
164                 do {
165                     SASSERT(find(curr) == v);
166                     sz++;
167                     curr = next(curr);
168                 }
169                 while (curr != v);
170                 SASSERT(m_size[v] == sz);
171             }
172         }
173         return true;
174     }
175 #endif
176 };
177 
178 
179 class basic_union_find {
180     unsigned_vector   m_find;
181     unsigned_vector   m_size;
182     unsigned_vector   m_next;
183 
ensure_size(unsigned v)184     void ensure_size(unsigned v) {
185         while (v >= get_num_vars()) {
186             mk_var();
187         }
188     }
189  public:
mk_var()190     unsigned mk_var() {
191         unsigned r = m_find.size();
192         m_find.push_back(r);
193         m_size.push_back(1);
194         m_next.push_back(r);
195         return r;
196     }
get_num_vars()197     unsigned get_num_vars() const { return m_find.size(); }
198 
find(unsigned v)199     unsigned find(unsigned v) const {
200         if (v >= get_num_vars()) {
201             return v;
202         }
203         while (true) {
204             unsigned new_v = m_find[v];
205             if (new_v == v)
206                 return v;
207             v = new_v;
208         }
209     }
210 
next(unsigned v)211     unsigned next(unsigned v) const {
212         if (v >= get_num_vars()) {
213             return v;
214         }
215         return m_next[v];
216     }
217 
is_root(unsigned v)218     bool is_root(unsigned v) const {
219         return v >= get_num_vars() || m_find[v] == v;
220     }
221 
merge(unsigned v1,unsigned v2)222     void merge(unsigned v1, unsigned v2) {
223         unsigned r1 = find(v1);
224         unsigned r2 = find(v2);
225         if (r1 == r2)
226             return;
227         ensure_size(v1);
228         ensure_size(v2);
229         if (m_size[r1] > m_size[r2])
230             std::swap(r1, r2);
231         m_find[r1] = r2;
232         m_size[r2] += m_size[r1];
233         std::swap(m_next[r1], m_next[r2]);
234     }
235 
reset()236     void reset() {
237         m_find.reset();
238         m_next.reset();
239         m_size.reset();
240     }
241 };
242 
243 
244 
245