1 /*++
2   Copyright (c) 2020 Microsoft Corporation
3 
4   Module Name:
5 
6    sat_anf_simplifier.cpp
7 
8   Abstract:
9 
10     Simplification based on ANF format.
11 
12   Author:
13 
14     Nikolaj Bjorner 2020-01-02
15 
16   --*/
17 
18 #include "util/union_find.h"
19 #include "sat/sat_anf_simplifier.h"
20 #include "sat/sat_solver.h"
21 #include "sat/sat_elim_eqs.h"
22 #include "sat/sat_xor_finder.h"
23 #include "sat/sat_aig_finder.h"
24 #include "math/grobner/pdd_solver.h"
25 
26 namespace sat {
27 
28     struct anf_simplifier::report {
29         anf_simplifier& s;
30         stopwatch       m_watch;
reportsat::anf_simplifier::report31         report(anf_simplifier& s): s(s) { m_watch.start(); }
~reportsat::anf_simplifier::report32         ~report() {
33             m_watch.stop();
34             IF_VERBOSE(2,
35                        verbose_stream() << " (sat.anf.simplifier"
36                        << " :num-units " << s.m_stats.m_num_units
37                        << " :num-eqs " << s.m_stats.m_num_eqs
38                        << " :mb " << mem_stat()
39                        << m_watch << ")\n");
40         }
41     };
42 
operator ()()43     void anf_simplifier::operator()() {
44         dd::pdd_manager m(20, dd::pdd_manager::semantics::mod2_e);
45         pdd_solver solver(s.rlimit(), m);
46         report _report(*this);
47         configure_solver(solver);
48         clauses2anf(solver);
49         TRACE("anf_simplifier", solver.display(tout); s.display(tout););
50         solver.simplify();
51         TRACE("anf_simplifier", solver.display(tout););
52         anf2clauses(solver);
53         anf2phase(solver);
54         save_statistics(solver);
55         IF_VERBOSE(10, m_st.display(verbose_stream() << "(sat.anf.simplifier\n"); verbose_stream() << ")\n");
56     }
57 
58     /**
59        \brief extract learned units and equivalences from processed anf.
60 
61        TBD: could learn binary clauses
62        TBD: could try simplify equations using BIG subsumption similar to asymm_branch
63      */
anf2clauses(pdd_solver & solver)64     void anf_simplifier::anf2clauses(pdd_solver& solver) {
65 
66         union_find_default_ctx ctx;
67         union_find<> uf(ctx);
68         for (unsigned i = 2*s.num_vars(); i--> 0; ) uf.mk_var();
69         auto add_eq = [&](literal l1, literal l2) {
70             uf.merge(l1.index(), l2.index());
71             uf.merge((~l1).index(), (~l2).index());
72         };
73 
74         unsigned old_num_eqs = m_stats.m_num_eqs;
75         for (auto* e : solver.equations()) {
76             auto const& p = e->poly();
77             if (p.is_one()) {
78                 s.set_conflict();
79                 break;
80             }
81             else if (p.is_unary()) {
82                 // unit
83                 SASSERT(!p.is_val() && p.lo().is_val() && p.hi().is_val());
84                 literal lit(p.var(), p.lo().is_zero());
85                 s.assign_unit(lit);
86                 ++m_stats.m_num_units;
87                 TRACE("anf_simplifier", tout << "unit " << p << " : " << lit << "\n";);
88             }
89             else if (p.is_binary()) {
90                 // equivalence
91                 // x + y + c = 0
92                 SASSERT(!p.is_val() && p.hi().is_one() && !p.lo().is_val() && p.lo().hi().is_one() && p.lo().lo().is_val());
93                 literal x(p.var(), false);
94                 literal y(p.lo().var(), p.lo().lo().is_one());
95                 add_eq(x, y);
96                 ++m_stats.m_num_eqs;
97                 TRACE("anf_simplifier", tout << "equivalence " << p << " : " << x << " == " << y << "\n";);
98             }
99         }
100 
101         if (old_num_eqs < m_stats.m_num_eqs) {
102             elim_eqs elim(s);
103             elim(uf);
104         }
105     }
106 
107     /**
108        \brief update best phase using solved equations
109        polynomials that are not satisfied evaluate to true.
110        In a satisfying assignment, all polynomials should evaluate to false.
111        assume that solutions are provided in reverse order.
112 
113        As a simplifying assumption it relies on the property
114        that if an equation is of the form v + p, where v does not occur in p,
115        then all equations that come after it do not contain p.
116        In this way we can flip the assignment to v without
117        invalidating the evaluation cache.
118      */
anf2phase(pdd_solver & solver)119     void anf_simplifier::anf2phase(pdd_solver& solver) {
120         if (!m_config.m_anf2phase)
121             return;
122         reset_eval();
123         auto const& eqs = solver.equations();
124         for (unsigned i = eqs.size(); i-- > 0; ) {
125             dd::pdd const& p = eqs[i]->poly();
126             if (!p.is_val() && p.hi().is_one() && s.m_best_phase[p.var()] != eval(p.lo())) {
127                 s.m_best_phase[p.var()] = !s.m_best_phase[p.var()];
128                 ++m_stats.m_num_phase_flips;
129             }
130         }
131     }
132 
eval(dd::pdd const & p)133     bool anf_simplifier::eval(dd::pdd const& p) {
134         if (p.is_one()) return true;
135         if (p.is_zero()) return false;
136         unsigned index = p.index();
137         if (index < m_eval_cache.size()) {
138             if (m_eval_cache[index] == m_eval_ts) return false;
139             if (m_eval_cache[index] == m_eval_ts + 1) return true;
140         }
141         SASSERT(!p.is_val());
142         bool hi = eval(p.hi());
143         bool lo = eval(p.lo());
144         bool v = (hi && s.m_best_phase[p.var()]) ^ lo;
145         m_eval_cache.reserve(index + 1, 0);
146         m_eval_cache[index] = m_eval_ts + v;
147         return v;
148     }
149 
reset_eval()150     void anf_simplifier::reset_eval() {
151         if (m_eval_ts + 2 < m_eval_ts) {
152             m_eval_cache.reset();
153             m_eval_ts = 0;
154         }
155         m_eval_ts += 2;
156     }
157 
clauses2anf(pdd_solver & solver)158     void anf_simplifier::clauses2anf(pdd_solver& solver) {
159         svector<solver::bin_clause> bins;
160         m_relevant.reset();
161         m_relevant.resize(s.num_vars(), false);
162         clause_vector clauses(s.clauses());
163         s.collect_bin_clauses(bins, false, false);
164         collect_clauses(clauses, bins);
165         try {
166             compile_xors(clauses, solver);
167             compile_aigs(clauses, bins, solver);
168 
169             for (auto const& b : bins) {
170                 add_bin(b, solver);
171             }
172             for (clause* cp : clauses) {
173                 add_clause(*cp, solver);
174             }
175         }
176         catch (dd::pdd_manager::mem_out) {
177             IF_VERBOSE(1, verbose_stream() << "(sat.anf memout)\n");
178         }
179     }
180 
collect_clauses(clause_vector & clauses,svector<solver::bin_clause> & bins)181     void anf_simplifier::collect_clauses(clause_vector & clauses, svector<solver::bin_clause>& bins) {
182         clause_vector oclauses;
183         svector<solver::bin_clause> obins;
184 
185         unsigned j = 0;
186         for (clause* cp : clauses) {
187             clause const& c = *cp;
188             if (is_too_large(c))
189                 continue;
190             else if (is_pre_satisfied(c)) {
191                 oclauses.push_back(cp);
192             }
193             else {
194                 clauses[j++] = cp;
195             }
196         }
197         clauses.shrink(j);
198 
199         j = 0;
200         for (auto const& b : bins) {
201             if (is_pre_satisfied(b)) {
202                 obins.push_back(b);
203             }
204             else {
205                 bins[j++] = b;
206             }
207         }
208         bins.shrink(j);
209 
210         unsigned rounds = 0, max_rounds = 3;
211         bool added = true;
212         while (bins.size() + clauses.size() < m_config.m_max_clauses &&
213                (!obins.empty() || !oclauses.empty()) &&
214                added &&
215                rounds < max_rounds) {
216 
217             added = false;
218             for (auto const& b : bins) set_relevant(b);
219             for (clause* cp : clauses) set_relevant(*cp);
220 
221             j = 0;
222             for (auto const& b : obins) {
223                 if (has_relevant_var(b)) {
224                     added = true;
225                     bins.push_back(b);
226                 }
227                 else {
228                     obins[j++] = b;
229                 }
230             }
231             obins.shrink(j);
232 
233             if (bins.size() + clauses.size() >= m_config.m_max_clauses) {
234                 break;
235             }
236 
237             j = 0;
238             for (clause* cp : oclauses) {
239                 if (has_relevant_var(*cp)) {
240                     added = true;
241                     clauses.push_back(cp);
242                 }
243                 else {
244                     oclauses[j++] = cp;
245                 }
246             }
247             oclauses.shrink(j);
248         }
249 
250         TRACE("anf_simplifier",
251               tout << "kept:\n";
252               for (clause* cp : clauses) tout << *cp << "\n";
253               for (auto b : bins) tout << b.first << " " << b.second << "\n";
254               tout << "removed:\n";
255               for (clause* cp : oclauses) tout << *cp << "\n";
256               for (auto b : obins) tout << b.first << " " << b.second << "\n";);
257     }
258 
set_relevant(solver::bin_clause const & b)259     void anf_simplifier::set_relevant(solver::bin_clause const& b) {
260         set_relevant(b.first);
261         set_relevant(b.second);
262     }
263 
set_relevant(clause const & c)264     void anf_simplifier::set_relevant(clause const& c) {
265         for (literal l : c) set_relevant(l);
266     }
267 
is_pre_satisfied(clause const & c)268     bool anf_simplifier::is_pre_satisfied(clause const& c) {
269         for (literal l : c) if (phase_is_true(l)) return true;
270         return false;
271     }
272 
is_pre_satisfied(solver::bin_clause const & b)273     bool anf_simplifier::is_pre_satisfied(solver::bin_clause const& b) {
274         return phase_is_true(b.first) || phase_is_true(b.second);
275     }
276 
phase_is_true(literal l)277     bool anf_simplifier::phase_is_true(literal l) {
278         bool ph = (s.m_best_phase_size > 0) ? s.m_best_phase[l.var()] : s.m_phase[l.var()];
279         return l.sign() ? !ph : ph;
280     }
281 
has_relevant_var(clause const & c)282     bool anf_simplifier::has_relevant_var(clause const& c) {
283         for (literal l : c) if (is_relevant(l)) return true;
284         return false;
285     }
286 
has_relevant_var(solver::bin_clause const & b)287     bool anf_simplifier::has_relevant_var(solver::bin_clause const& b) {
288         return is_relevant(b.first) || is_relevant(b.second);
289     }
290 
291     /**
292        \brief extract xors from all s.clauses()
293        (could be just filtered clauses, or clauses with relevant variables).
294        Add the extracted xors to pdd_solver.
295        Remove clauses from list that correspond to extracted xors
296      */
compile_xors(clause_vector & clauses,pdd_solver & ps)297     void anf_simplifier::compile_xors(clause_vector& clauses, pdd_solver& ps) {
298         if (!m_config.m_compile_xor) {
299             return;
300         }
301         std::function<void(literal_vector const&)> f =
302             [&,this](literal_vector const& x) {
303             add_xor(x, ps);
304             m_stats.m_num_xors++;
305         };
306         xor_finder xf(s);
307         xf.set(f);
308         xf(clauses);
309     }
310 
normalize(solver::bin_clause const & b)311     static solver::bin_clause normalize(solver::bin_clause const& b) {
312         if (b.first.index() > b.second.index()) {
313             return solver::bin_clause(b.second, b.first);
314         }
315         else {
316             return b;
317         }
318     }
319     /**
320        \brief extract AIGs from clauses.
321        Add the extracted AIGs to pdd_solver.
322        Remove clauses from list that correspond to extracted AIGs
323        Remove binary clauses that correspond to extracted AIGs.
324      */
compile_aigs(clause_vector & clauses,svector<solver::bin_clause> & bins,pdd_solver & ps)325     void anf_simplifier::compile_aigs(clause_vector& clauses, svector<solver::bin_clause>& bins, pdd_solver& ps) {
326         if (!m_config.m_compile_aig) {
327             return;
328         }
329         hashtable<solver::bin_clause, solver::bin_clause_hash, default_eq<solver::bin_clause>> seen_bin;
330 
331         std::function<void(literal head, literal_vector const& tail)> on_aig =
332             [&,this](literal head, literal_vector const& tail) {
333             add_aig(head, tail, ps);
334             for (literal l : tail) {
335                 seen_bin.insert(normalize(solver::bin_clause(~l, head)));
336             }
337             m_stats.m_num_aigs++;
338         };
339         std::function<void(literal head, literal c, literal th, literal el)> on_if =
340             [&,this](literal head, literal c, literal th, literal el) {
341             add_if(head, c, th, el, ps);
342             m_stats.m_num_ifs++;
343         };
344         aig_finder af(s);
345         af.set(on_aig);
346         af.set(on_if);
347         af(clauses);
348 
349         std::function<bool(solver::bin_clause b)> not_seen =
350             [&](solver::bin_clause b) { return !seen_bin.contains(normalize(b)); };
351         bins.filter_update(not_seen);
352     }
353 
354     /**
355        assign levels to variables.
356        use variable id as a primary source for the level of a variable.
357        secondarily, sort variables randomly (each variable is assigned
358        a random, unique, id).
359     */
configure_solver(pdd_solver & ps)360     void anf_simplifier::configure_solver(pdd_solver& ps) {
361         unsigned nv = s.num_vars();
362         unsigned_vector l2v(nv), var2id(nv), id2var(nv);
363         svector<std::pair<unsigned, unsigned>> vl(nv);
364 
365         for (unsigned i = 0; i < nv; ++i) var2id[i] = i;
366         shuffle(var2id.size(), var2id.data(), s.rand());
367         for (unsigned i = 0; i < nv; ++i) id2var[var2id[i]] = i;
368         for (unsigned i = 0; i < nv; ++i) vl[i] = std::make_pair(i, var2id[i]);
369         std::sort(vl.begin(), vl.end());
370         for (unsigned i = 0; i < nv; ++i) l2v[i] = id2var[vl[i].second];
371 
372         ps.get_manager().reset(l2v);
373 
374         // set configuration parameters.
375         dd::solver::config cfg;
376         cfg.m_expr_size_limit = 1000;
377         cfg.m_max_steps = 1000;
378         cfg.m_random_seed = s.rand()();
379         cfg.m_enable_exlin = m_config.m_enable_exlin;
380 
381         unsigned max_num_nodes = 1 << 18;
382         ps.get_manager().set_max_num_nodes(max_num_nodes);
383         ps.set(cfg);
384     }
385 
386 #define lit2pdd(_l_) (_l_.sign() ? ~m.mk_var(_l_.var()) : m.mk_var(_l_.var()))
387 
add_bin(solver::bin_clause const & b,pdd_solver & ps)388     void anf_simplifier::add_bin(solver::bin_clause const& b, pdd_solver& ps) {
389         auto& m = ps.get_manager();
390         dd::pdd p = (lit2pdd(b.first) | lit2pdd(b.second)) ^ true;
391         ps.add(p);
392         TRACE("anf_simplifier", tout << "bin: " << b.first << " " << b.second << " : " << p << "\n";);
393     }
394 
add_clause(clause const & c,pdd_solver & ps)395     void anf_simplifier::add_clause(clause const& c, pdd_solver& ps) {
396         if (c.size() > m_config.m_max_clause_size) return;
397         auto& m = ps.get_manager();
398         dd::pdd p = m.zero();
399         for (literal l : c) p |= lit2pdd(l);
400         p = p ^ true;
401         ps.add(p);
402         TRACE("anf_simplifier", tout << "clause: " << c << " : " << p << "\n";);
403     }
404 
add_xor(literal_vector const & x,pdd_solver & ps)405     void anf_simplifier::add_xor(literal_vector const& x, pdd_solver& ps) {
406         auto& m = ps.get_manager();
407         dd::pdd p = m.one();
408         for (literal l : x) p ^= lit2pdd(l);
409         ps.add(p);
410         TRACE("anf_simplifier", tout << "xor: " << x << " : " << p << "\n";);
411     }
412 
add_aig(literal head,literal_vector const & ands,pdd_solver & ps)413     void anf_simplifier::add_aig(literal head, literal_vector const& ands, pdd_solver& ps) {
414         auto& m = ps.get_manager();
415         dd::pdd q = m.one();
416         for (literal l : ands) q &= lit2pdd(l);
417         dd::pdd p = lit2pdd(head) ^ q;
418         ps.add(p);
419         TRACE("anf_simplifier", tout << "aig: " << head << " == " << ands << " poly : " << p << "\n";);
420     }
421 
add_if(literal head,literal c,literal th,literal el,pdd_solver & ps)422     void anf_simplifier::add_if(literal head, literal c, literal th, literal el, pdd_solver& ps) {
423         auto& m = ps.get_manager();
424         dd::pdd cond = lit2pdd(c);
425         dd::pdd p = lit2pdd(head) ^ (cond & lit2pdd(th)) ^ (~cond & lit2pdd(el));
426         ps.add(p);
427         TRACE("anf_simplifier", tout << "ite: " << head << " == " << c << "?" << th << ":" << el << " poly : " << p << "\n";);
428     }
429 
save_statistics(pdd_solver & solver)430     void anf_simplifier::save_statistics(pdd_solver& solver) {
431         solver.collect_statistics(m_st);
432         m_st.update("sat-anf.units", m_stats.m_num_units);
433         m_st.update("sat-anf.eqs",   m_stats.m_num_eqs);
434         m_st.update("sat-anf.ands",  m_stats.m_num_aigs);
435         m_st.update("sat-anf.ites",  m_stats.m_num_ifs);
436         m_st.update("sat-anf.xors",  m_stats.m_num_xors);
437         m_st.update("sat-anf.phase_flips", m_stats.m_num_phase_flips);
438     }
439 
440 }
441