1 /*++
2   Copyright (c) 2020 Microsoft Corporation
3 
4   Module Name:
5 
6    sat_xor_finder.cpp
7 
8   Abstract:
9 
10     xor finder
11 
12   Author:
13 
14     Nikolaj Bjorner 2020-01-02
15 
16   Notes:
17 
18 
19   --*/
20 
21 #include "sat/sat_xor_finder.h"
22 #include "sat/sat_solver.h"
23 
24 namespace sat {
25 
26 
operator ()(clause_vector & clauses)27     void xor_finder::operator()(clause_vector& clauses) {
28         m_removed_clauses.reset();
29         unsigned max_size = m_max_xor_size;
30         // we better have enough bits in the combination mask to
31         // handle clauses up to max_size.
32         // max_size = 5 -> 32 bits
33         // max_size = 6 -> 64 bits
34         SASSERT(sizeof(m_combination)*8 <= (1ull << static_cast<uint64_t>(max_size)));
35         init_clause_filter();
36         m_var_position.resize(s.num_vars());
37         for (clause* cp : clauses) {
38             cp->unmark_used();
39         }
40         for (; max_size > 2; --max_size) {
41             for (clause* cp : clauses) {
42                 clause& c = *cp;
43                 if (c.size() == max_size && !c.was_removed() && !c.is_learned() && !c.was_used()) {
44                     extract_xor(c);
45                 }
46             }
47         }
48         m_clause_filters.clear();
49 
50         for (clause* cp : clauses) cp->unmark_used();
51         for (clause* cp : m_removed_clauses) cp->mark_used();
52         std::function<bool(clause*)> not_used = [](clause* cp) { return !cp->was_used(); };
53         clauses.filter_update(not_used);
54     }
55 
extract_xor(clause & c)56     void xor_finder::extract_xor(clause& c) {
57         SASSERT(c.size() > 2);
58         unsigned filter = get_clause_filter(c);
59         s.init_visited();
60         TRACE("sat_xor", tout << c << "\n";);
61         bool parity = false;
62         unsigned mask = 0, i = 0;
63         for (literal l : c) {
64             m_var_position[l.var()] = i;
65             s.mark_visited(l.var());
66             parity ^= !l.sign();
67             mask |= (!l.sign() << (i++));
68         }
69         // parity is number of true literals in clause.
70         m_clauses_to_remove.reset();
71         m_clauses_to_remove.push_back(&c);
72         m_clause.resize(c.size());
73         m_combination = 0;
74         set_combination(mask);
75         c.mark_used();
76         for (literal l : c) {
77             for (auto const& cf : m_clause_filters[l.var()]) {
78                 if ((filter == (filter | cf.m_filter)) &&
79                     !cf.m_clause->was_used() &&
80                     extract_xor(parity, c, *cf.m_clause)) {
81                     add_xor(parity, c);
82                     return;
83                 }
84             }
85             // loop over binary clauses in watch list
86             for (watched const & w : s.get_wlist(l)) {
87                 if (w.is_binary_clause() && s.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) {
88                     if (extract_xor(parity, c, ~l, w.get_literal())) {
89                         add_xor(parity, c);
90                         return;
91                     }
92                 }
93             }
94             l.neg();
95             for (watched const & w : s.get_wlist(l)) {
96                 if (w.is_binary_clause() && s.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) {
97                     if (extract_xor(parity, c, ~l, w.get_literal())) {
98                         add_xor(parity, c);
99                         return;
100                     }
101                 }
102             }
103         }
104     }
105 
set_combination(unsigned mask)106     void xor_finder::set_combination(unsigned mask) {
107         m_combination |= (1 << mask);
108     }
109 
110 
add_xor(bool parity,clause & c)111     void xor_finder::add_xor(bool parity, clause& c) {
112         DEBUG_CODE(for (clause* cp : m_clauses_to_remove) VERIFY(cp->was_used()););
113         m_removed_clauses.append(m_clauses_to_remove);
114         literal_vector lits;
115         for (literal l : c) {
116             lits.push_back(literal(l.var(), false));
117             s.set_external(l.var());
118         }
119         if (parity == (lits.size() % 2 == 0)) lits[0].neg();
120         TRACE("sat_xor", tout << parity << ": " << lits << "\n";);
121         m_on_xor(lits);
122     }
123 
extract_xor(bool parity,clause & c,literal l1,literal l2)124     bool xor_finder::extract_xor(bool parity, clause& c, literal l1, literal l2) {
125         SASSERT(s.is_visited(l1.var()));
126         SASSERT(s.is_visited(l2.var()));
127         m_missing.reset();
128         unsigned mask = 0;
129         for (unsigned i = 0; i < c.size(); ++i) {
130             if (c[i].var() == l1.var()) {
131                 mask |= (!l1.sign() << i);
132             }
133             else if (c[i].var() == l2.var()) {
134                 mask |= (!l2.sign() << i);
135             }
136             else {
137                 m_missing.push_back(i);
138             }
139         }
140         TRACE("sat_xor", tout << l1 << " " << l2 << "\n";);
141         return update_combinations(c, parity, mask);
142     }
143 
extract_xor(bool parity,clause & c,clause & c2)144     bool xor_finder::extract_xor(bool parity, clause& c, clause& c2) {
145         bool parity2 = false;
146         for (literal l : c2) {
147             if (!s.is_visited(l.var())) return false;
148             parity2 ^= !l.sign();
149         }
150         if (c2.size() == c.size() && parity2 != parity) {
151             return false;
152         }
153         if (c2.size() == c.size()) {
154             m_clauses_to_remove.push_back(&c2);
155             c2.mark_used();
156         }
157         TRACE("sat_xor", tout << c2 << "\n";);
158         // insert missing
159         unsigned mask = 0;
160         m_missing.reset();
161         SASSERT(c2.size() <= c.size());
162         for (unsigned i = 0; i < c.size(); ++i) {
163             m_clause[i] = null_literal;
164         }
165         for (literal l : c2) {
166             unsigned pos = m_var_position[l.var()];
167             m_clause[pos] = l;
168         }
169         for (unsigned j = 0; j < c.size(); ++j) {
170             literal lit = m_clause[j];
171             if (lit == null_literal) {
172                 m_missing.push_back(j);
173             }
174             else {
175                 mask |= (!m_clause[j].sign() << j);
176             }
177         }
178         return update_combinations(c, parity, mask);
179     }
180 
update_combinations(clause & c,bool parity,unsigned mask)181     bool xor_finder::update_combinations(clause& c, bool parity, unsigned mask) {
182         unsigned num_missing = m_missing.size();
183         for (unsigned k = 0; k < (1ul << num_missing); ++k) {
184             unsigned mask2 = mask;
185             for (unsigned i = 0; i < num_missing; ++i) {
186                 if ((k & (1 << i)) != 0) {
187                     mask2 |= 1ul << m_missing[i];
188                 }
189             }
190             set_combination(mask2);
191         }
192         // return true if xor clause is covered.
193         unsigned sz = c.size();
194         for (unsigned i = 0; i < (1ul << sz); ++i) {
195             TRACE("sat_xor", tout << i << ": " << parity << " " << m_parity[sz][i] << " " << get_combination(i) << "\n";);
196             if (parity == m_parity[sz][i] && !get_combination(i)) {
197                 return false;
198             }
199         }
200         return true;
201     }
202 
init_parity()203     void xor_finder::init_parity() {
204         for (unsigned i = m_parity.size(); i <= m_max_xor_size; ++i) {
205             bool_vector bv;
206             for (unsigned j = 0; j < (1ul << i); ++j) {
207                 bool parity = false;
208                 for (unsigned k = 0; k < i; ++k) {
209                     parity ^= ((j & (1 << k)) != 0);
210                 }
211                 bv.push_back(parity);
212             }
213             m_parity.push_back(bv);
214         }
215     }
216 
init_clause_filter()217     void xor_finder::init_clause_filter() {
218         m_clause_filters.reset();
219         m_clause_filters.resize(s.num_vars());
220         init_clause_filter(s.m_clauses);
221         init_clause_filter(s.m_learned);
222     }
223 
init_clause_filter(clause_vector & clauses)224     void xor_finder::init_clause_filter(clause_vector& clauses) {
225         for (clause* cp : clauses) {
226             clause& c = *cp;
227             if (c.size() <= m_max_xor_size && s.all_distinct(c)) {
228                 clause_filter cf(get_clause_filter(c), cp);
229                 for (literal l : c) {
230                     m_clause_filters[l.var()].push_back(cf);
231                 }
232             }
233         }
234     }
235 
get_clause_filter(clause & c)236     unsigned xor_finder::get_clause_filter(clause& c) {
237         unsigned filter = 0;
238         for (literal l : c) {
239             filter |= 1 << ((l.var() % 32));
240         }
241         return filter;
242     }
243 
244 
245 }
246