1 /*++
2   Copyright (c) 2020 Microsoft Corporation
3 
4 
5   Abstract:
6 
7     simplification routines for pdd polys
8 
9   Author:
10     Nikolaj Bjorner (nbjorner)
11     Lev Nachmanson (levnach)
12 
13   Notes:
14 
15 
16         Linear Elimination:
17         - comprises of a simplification pass that puts linear equations in to_processed
18         - so before simplifying with respect to the variable ordering, eliminate linear equalities.
19 
20         Extended Linear Simplification (as exploited in Bosphorus AAAI 2019):
21         - multiply each polynomial by one variable from their orbits.
22         - The orbit of a varible are the variables that occur in the same monomial as it in some polynomial.
23         - The extended set of polynomials is fed to a linear Gauss Jordan Eliminator that extracts
24           additional linear equalities.
25         - Bosphorus uses M4RI to perform efficient GJE to scale on large bit-matrices.
26 
27         Long distance vanishing polynomials (used by PolyCleaner ICCAD 2019):
28         - identify polynomials p, q, such that p*q = 0
29         - main case is half-adders and full adders (p := x + y, q := x * y) over GF2
30           because (x+y)*x*y = 0 over GF2
31           To work beyond GF2 we would need to rely on simplification with respect to asserted equalities.
32           The method seems rather specific to hardware multipliers so not clear it is useful to
33           generalize.
34         - find monomials that contain pairs of vanishing polynomials, transitively
35           withtout actually inlining.
36           Then color polynomial variables w by p, resp, q if they occur in polynomial equalities
37           w - r = 0, such that all paths in r contain a node colored by p, resp q.
38           polynomial variables that get colored by both p and q can be set to 0.
39           When some variable gets colored, other variables can be colored.
40         - We can walk pdd nodes by level to perform coloring in a linear sweep.
41           PDD nodes that are equal to 0 using some equality are marked as definitions.
42           First walk definitions to search for vanishing polynomial pairs.
43           Given two definition polynomials d1, d2, it must be the case that
44           level(lo(d1)) = level(lo(d1)) for the polynomial lo(d1)*lo(d2) to be vanishing.
45           Then starting from the lowest level examine pdd nodes.
46           Let the current node be called p, check if the pdd node p is used in an equation
47           w - r = 0. In which case, w inherits the labels from r.
48           Otherwise, label the node by the intersection of vanishing polynomials from lo(p) and hi(p).
49 
50        Eliminating multiplier variables, but not adders [Kaufmann et al FMCAD 2019 for GF2];
51        - Only apply GB saturation with respect to variables that are part of multipliers.
52        - Perhaps this amounts to figuring out whether a variable is used in an xor or more
53 
54   --*/
55 
56 #include "math/grobner/pdd_simplifier.h"
57 #include "math/simplex/bit_matrix.h"
58 
59 namespace dd {
60 
61 
operator ()()62     void simplifier::operator()() {
63         try {
64             while (!s.done() &&
65                    (simplify_linear_step(true) ||
66                     simplify_elim_pure_step() ||
67                     simplify_cc_step() ||
68                     simplify_leaf_step() ||
69                     simplify_linear_step(false) ||
70                     /*simplify_elim_dual_step() ||*/
71                     simplify_exlin() ||
72                     false)) {
73                 DEBUG_CODE(s.invariant(););
74                 TRACE("dd.solver", s.display(tout););
75             }
76         }
77         catch (pdd_manager::mem_out) {
78             IF_VERBOSE(2, verbose_stream() << "simplifier memout\n");
79             // done reduce
80             DEBUG_CODE(s.invariant(););
81         }
82     }
83 
84     struct simplifier::compare_top_var {
operator ()dd::simplifier::compare_top_var85         bool operator()(equation* a, equation* b) const {
86             return a->poly().var() < b->poly().var();
87         }
88     };
89 
simplify_linear_step(bool binary)90     bool simplifier::simplify_linear_step(bool binary) {
91         TRACE("dd.solver", tout << "binary " << binary << "\n";);
92         IF_VERBOSE(2, verbose_stream() << "binary " << binary << "\n");
93         equation_vector linear;
94         for (equation* e : s.m_to_simplify) {
95             pdd p = e->poly();
96             if (binary) {
97                 if (p.is_binary()) linear.push_back(e);
98             }
99             else if (p.is_linear()) {
100                 linear.push_back(e);
101             }
102         }
103         return simplify_linear_step(linear);
104     }
105 
106     /**
107        \brief simplify linear equations by using top variable as solution.
108        The linear equation is moved to set of solved equations.
109     */
simplify_linear_step(equation_vector & linear)110     bool simplifier::simplify_linear_step(equation_vector& linear) {
111         if (linear.empty()) return false;
112         use_list_t use_list = get_use_list();
113         compare_top_var ctv;
114         std::stable_sort(linear.begin(), linear.end(), ctv);
115         equation_vector trivial;
116         unsigned j = 0;
117         bool has_conflict = false;
118         for (equation* src : linear) {
119             if (has_conflict) {
120                 break;
121             }
122             if (s.is_trivial(*src)) {
123                 continue;
124             }
125             unsigned v = src->poly().var();
126             equation_vector const& uses = use_list[v];
127             TRACE("dd.solver",
128                   s.display(tout << "uses of: ", *src) << "\n";
129                   for (equation* e : uses) {
130                       s.display(tout, *e) << "\n";
131                   });
132             bool changed_leading_term;
133             bool all_reduced = true;
134             for (equation* dst : uses) {
135                 if (src == dst || s.is_trivial(*dst)) {
136                     continue;
137                 }
138                 pdd q = dst->poly();
139                 if (!src->poly().is_binary() && !q.is_linear()) {
140                     all_reduced = false;
141                     continue;
142                 }
143                 remove_from_use(dst, use_list, v);
144                 s.simplify_using(*dst, *src, changed_leading_term);
145                 if (s.is_trivial(*dst)) {
146                     trivial.push_back(dst);
147                 }
148                 else if (s.is_conflict(dst)) {
149                     s.pop_equation(dst);
150                     s.set_conflict(dst);
151                     has_conflict = true;
152                 }
153                 else if (changed_leading_term) {
154                     s.pop_equation(dst);
155                     s.push_equation(solver::to_simplify, dst);
156                 }
157                 // v has been eliminated.
158                 // SASSERT(!dst->poly().free_vars().contains(v));
159                 add_to_use(dst, use_list);
160             }
161             if (all_reduced) {
162                 linear[j++] = src;
163             }
164         }
165         if (!has_conflict) {
166             linear.shrink(j);
167             for (equation* src : linear) {
168                 s.pop_equation(src);
169                 s.push_equation(solver::solved, src);
170             }
171         }
172         for (equation* e : trivial) {
173             s.del_equation(e);
174         }
175         DEBUG_CODE(s.invariant(););
176         return j > 0 || has_conflict;
177     }
178 
179     /**
180        \brief simplify using congruences
181        replace pair px + q and ry + q by
182        px + q, px - ry
183        since px = ry
184      */
simplify_cc_step()185     bool simplifier::simplify_cc_step() {
186         TRACE("dd.solver", tout << "cc\n";);
187         IF_VERBOSE(2, verbose_stream() << "cc\n");
188         u_map<equation*> los;
189         bool reduced = false;
190         unsigned j = 0;
191         for (equation* eq1 : s.m_to_simplify) {
192             SASSERT(eq1->state() == solver::to_simplify);
193             pdd p = eq1->poly();
194             equation* eq2 = los.insert_if_not_there(p.lo().index(), eq1);
195             pdd q = eq2->poly();
196             if (eq2 != eq1 && (p.hi().is_val() || q.hi().is_val()) && !p.lo().is_val()) {
197                 *eq1 = p - eq2->poly();
198                 *eq1 = s.m_dep_manager.mk_join(eq1->dep(), eq2->dep());
199                 reduced = true;
200                 if (s.is_trivial(*eq1)) {
201                     s.retire(eq1);
202                     continue;
203                 }
204                 else if (s.check_conflict(*eq1)) {
205                     continue;
206                 }
207             }
208             s.m_to_simplify[j] = eq1;
209             eq1->set_index(j++);
210         }
211         s.m_to_simplify.shrink(j);
212         return reduced;
213     }
214 
215     /**
216        \brief remove ax+b from p if x occurs as a leaf in p and a is a constant.
217     */
simplify_leaf_step()218     bool simplifier::simplify_leaf_step() {
219         TRACE("dd.solver", tout << "leaf\n";);
220         IF_VERBOSE(2, verbose_stream() << "leaf\n");
221         use_list_t use_list = get_use_list();
222         equation_vector leaves;
223         for (unsigned i = 0; i < s.m_to_simplify.size(); ++i) {
224             equation* e = s.m_to_simplify[i];
225             pdd p = e->poly();
226             if (!p.hi().is_val()) {
227                 continue;
228             }
229             leaves.reset();
230             for (equation* e2 : use_list[p.var()]) {
231                 if (e != e2 && e2->poly().var_is_leaf(p.var())) {
232                     leaves.push_back(e2);
233                 }
234             }
235             for (equation* e2 : leaves) {
236                 bool changed_leading_term;
237                 remove_from_use(e2, use_list);
238                 s.simplify_using(*e2, *e, changed_leading_term);
239                 add_to_use(e2, use_list);
240                 if (s.is_trivial(*e2)) {
241                     s.pop_equation(e2);
242                     s.retire(e2);
243                 }
244                 else if (e2->poly().is_val()) {
245                     s.pop_equation(e2);
246                     s.set_conflict(*e2);
247                     return true;
248                 }
249                 else if (changed_leading_term) {
250                     s.pop_equation(e2);
251                     s.push_equation(solver::to_simplify, e2);
252                 }
253             }
254         }
255         return false;
256     }
257 
258     /**
259        \brief treat equations as processed if top variable occurs only once.
260     */
simplify_elim_pure_step()261     bool simplifier::simplify_elim_pure_step() {
262         TRACE("dd.solver", tout << "pure\n";);
263         IF_VERBOSE(2, verbose_stream() << "pure\n");
264         use_list_t use_list = get_use_list();
265         unsigned j = 0;
266         for (equation* e : s.m_to_simplify) {
267             pdd p = e->poly();
268             if (!p.is_val() && p.hi().is_val() && use_list[p.var()].size() == 1) {
269                 s.push_equation(solver::solved, e);
270             }
271             else {
272                 s.m_to_simplify[j] = e;
273                 e->set_index(j++);
274             }
275         }
276         if (j != s.m_to_simplify.size()) {
277             s.m_to_simplify.shrink(j);
278             return true;
279         }
280         return false;
281     }
282 
283     /**
284        \brief
285        reduce equations where top variable occurs only twice and linear in one of the occurrences.
286      */
simplify_elim_dual_step()287     bool simplifier::simplify_elim_dual_step() {
288         use_list_t use_list = get_use_list();
289         unsigned j = 0;
290         bool reduced = false;
291         for (unsigned i = 0; i < s.m_to_simplify.size(); ++i) {
292             equation* e = s.m_to_simplify[i];
293             pdd p = e->poly();
294             // check that e is linear in top variable.
295             if (e->state() != solver::to_simplify) {
296                 reduced = true;
297             }
298             else if (!s.done() && !s.is_trivial(*e) && p.hi().is_val() && use_list[p.var()].size() == 2) {
299                 for (equation* e2 : use_list[p.var()]) {
300                     if (e2 == e) continue;
301                     bool changed_leading_term;
302 
303                     remove_from_use(e2, use_list);
304                     s.simplify_using(*e2, *e, changed_leading_term);
305                     if (s.is_conflict(e2)) {
306                         s.pop_equation(e2);
307                         s.set_conflict(e2);
308                     }
309                     // when e2 is trivial, leading term is changed
310                     SASSERT(!s.is_trivial(*e2) || changed_leading_term);
311                     if (changed_leading_term) {
312                         s.pop_equation(e2);
313                         s.push_equation(solver::to_simplify, e2);
314                     }
315                     add_to_use(e2, use_list);
316                     break;
317                 }
318                 reduced = true;
319                 s.push_equation(solver::solved, e);
320             }
321             else {
322                 s.m_to_simplify[j] = e;
323                 e->set_index(j++);
324             }
325         }
326         if (reduced) {
327             // clean up elements in s.m_to_simplify
328             // they may have moved.
329             s.m_to_simplify.shrink(j);
330             j = 0;
331             for (equation* e : s.m_to_simplify) {
332                 if (s.is_trivial(*e)) {
333                     s.retire(e);
334                 }
335                 else if (e->state() == solver::to_simplify) {
336                     s.m_to_simplify[j] = e;
337                     e->set_index(j++);
338                 }
339             }
340             s.m_to_simplify.shrink(j);
341             return true;
342         }
343         else {
344             return false;
345         }
346     }
347 
add_to_use(equation * e,use_list_t & use_list)348     void simplifier::add_to_use(equation* e, use_list_t& use_list) {
349         unsigned_vector const& fv = e->poly().free_vars();
350         for (unsigned v : fv) {
351             use_list.reserve(v + 1);
352             use_list[v].push_back(e);
353         }
354     }
355 
remove_from_use(equation * e,use_list_t & use_list)356     void simplifier::remove_from_use(equation* e, use_list_t& use_list) {
357         unsigned_vector const& fv = e->poly().free_vars();
358         for (unsigned v : fv) {
359             use_list.reserve(v + 1);
360             use_list[v].erase(e);
361         }
362     }
363 
remove_from_use(equation * e,use_list_t & use_list,unsigned except_v)364     void simplifier::remove_from_use(equation* e, use_list_t& use_list, unsigned except_v) {
365         unsigned_vector const& fv = e->poly().free_vars();
366         for (unsigned v : fv) {
367             if (v != except_v) {
368                 use_list.reserve(v + 1);
369                 use_list[v].erase(e);
370             }
371         }
372     }
373 
get_use_list()374     simplifier::use_list_t simplifier::get_use_list() {
375         use_list_t use_list;
376         for (equation * e : s.m_to_simplify) {
377             add_to_use(e, use_list);
378         }
379         for (equation * e : s.m_processed) {
380             add_to_use(e, use_list);
381         }
382         return use_list;
383     }
384 
385 
386     /**
387        \brief use Gauss elimination to extract linear equalities.
388        So far just for GF(2) semantics.
389      */
390 
simplify_exlin()391     bool simplifier::simplify_exlin() {
392         if (s.m.get_semantics() != pdd_manager::mod2_e ||
393             !s.m_config.m_enable_exlin) {
394             return false;
395         }
396         vector<pdd> eqs, simp_eqs;
397         for (auto* e : s.m_to_simplify) if (!e->dep()) eqs.push_back(e->poly());
398         for (auto* e : s.m_processed) if (!e->dep()) eqs.push_back(e->poly());
399         vector<uint_set> orbits(s.m.num_vars());
400         init_orbits(eqs, orbits);
401         exlin_augment(orbits, eqs);
402         simplify_exlin(orbits, eqs, simp_eqs);
403         for (pdd const& p : simp_eqs) {
404             s.add(p);
405         }
406         IF_VERBOSE(10, verbose_stream() << "simp_linear " << simp_eqs.size() << "\n";);
407         return !simp_eqs.empty() && simplify_linear_step(false);
408     }
409 
init_orbits(vector<pdd> const & eqs,vector<uint_set> & orbits)410     void simplifier::init_orbits(vector<pdd> const& eqs, vector<uint_set>& orbits) {
411         for (pdd const& p : eqs) {
412             auto const& fv = p.free_vars();
413             for (unsigned i = fv.size(); i-- > 0; ) {
414                 orbits[fv[i]].insert(fv[i]); // if v is used, it is in its own orbit.
415                 for (unsigned j = i; j-- > 0; ) {
416                     orbits[fv[i]].insert(fv[j]);
417                     orbits[fv[j]].insert(fv[i]);
418                 }
419             }
420         }
421     }
422 
423 
424     /**
425        augment set of equations by multiplying with selected variables.
426        Uses orbits to prune which variables are multiplied.
427        TBD: could also prune added polynomials based on a maximal degree.
428        TBD: for large systems, extract cluster of polynomials based on sampling orbits
429      */
430 
exlin_augment(vector<uint_set> const & orbits,vector<pdd> & eqs)431     void simplifier::exlin_augment(vector<uint_set> const& orbits, vector<pdd>& eqs) {
432         IF_VERBOSE(10, verbose_stream() << "pdd-exlin augment\n";);
433         unsigned nv = s.m.num_vars();
434         random_gen rand(s.m_config.m_random_seed);
435         unsigned modest_num_eqs = std::max(eqs.size(), 500u);
436         unsigned max_xlin_eqs = modest_num_eqs;
437         unsigned max_degree = 5;
438         TRACE("dd.solver", tout << "augment " << nv << "\n";
439               for (auto const& o : orbits) tout << o.num_elems() << "\n";);
440         vector<pdd> n_eqs;
441         unsigned start = rand();
442         for (unsigned _v = 0; _v < nv; ++_v) {
443             unsigned v = (_v + start) % nv;
444             auto const& orbitv = orbits[v];
445             if (orbitv.empty()) continue;
446             pdd pv = s.m.mk_var(v);
447             for (pdd p : eqs) {
448                 if (p.degree() > max_degree) continue;
449                 for (unsigned w : p.free_vars()) {
450                     if (v != w && orbitv.contains(w)) {
451                         n_eqs.push_back(pv * p);
452                         if (n_eqs.size() > max_xlin_eqs) {
453                             goto end_of_new_eqs;
454                         }
455                         break;
456                     }
457                 }
458             }
459         }
460 
461         start = rand();
462         for (unsigned _v = 0; _v < nv; ++_v) {
463             unsigned v = (_v + start) % nv;
464             auto const& orbitv = orbits[v];
465             if (orbitv.empty()) continue;
466             pdd pv = s.m.mk_var(v);
467             for (unsigned w : orbitv) {
468                 if (v >= w) continue;
469                 pdd pw = s.m.mk_var(w);
470                 for (pdd p : eqs) {
471                     if (p.degree() + 1 > max_degree) continue;
472                     for (unsigned u : p.free_vars()) {
473                         if (orbits[w].contains(u) || orbits[v].contains(u)) {
474                             n_eqs.push_back(pw * pv * p);
475                             if (n_eqs.size() > max_xlin_eqs) {
476                                 goto end_of_new_eqs;
477                             }
478                             break;
479                         }
480                     }
481                 }
482             }
483         }
484     end_of_new_eqs:
485         s.m_config.m_random_seed = rand();
486         eqs.append(n_eqs);
487         TRACE("dd.solver", for (pdd const& p : eqs) tout << p << "\n";);
488     }
489 
simplify_exlin(vector<uint_set> const & orbits,vector<pdd> const & eqs,vector<pdd> & simp_eqs)490     void simplifier::simplify_exlin(vector<uint_set> const& orbits, vector<pdd> const& eqs, vector<pdd>& simp_eqs) {
491         IF_VERBOSE(10, verbose_stream() << "pdd simplify-exlin\n";);
492         // index monomials
493         unsigned_vector vars;
494         struct mon {
495             unsigned sz;
496             unsigned offset;
497             unsigned index;
498             mon(unsigned sz, unsigned offset): sz(sz), offset(offset), index(UINT_MAX) {}
499             mon(): sz(0), offset(0), index(UINT_MAX) {}
500             bool is_valid() const { return index != UINT_MAX; }
501             struct hash {
502                 unsigned_vector& vars;
503                 hash(unsigned_vector& vars):vars(vars) {}
504                 bool operator()(mon const& m) const {
505                     return unsigned_ptr_hash(vars.data() + m.offset, m.sz, 1);
506                 };
507             };
508             struct eq {
509                 unsigned_vector& vars;
510                 eq(unsigned_vector& vars):vars(vars) {}
511                 bool operator()(mon const& a, mon const& b) const {
512                     if (a.sz != b.sz) return false;
513                     for (unsigned i = 0; i < a.sz; ++i)
514                         if (vars[a.offset+i] != vars[b.offset+i])
515                             return false;
516                     return true;
517                 }
518             };
519         };
520         mon::hash mon_hash(vars);
521         mon::eq mon_eq(vars);
522         hashtable<mon, mon::hash, mon::eq> mon2idx(DEFAULT_HASHTABLE_INITIAL_CAPACITY, mon_hash, mon_eq);
523         svector<mon> idx2mon;
524 
525         auto insert_mon = [&](unsigned n, unsigned const* vs) {
526             mon mm(n, vars.size());
527             vars.append(n, vs);
528             auto* e = mon2idx.insert_if_not_there2(mm);
529             if (!e->get_data().is_valid()) {
530                 e->get_data().index = idx2mon.size();
531                 idx2mon.push_back(e->get_data());
532             }
533             else {
534                 vars.shrink(vars.size() - n);
535             }
536         };
537 
538         // insert monomials of degree > 1
539         for (pdd const& p : eqs) {
540             for (auto const& m : p) {
541                 if (m.vars.size() <= 1) continue;
542                 insert_mon(m.vars.size(), m.vars.data());
543             }
544         }
545 
546         // insert variables last.
547         unsigned nv = s.m.num_vars();
548         for (unsigned v = 0; v < nv; ++v) {
549             if (!orbits[v].empty()) { // not all variables are used.
550                 insert_mon(1, &v);
551             }
552         }
553 
554         IF_VERBOSE(10, verbose_stream() << "extracted monomials: " << idx2mon.size() << "\n";);
555 
556 
557         bit_matrix bm;
558         unsigned const_idx = idx2mon.size();
559         bm.reset(const_idx + 1);
560 
561         // populate rows
562         for (pdd const& p : eqs) {
563             if (p.is_zero()) {
564                 continue;
565             }
566             auto row = bm.add_row();
567             for (auto const& m : p) {
568                 SASSERT(m.coeff.is_one());
569                 if (m.vars.empty()) {
570                     row.set(const_idx);
571                     continue;
572                 }
573                 unsigned n = m.vars.size();
574                 mon mm(n, vars.size());
575                 vars.append(n, m.vars.data());
576                 VERIFY(mon2idx.find(mm, mm));
577                 vars.shrink(vars.size() - n);
578                 row.set(mm.index);
579             }
580         }
581 
582         TRACE("dd.solver", tout << bm << "\n";);
583         IF_VERBOSE(10, verbose_stream() << "bit-matrix solving\n");
584 
585         bm.solve();
586 
587         TRACE("dd.solver", tout << bm << "\n";);
588         IF_VERBOSE(10, verbose_stream() << "bit-matrix solved\n");
589 
590         for (auto const& r : bm) {
591             bool is_linear = true;
592             for (unsigned c : r) {
593                 SASSERT(r[c]);
594                 if (c == const_idx) {
595                     break;
596                 }
597                 if (idx2mon[c].sz != 1) {
598                     is_linear = false;
599                     break;
600                 }
601             }
602 
603             if (is_linear) {
604                 pdd p = s.m.zero();
605                 for (unsigned c : r) {
606                     if (c == const_idx) {
607                         p += s.m.one();
608                     }
609                     else {
610                         mon const& mm = idx2mon[c];
611                         p += s.m.mk_var(vars[mm.offset]);
612                     }
613                 }
614                 if (!p.is_zero()) {
615                     TRACE("dd.solver", tout << "new linear: " << p << "\n";);
616                     simp_eqs.push_back(p);
617                 }
618             }
619 
620             // could also consider singleton monomials as Bosphorus does
621             // Singleton monomials are of the form v*w*u*v == 0
622             // Generally want to deal with negations too
623             // v*(w+1)*u will have shared pdd under w,
624             // e.g, test every variable in p whether it has hi() == lo().
625             // maybe easier to read out of a pdd than the expanded form.
626         }
627     }
628 
629 }
630