1 /*++
2   Copyright (c) 2017 Microsoft Corporation
3 
4   Module Name:
5 
6   <name>
7 
8   Abstract:
9 
10   <abstract>
11 
12   Author:
13   Nikolaj Bjorner (nbjorner)
14   Lev Nachmanson (levnach)
15 
16   Revision History:
17 
18 
19   --*/
20 
21 #pragma once
22 #include "util/union_find.h"
23 #include "math/lp/nla_defs.h"
24 #include "util/rational.h"
25 #include "math/lp/explanation.h"
26 #include "math/lp/incremental_vector.h"
27 
28 namespace nla {
29 
30 class eq_justification {
31     lpci m_cs[4];
32 public:
eq_justification(std::initializer_list<lpci> cs)33     eq_justification(std::initializer_list<lpci> cs) {
34         int i = 0;
35         for (lpci c: cs) {
36             m_cs[i++] = c;
37         }
38         for (; i < 4; i++) {
39             m_cs[i] = -1;
40         }
41     }
42 
explain(lp::explanation & e)43     void explain(lp::explanation& e) const {
44         for (lpci c : m_cs)
45             if (c + 1 != 0) // c != -1
46                 e.push_back(c);
47     }
48 };
49 
50 template <typename T>
51 class var_eqs {
52     struct eq_edge {
53         signed_var       m_var;
54         eq_justification m_just;
eq_edgeeq_edge55         eq_edge(signed_var v, eq_justification const& j): m_var(v), m_just(j) {}
56     };
57 
58     struct var_frame {
59         signed_var m_var;
60         unsigned   m_index;
var_framevar_frame61         var_frame(signed_var v, unsigned i): m_var(v), m_index(i) {}
62     };
63     struct stats {
64         unsigned m_num_explain_calls;
65         unsigned m_num_explains;
statsstats66         stats() { memset(this, 0, sizeof(*this)); }
67     };
68 
69     T*                                m_merge_handler;
70     union_find<var_eqs>               m_uf;
71     lp::incremental_vector<std::pair<signed_var, signed_var>>
72 	                                  m_trail;
73     vector<svector<eq_edge>>          m_eqs;    // signed_var.index() -> the edges adjacent to signed_var.index()
74 
75     trail_stack                       m_stack;
76     mutable svector<var_frame>        m_todo;
77     mutable bool_vector             m_marked;
78     mutable unsigned_vector           m_marked_trail;
79     mutable svector<eq_justification> m_justtrail;
80 
81     mutable stats m_stats;
82 public:
var_eqs()83     var_eqs(): m_merge_handler(nullptr), m_uf(*this), m_stack() {}
84     /**
85        \brief push a scope    */
push()86     void push() {
87         m_trail.push_scope();
88         m_stack.push_scope();
89     }
90 
91     /**
92        \brief pop n scopes
93     */
pop(unsigned n)94     void pop(unsigned n)  {
95         unsigned old_sz = m_trail.peek_size(n);
96         for (unsigned i = m_trail.size(); i-- > old_sz; ) {
97             auto const& sv = m_trail[i];
98             m_eqs[sv.first.index()].pop_back();
99             m_eqs[sv.second.index()].pop_back();
100             m_eqs[(~sv.first).index()].pop_back();
101             m_eqs[(~sv.second).index()].pop_back();
102         }
103         m_trail.pop_scope(n);
104         m_stack.pop_scope(n); // this cass takes care of unmerging through union_find m_uf
105     }
106 
107     /**
108        \brief merge equivalence classes for v1, v2 with justification j
109     */
merge(signed_var v1,signed_var v2,eq_justification const & j)110     void merge(signed_var v1, signed_var v2, eq_justification const& j)  {
111         if (v1 == v2)
112             return;
113         if (find(v1).var() == find(v2).var())
114             return;
115         unsigned max_i = std::max(v1.index(), v2.index()) + 2;
116         m_eqs.reserve(max_i);
117         while (m_uf.get_num_vars() <= max_i) m_uf.mk_var();
118         TRACE("nla_solver_mons", tout << v1 << " == " << v2 << " " << m_uf.find(v1.index()) << " == " << m_uf.find(v2.index()) << "\n";);
119         m_trail.push_back(std::make_pair(v1, v2));
120         m_uf.merge(v1.index(), v2.index());
121         m_uf.merge((~v1).index(), (~v2).index());
122         m_eqs[v1.index()].push_back(eq_edge(v2, j));
123         m_eqs[v2.index()].push_back(eq_edge(v1, j));
124         m_eqs[(~v1).index()].push_back(eq_edge(~v2, j));
125         m_eqs[(~v2).index()].push_back(eq_edge(~v1, j));
126     }
127 
merge_plus(lpvar v1,lpvar v2,eq_justification const & j)128     void merge_plus(lpvar v1, lpvar v2, eq_justification const& j)  { merge(signed_var(v1, false), signed_var(v2, false), j); }
merge_minus(lpvar v1,lpvar v2,eq_justification const & j)129     void merge_minus(lpvar v1, lpvar v2, eq_justification const& j) { merge(signed_var(v1, false), signed_var(v2, true),  j); }
130 
131     /**
132        \brief find equivalence class representative for v
133     */
find(signed_var v)134     signed_var find(signed_var v) const {
135         if (v.index() >= m_uf.get_num_vars()) {
136             return v;
137         }
138         unsigned idx = m_uf.find(v.index());
139         return signed_var(idx);
140     }
141 
find(lpvar j)142     inline signed_var find(lpvar j) const {
143         return find(signed_var(j, false));
144     }
145 
is_root(lpvar j)146     inline bool is_root(lpvar j) const {
147         signed_var sv = find(signed_var(j, false));
148         return sv.var() == j;
149     }
is_root(svector<lpvar> v)150     inline bool is_root(svector<lpvar> v) const {
151         for (lpvar j : v)
152             if (!is_root(j))
153                 return false;
154         return true;
155     }
156 
vars_are_equiv(lpvar j,lpvar k)157     bool vars_are_equiv(lpvar j, lpvar k) const {
158         signed_var sj = find(signed_var(j, false));
159         signed_var sk = find(signed_var(k, false));
160         return sj.var() == sk.var();
161     }
162     /**
163        \brief Returns eq_justifications for
164        \pre find(v1) == find(v2)
165     */
explain_dfs(signed_var v1,signed_var v2,lp::explanation & e)166     void explain_dfs(signed_var v1, signed_var v2, lp::explanation& e) const {
167         SASSERT(find(v1) == find(v2));
168         if (v1 == v2) {
169             return;
170         }
171         m_todo.push_back(var_frame(v1, 0));
172         m_justtrail.reset();
173         m_marked.reserve(m_eqs.size(), false);
174         SASSERT(m_marked_trail.empty());
175         m_marked[v1.index()] = true;
176         m_marked_trail.push_back(v1.index());
177         while (true) {
178             SASSERT(!m_todo.empty());
179             var_frame& f = m_todo.back();
180             signed_var v = f.m_var;
181             if (v == v2) {
182                 break;
183             }
184             auto const& next = m_eqs[v.index()];
185             bool seen_all = true;
186             unsigned sz = next.size();
187             for (unsigned i = f.m_index; seen_all && i < sz; ++i) {
188                 eq_edge const& jv = next[i];
189                 signed_var v3 = jv.m_var;
190                 if (!m_marked[v3.index()]) {
191                     seen_all = false;
192                     f.m_index = i + 1;
193                     m_todo.push_back(var_frame(v3, 0));
194                     m_justtrail.push_back(jv.m_just);
195                     m_marked_trail.push_back(v3.index());
196                     m_marked[v3.index()] = true;
197                 }
198             }
199             if (seen_all) {
200                 m_todo.pop_back();
201                 m_justtrail.pop_back();
202             }
203         }
204 
205         for (eq_justification const& j : m_justtrail) {
206             j.explain(e);
207         }
208         m_stats.m_num_explains += m_justtrail.size();
209         m_stats.m_num_explain_calls++;
210         m_todo.reset();
211         m_justtrail.reset();
212         for (unsigned idx : m_marked_trail) {
213             m_marked[idx] = false;
214         }
215         m_marked_trail.reset();
216 
217         // IF_VERBOSE(2, verbose_stream() << (double)m_stats.m_num_explains / m_stats.m_num_explain_calls << "\n");
218     }
219 
explain_bfs(signed_var v1,signed_var v2,lp::explanation & e)220     void explain_bfs(signed_var v1, signed_var v2, lp::explanation& e) const {
221         SASSERT(find(v1) == find(v2));
222         if (v1 == v2) {
223             return;
224         }
225         m_todo.push_back(var_frame(v1, 0));
226         m_justtrail.push_back(eq_justification({}));
227         m_marked.reserve(m_eqs.size(), false);
228         SASSERT(m_marked_trail.empty());
229         m_marked[v1.index()] = true;
230         m_marked_trail.push_back(v1.index());
231         unsigned head = 0;
232         for (; ; ++head) {
233             var_frame& f = m_todo[head];
234             signed_var v = f.m_var;
235             if (v == v2) {
236                 break;
237             }
238             auto const& next = m_eqs[v.index()];
239             unsigned sz = next.size();
240             for (unsigned i = sz; i-- > 0; ) {
241                 eq_edge const& jv = next[i];
242                 signed_var v3 = jv.m_var;
243                 if (!m_marked[v3.index()]) {
244                     m_todo.push_back(var_frame(v3, head));
245                     m_justtrail.push_back(jv.m_just);
246                     m_marked_trail.push_back(v3.index());
247                     m_marked[v3.index()] = true;
248                 }
249             }
250         }
251 
252         while (head != 0) {
253             m_justtrail[head].explain(e);
254             head = m_todo[head].m_index;
255             ++m_stats.m_num_explains;
256         }
257         ++m_stats.m_num_explain_calls;
258 
259         m_todo.reset();
260         m_justtrail.reset();
261         for (unsigned idx : m_marked_trail) {
262             m_marked[idx] = false;
263         }
264         m_marked_trail.reset();
265 
266         // IF_VERBOSE(2, verbose_stream() << (double)m_stats.m_num_explains / m_stats.m_num_explain_calls << "\n");
267     }
268 
269 
explain(signed_var v1,signed_var v2,lp::explanation & e)270     inline void explain(signed_var v1, signed_var v2, lp::explanation& e) const {
271         explain_bfs(v1, v2, e);
272     }
explain(lpvar v1,lpvar v2,lp::explanation & e)273     inline void explain(lpvar v1, lpvar v2, lp::explanation & e) const {
274         return explain(signed_var(v1, false), signed_var(v2, false), e);
275     }
276 
explain(lpvar j,lp::explanation & e)277     inline void explain(lpvar j, lp::explanation& e) const {
278         signed_var s(j, false);
279         return explain(find(s), s, e);
280     }
281 
282     // iterates over the class of lpvar(m_idx)
283     class iterator {
284         var_eqs& m_ve;        // context.
285         unsigned m_idx;       // index into a signed variable, same as union-find index
286         bool     m_touched;   // toggle between initial and final state
287     public:
iterator(var_eqs & ve,unsigned idx,bool t)288         iterator(var_eqs& ve, unsigned idx, bool t) : m_ve(ve), m_idx(idx), m_touched(t) {}
289         signed_var operator*() const {
290             return signed_var(m_idx);
291         }
292         iterator& operator++() { m_idx = m_ve.m_uf.next(m_idx); m_touched = true; return *this; }
293         bool operator==(iterator const& other) const { return m_idx == other.m_idx && m_touched == other.m_touched; }
294         bool operator!=(iterator const& other) const { return m_idx != other.m_idx || m_touched != other.m_touched; }
295     };
296 
297     class eq_class {
298         var_eqs& m_ve;
299         signed_var m_v;
300     public:
eq_class(var_eqs & ve,signed_var v)301         eq_class(var_eqs& ve, signed_var v) : m_ve(ve), m_v(v) {}
begin()302         iterator begin() { return iterator(m_ve, m_v.index(), false); }
end()303         iterator end() { return iterator(m_ve, m_v.index(), true); }
304     };
305 
equiv_class(signed_var v)306     eq_class equiv_class(signed_var v) { return eq_class(*this, v); }
307 
equiv_class(lpvar v)308     eq_class equiv_class(lpvar v) { return equiv_class(signed_var(v, false)); }
309 
310 
display(std::ostream & out)311     std::ostream& display(std::ostream& out) const {
312         m_uf.display(out);
313         unsigned idx = 0;
314         for (auto const& edges : m_eqs) {
315             if (!edges.empty()) {
316                 auto v = signed_var(idx);
317                 out << v << " root: " << find(v) << " : ";
318                 for (auto const& jv : edges) {
319                     out << jv.m_var << " ";
320                 }
321                 out << "\n";
322             }
323             ++idx;
324         }
325         return out;
326     }
327 
328     // union find event handlers
set_merge_handler(T * mh)329     void set_merge_handler(T* mh) { m_merge_handler = mh; }
330     // this method is required by union_find
get_trail_stack()331     trail_stack & get_trail_stack() { return m_stack; }
332 
unmerge_eh(unsigned i,unsigned j)333     void unmerge_eh(unsigned i, unsigned j) {
334         if (m_merge_handler) {
335             m_merge_handler->unmerge_eh(signed_var(i), signed_var(j));
336         }
337     }
merge_eh(unsigned r2,unsigned r1,unsigned v2,unsigned v1)338     void merge_eh(unsigned r2, unsigned r1, unsigned v2, unsigned v1) {
339         if (m_merge_handler) {
340             m_merge_handler->merge_eh(signed_var(r2), signed_var(r1),
341                                       signed_var(v2), signed_var(v1));
342         }
343     }
344 
after_merge_eh(unsigned r2,unsigned r1,unsigned v2,unsigned v1)345     void after_merge_eh(unsigned r2, unsigned r1, unsigned v2, unsigned v1) {
346         if (m_merge_handler) {
347             m_merge_handler->after_merge_eh(signed_var(r2), signed_var(r1),
348                                             signed_var(v2), signed_var(v1));
349         }
350     }
351 };  // end of var_eqs
352 }
353