1 /*++ 2 Copyright (c) 2020 Microsoft Corporation 3 4 Module Name: 5 6 sat_xor_finder.cpp 7 8 Abstract: 9 10 xor finder 11 12 Author: 13 14 Nikolaj Bjorner 2020-01-02 15 16 Notes: 17 18 19 --*/ 20 21 #include "sat/sat_xor_finder.h" 22 #include "sat/sat_solver.h" 23 24 namespace sat { 25 26 operator ()(clause_vector & clauses)27 void xor_finder::operator()(clause_vector& clauses) { 28 m_removed_clauses.reset(); 29 unsigned max_size = m_max_xor_size; 30 // we better have enough bits in the combination mask to 31 // handle clauses up to max_size. 32 // max_size = 5 -> 32 bits 33 // max_size = 6 -> 64 bits 34 SASSERT(sizeof(m_combination)*8 <= (1ull << static_cast<uint64_t>(max_size))); 35 init_clause_filter(); 36 m_var_position.resize(s.num_vars()); 37 for (clause* cp : clauses) { 38 cp->unmark_used(); 39 } 40 for (; max_size > 2; --max_size) { 41 for (clause* cp : clauses) { 42 clause& c = *cp; 43 if (c.size() == max_size && !c.was_removed() && !c.is_learned() && !c.was_used()) { 44 extract_xor(c); 45 } 46 } 47 } 48 m_clause_filters.clear(); 49 50 for (clause* cp : clauses) cp->unmark_used(); 51 for (clause* cp : m_removed_clauses) cp->mark_used(); 52 std::function<bool(clause*)> not_used = [](clause* cp) { return !cp->was_used(); }; 53 clauses.filter_update(not_used); 54 } 55 extract_xor(clause & c)56 void xor_finder::extract_xor(clause& c) { 57 SASSERT(c.size() > 2); 58 unsigned filter = get_clause_filter(c); 59 s.init_visited(); 60 TRACE("sat_xor", tout << c << "\n";); 61 bool parity = false; 62 unsigned mask = 0, i = 0; 63 for (literal l : c) { 64 m_var_position[l.var()] = i; 65 s.mark_visited(l.var()); 66 parity ^= !l.sign(); 67 mask |= (!l.sign() << (i++)); 68 } 69 // parity is number of true literals in clause. 70 m_clauses_to_remove.reset(); 71 m_clauses_to_remove.push_back(&c); 72 m_clause.resize(c.size()); 73 m_combination = 0; 74 set_combination(mask); 75 c.mark_used(); 76 for (literal l : c) { 77 for (auto const& cf : m_clause_filters[l.var()]) { 78 if ((filter == (filter | cf.m_filter)) && 79 !cf.m_clause->was_used() && 80 extract_xor(parity, c, *cf.m_clause)) { 81 add_xor(parity, c); 82 return; 83 } 84 } 85 // loop over binary clauses in watch list 86 for (watched const & w : s.get_wlist(l)) { 87 if (w.is_binary_clause() && s.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { 88 if (extract_xor(parity, c, ~l, w.get_literal())) { 89 add_xor(parity, c); 90 return; 91 } 92 } 93 } 94 l.neg(); 95 for (watched const & w : s.get_wlist(l)) { 96 if (w.is_binary_clause() && s.is_visited(w.get_literal().var()) && w.get_literal().index() < l.index()) { 97 if (extract_xor(parity, c, ~l, w.get_literal())) { 98 add_xor(parity, c); 99 return; 100 } 101 } 102 } 103 } 104 } 105 set_combination(unsigned mask)106 void xor_finder::set_combination(unsigned mask) { 107 m_combination |= (1 << mask); 108 } 109 110 add_xor(bool parity,clause & c)111 void xor_finder::add_xor(bool parity, clause& c) { 112 DEBUG_CODE(for (clause* cp : m_clauses_to_remove) VERIFY(cp->was_used());); 113 m_removed_clauses.append(m_clauses_to_remove); 114 literal_vector lits; 115 for (literal l : c) { 116 lits.push_back(literal(l.var(), false)); 117 s.set_external(l.var()); 118 } 119 if (parity == (lits.size() % 2 == 0)) lits[0].neg(); 120 TRACE("sat_xor", tout << parity << ": " << lits << "\n";); 121 m_on_xor(lits); 122 } 123 extract_xor(bool parity,clause & c,literal l1,literal l2)124 bool xor_finder::extract_xor(bool parity, clause& c, literal l1, literal l2) { 125 SASSERT(s.is_visited(l1.var())); 126 SASSERT(s.is_visited(l2.var())); 127 m_missing.reset(); 128 unsigned mask = 0; 129 for (unsigned i = 0; i < c.size(); ++i) { 130 if (c[i].var() == l1.var()) { 131 mask |= (!l1.sign() << i); 132 } 133 else if (c[i].var() == l2.var()) { 134 mask |= (!l2.sign() << i); 135 } 136 else { 137 m_missing.push_back(i); 138 } 139 } 140 TRACE("sat_xor", tout << l1 << " " << l2 << "\n";); 141 return update_combinations(c, parity, mask); 142 } 143 extract_xor(bool parity,clause & c,clause & c2)144 bool xor_finder::extract_xor(bool parity, clause& c, clause& c2) { 145 bool parity2 = false; 146 for (literal l : c2) { 147 if (!s.is_visited(l.var())) return false; 148 parity2 ^= !l.sign(); 149 } 150 if (c2.size() == c.size() && parity2 != parity) { 151 return false; 152 } 153 if (c2.size() == c.size()) { 154 m_clauses_to_remove.push_back(&c2); 155 c2.mark_used(); 156 } 157 TRACE("sat_xor", tout << c2 << "\n";); 158 // insert missing 159 unsigned mask = 0; 160 m_missing.reset(); 161 SASSERT(c2.size() <= c.size()); 162 for (unsigned i = 0; i < c.size(); ++i) { 163 m_clause[i] = null_literal; 164 } 165 for (literal l : c2) { 166 unsigned pos = m_var_position[l.var()]; 167 m_clause[pos] = l; 168 } 169 for (unsigned j = 0; j < c.size(); ++j) { 170 literal lit = m_clause[j]; 171 if (lit == null_literal) { 172 m_missing.push_back(j); 173 } 174 else { 175 mask |= (!m_clause[j].sign() << j); 176 } 177 } 178 return update_combinations(c, parity, mask); 179 } 180 update_combinations(clause & c,bool parity,unsigned mask)181 bool xor_finder::update_combinations(clause& c, bool parity, unsigned mask) { 182 unsigned num_missing = m_missing.size(); 183 for (unsigned k = 0; k < (1ul << num_missing); ++k) { 184 unsigned mask2 = mask; 185 for (unsigned i = 0; i < num_missing; ++i) { 186 if ((k & (1 << i)) != 0) { 187 mask2 |= 1ul << m_missing[i]; 188 } 189 } 190 set_combination(mask2); 191 } 192 // return true if xor clause is covered. 193 unsigned sz = c.size(); 194 for (unsigned i = 0; i < (1ul << sz); ++i) { 195 TRACE("sat_xor", tout << i << ": " << parity << " " << m_parity[sz][i] << " " << get_combination(i) << "\n";); 196 if (parity == m_parity[sz][i] && !get_combination(i)) { 197 return false; 198 } 199 } 200 return true; 201 } 202 init_parity()203 void xor_finder::init_parity() { 204 for (unsigned i = m_parity.size(); i <= m_max_xor_size; ++i) { 205 bool_vector bv; 206 for (unsigned j = 0; j < (1ul << i); ++j) { 207 bool parity = false; 208 for (unsigned k = 0; k < i; ++k) { 209 parity ^= ((j & (1 << k)) != 0); 210 } 211 bv.push_back(parity); 212 } 213 m_parity.push_back(bv); 214 } 215 } 216 init_clause_filter()217 void xor_finder::init_clause_filter() { 218 m_clause_filters.reset(); 219 m_clause_filters.resize(s.num_vars()); 220 init_clause_filter(s.m_clauses); 221 init_clause_filter(s.m_learned); 222 } 223 init_clause_filter(clause_vector & clauses)224 void xor_finder::init_clause_filter(clause_vector& clauses) { 225 for (clause* cp : clauses) { 226 clause& c = *cp; 227 if (c.size() <= m_max_xor_size && s.all_distinct(c)) { 228 clause_filter cf(get_clause_filter(c), cp); 229 for (literal l : c) { 230 m_clause_filters[l.var()].push_back(cf); 231 } 232 } 233 } 234 } 235 get_clause_filter(clause & c)236 unsigned xor_finder::get_clause_filter(clause& c) { 237 unsigned filter = 0; 238 for (literal l : c) { 239 filter |= 1 << ((l.var() % 32)); 240 } 241 return filter; 242 } 243 244 245 } 246