1 /*********************                                                        */
2 /*! \file extended_rewrite.cpp
3  ** \verbatim
4  ** Top contributors (to current version):
5  **   Andrew Reynolds, Andres Noetzli
6  ** This file is part of the CVC4 project.
7  ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8  ** in the top-level source directory) and their institutional affiliations.
9  ** All rights reserved.  See the file COPYING in the top-level source
10  ** directory for licensing information.\endverbatim
11  **
12  ** \brief Implementation of extended rewriting techniques
13  **/
14 
15 #include "theory/quantifiers/extended_rewrite.h"
16 
17 #include "options/quantifiers_options.h"
18 #include "theory/arith/arith_msum.h"
19 #include "theory/bv/theory_bv_utils.h"
20 #include "theory/datatypes/datatypes_rewriter.h"
21 #include "theory/quantifiers/term_util.h"
22 #include "theory/rewriter.h"
23 #include "theory/strings/theory_strings_rewriter.h"
24 
25 using namespace CVC4::kind;
26 using namespace std;
27 
28 namespace CVC4 {
29 namespace theory {
30 namespace quantifiers {
31 
32 struct ExtRewriteAttributeId
33 {
34 };
35 typedef expr::Attribute<ExtRewriteAttributeId, Node> ExtRewriteAttribute;
36 
ExtendedRewriter(bool aggr)37 ExtendedRewriter::ExtendedRewriter(bool aggr) : d_aggr(aggr)
38 {
39   d_true = NodeManager::currentNM()->mkConst(true);
40   d_false = NodeManager::currentNM()->mkConst(false);
41 }
42 
setCache(Node n,Node ret)43 void ExtendedRewriter::setCache(Node n, Node ret)
44 {
45   ExtRewriteAttribute era;
46   n.setAttribute(era, ret);
47 }
48 
addToChildren(Node nc,std::vector<Node> & children,bool dropDup)49 bool ExtendedRewriter::addToChildren(Node nc,
50                                      std::vector<Node>& children,
51                                      bool dropDup)
52 {
53   // If the operator is non-additive, do not consider duplicates
54   if (dropDup
55       && std::find(children.begin(), children.end(), nc) != children.end())
56   {
57     return false;
58   }
59   children.push_back(nc);
60   return true;
61 }
62 
extendedRewrite(Node n)63 Node ExtendedRewriter::extendedRewrite(Node n)
64 {
65   n = Rewriter::rewrite(n);
66   if (!options::sygusExtRew())
67   {
68     return n;
69   }
70 
71   // has it already been computed?
72   if (n.hasAttribute(ExtRewriteAttribute()))
73   {
74     return n.getAttribute(ExtRewriteAttribute());
75   }
76 
77   Node ret = n;
78   NodeManager* nm = NodeManager::currentNM();
79 
80   //--------------------pre-rewrite
81   if (d_aggr)
82   {
83     Node pre_new_ret;
84     if (ret.getKind() == IMPLIES)
85     {
86       pre_new_ret = nm->mkNode(OR, ret[0].negate(), ret[1]);
87       debugExtendedRewrite(ret, pre_new_ret, "IMPLIES elim");
88     }
89     else if (ret.getKind() == XOR)
90     {
91       pre_new_ret = nm->mkNode(EQUAL, ret[0].negate(), ret[1]);
92       debugExtendedRewrite(ret, pre_new_ret, "XOR elim");
93     }
94     else if (ret.getKind() == NOT)
95     {
96       pre_new_ret = extendedRewriteNnf(ret);
97       debugExtendedRewrite(ret, pre_new_ret, "NNF");
98     }
99     if (!pre_new_ret.isNull())
100     {
101       ret = extendedRewrite(pre_new_ret);
102 
103       Trace("q-ext-rewrite-debug")
104           << "...ext-pre-rewrite : " << n << " -> " << pre_new_ret << std::endl;
105       setCache(n, ret);
106       return ret;
107     }
108   }
109   //--------------------end pre-rewrite
110 
111   //--------------------rewrite children
112   if (n.getNumChildren() > 0)
113   {
114     std::vector<Node> children;
115     if (n.getMetaKind() == metakind::PARAMETERIZED)
116     {
117       children.push_back(n.getOperator());
118     }
119     Kind k = n.getKind();
120     bool childChanged = false;
121     bool isNonAdditive = TermUtil::isNonAdditive(k);
122     // We flatten associative operators below, which requires k to be n-ary.
123     bool isAssoc = TermUtil::isAssoc(k, true);
124     for (unsigned i = 0; i < n.getNumChildren(); i++)
125     {
126       Node nc = extendedRewrite(n[i]);
127       childChanged = nc != n[i] || childChanged;
128       if (isAssoc && nc.getKind() == n.getKind())
129       {
130         for (const Node& ncc : nc)
131         {
132           if (!addToChildren(ncc, children, isNonAdditive))
133           {
134             childChanged = true;
135           }
136         }
137       }
138       else if (!addToChildren(nc, children, isNonAdditive))
139       {
140         childChanged = true;
141       }
142     }
143     Assert(!children.empty());
144     // Some commutative operators have rewriters that are agnostic to order,
145     // thus, we sort here.
146     if (TermUtil::isComm(k) && (d_aggr || children.size() <= 5))
147     {
148       childChanged = true;
149       std::sort(children.begin(), children.end());
150     }
151     if (childChanged)
152     {
153       if (isNonAdditive && children.size() == 1)
154       {
155         // we may have subsumed children down to one
156         ret = children[0];
157       }
158       else
159       {
160         ret = nm->mkNode(k, children);
161       }
162     }
163   }
164   ret = Rewriter::rewrite(ret);
165   //--------------------end rewrite children
166 
167   // now, do extended rewrite
168   Trace("q-ext-rewrite-debug") << "Do extended rewrite on : " << ret
169                                << " (from " << n << ")" << std::endl;
170   Node new_ret;
171 
172   //---------------------- theory-independent post-rewriting
173   if (ret.getKind() == ITE)
174   {
175     new_ret = extendedRewriteIte(ITE, ret);
176   }
177   else if (ret.getKind() == AND || ret.getKind() == OR)
178   {
179     new_ret = extendedRewriteAndOr(ret);
180   }
181   else if (ret.getKind() == EQUAL)
182   {
183     new_ret = extendedRewriteEqChain(EQUAL, AND, OR, NOT, ret);
184     debugExtendedRewrite(ret, new_ret, "Bool eq-chain simplify");
185   }
186   Assert(new_ret.isNull() || new_ret != ret);
187   if (new_ret.isNull() && ret.getKind() != ITE)
188   {
189     // simple ITE pulling
190     new_ret = extendedRewritePullIte(ITE, ret);
191   }
192   //----------------------end theory-independent post-rewriting
193 
194   //----------------------theory-specific post-rewriting
195   if (new_ret.isNull())
196   {
197     TheoryId tid;
198     if (ret.getKind() == ITE)
199     {
200       tid = Theory::theoryOf(ret.getType());
201     }
202     else
203     {
204       tid = Theory::theoryOf(ret);
205     }
206     Trace("q-ext-rewrite-debug") << "theoryOf( " << ret << " )= " << tid
207                                  << std::endl;
208     if (tid == THEORY_ARITH)
209     {
210       new_ret = extendedRewriteArith(ret);
211     }
212     else if (tid == THEORY_STRINGS)
213     {
214       new_ret = extendedRewriteStrings(ret);
215     }
216   }
217   //----------------------end theory-specific post-rewriting
218 
219   //----------------------aggressive rewrites
220   if (new_ret.isNull() && d_aggr)
221   {
222     new_ret = extendedRewriteAggr(ret);
223   }
224   //----------------------end aggressive rewrites
225 
226   setCache(n, ret);
227   if (!new_ret.isNull())
228   {
229     ret = extendedRewrite(new_ret);
230   }
231   Trace("q-ext-rewrite-debug") << "...ext-rewrite : " << n << " -> " << ret
232                                << std::endl;
233   if (Trace.isOn("q-ext-rewrite-nf"))
234   {
235     if (n == ret)
236     {
237       Trace("q-ext-rewrite-nf") << "ext-rew normal form : " << n << std::endl;
238     }
239   }
240   setCache(n, ret);
241   return ret;
242 }
243 
extendedRewriteAggr(Node n)244 Node ExtendedRewriter::extendedRewriteAggr(Node n)
245 {
246   Node new_ret;
247   Trace("q-ext-rewrite-debug2")
248       << "Do aggressive rewrites on " << n << std::endl;
249   bool polarity = n.getKind() != NOT;
250   Node ret_atom = n.getKind() == NOT ? n[0] : n;
251   if ((ret_atom.getKind() == EQUAL && ret_atom[0].getType().isReal())
252       || ret_atom.getKind() == GEQ)
253   {
254     // ITE term removal in polynomials
255     // e.g. ite( x=0, x, y ) = x+1 ---> ( x=0 ^ y = x+1 )
256     Trace("q-ext-rewrite-debug2")
257         << "Compute monomial sum " << ret_atom << std::endl;
258     // compute monomial sum
259     std::map<Node, Node> msum;
260     if (ArithMSum::getMonomialSumLit(ret_atom, msum))
261     {
262       for (std::map<Node, Node>::iterator itm = msum.begin(); itm != msum.end();
263            ++itm)
264       {
265         Node v = itm->first;
266         Trace("q-ext-rewrite-debug2")
267             << itm->first << " * " << itm->second << std::endl;
268         if (v.getKind() == ITE)
269         {
270           Node veq;
271           int res = ArithMSum::isolate(v, msum, veq, ret_atom.getKind());
272           if (res != 0)
273           {
274             Trace("q-ext-rewrite-debug")
275                 << "  have ITE relation, solved form : " << veq << std::endl;
276             // try pulling ITE
277             new_ret = extendedRewritePullIte(ITE, veq);
278             if (!new_ret.isNull())
279             {
280               if (!polarity)
281               {
282                 new_ret = new_ret.negate();
283               }
284               break;
285             }
286           }
287           else
288           {
289             Trace("q-ext-rewrite-debug")
290                 << "  failed to isolate " << v << " in " << n << std::endl;
291           }
292         }
293       }
294     }
295     else
296     {
297       Trace("q-ext-rewrite-debug")
298           << "  failed to get monomial sum of " << n << std::endl;
299     }
300   }
301   // TODO (#1706) : conditional rewriting, condition merging
302   return new_ret;
303 }
304 
extendedRewriteIte(Kind itek,Node n,bool full)305 Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full)
306 {
307   Assert(n.getKind() == itek);
308   Assert(n[1] != n[2]);
309 
310   NodeManager* nm = NodeManager::currentNM();
311 
312   Trace("ext-rew-ite") << "Rewrite ITE : " << n << std::endl;
313 
314   Node flip_cond;
315   if (n[0].getKind() == NOT)
316   {
317     flip_cond = n[0][0];
318   }
319   else if (n[0].getKind() == OR)
320   {
321     // a | b ---> ~( ~a & ~b )
322     flip_cond = TermUtil::simpleNegate(n[0]);
323   }
324   if (!flip_cond.isNull())
325   {
326     Node new_ret = nm->mkNode(ITE, flip_cond, n[2], n[1]);
327     // only print debug trace if full=true
328     if (full)
329     {
330       debugExtendedRewrite(n, new_ret, "ITE flip");
331     }
332     return new_ret;
333   }
334   // Boolean true/false return
335   TypeNode tn = n.getType();
336   if (tn.isBoolean())
337   {
338     for (unsigned i = 1; i <= 2; i++)
339     {
340       if (n[i].isConst())
341       {
342         Node cond = i == 1 ? n[0] : n[0].negate();
343         Node other = n[i == 1 ? 2 : 1];
344         Kind retk = AND;
345         if (n[i].getConst<bool>())
346         {
347           retk = OR;
348         }
349         else
350         {
351           cond = cond.negate();
352         }
353         Node new_ret = nm->mkNode(retk, cond, other);
354         if (full)
355         {
356           // ite( A, true, B ) ---> A V B
357           // ite( A, false, B ) ---> ~A /\ B
358           // ite( A, B,  true ) ---> ~A V B
359           // ite( A, B, false ) ---> A /\ B
360           debugExtendedRewrite(n, new_ret, "ITE const return");
361         }
362         return new_ret;
363       }
364     }
365   }
366 
367   // get entailed equalities in the condition
368   std::vector<Node> eq_conds;
369   Kind ck = n[0].getKind();
370   if (ck == EQUAL)
371   {
372     eq_conds.push_back(n[0]);
373   }
374   else if (ck == AND)
375   {
376     for (const Node& cn : n[0])
377     {
378       if (cn.getKind() == EQUAL)
379       {
380         eq_conds.push_back(cn);
381       }
382     }
383   }
384 
385   Node new_ret;
386   Node b;
387   Node e;
388   Node t1 = n[1];
389   Node t2 = n[2];
390   std::stringstream ss_reason;
391 
392   for (const Node& eq : eq_conds)
393   {
394     // simple invariant ITE
395     for (unsigned i = 0; i <= 1; i++)
396     {
397       // ite( x = y ^ C, y, x ) ---> x
398       // this is subsumed by the rewrites below
399       if (t2 == eq[i] && t1 == eq[1 - i])
400       {
401         new_ret = t2;
402         ss_reason << "ITE simple rev subs";
403         break;
404       }
405     }
406     if (!new_ret.isNull())
407     {
408       break;
409     }
410   }
411   if (new_ret.isNull())
412   {
413     // merging branches
414     for (unsigned i = 1; i <= 2; i++)
415     {
416       if (n[i].getKind() == ITE)
417       {
418         Node no = n[3 - i];
419         for (unsigned j = 1; j <= 2; j++)
420         {
421           if (n[i][j] == no)
422           {
423             // e.g.
424             // ite( C1, ite( C2, t1, t2 ), t1 ) ----> ite( C1 ^ ~C2, t2, t1 )
425             Node nc1 = i == 2 ? n[0].negate() : n[0];
426             Node nc2 = j == 1 ? n[i][0].negate() : n[i][0];
427             Node new_cond = nm->mkNode(AND, nc1, nc2);
428             new_ret = nm->mkNode(ITE, new_cond, n[i][3 - j], no);
429             ss_reason << "ITE merge branch";
430             break;
431           }
432         }
433       }
434       if (!new_ret.isNull())
435       {
436         break;
437       }
438     }
439   }
440 
441   if (new_ret.isNull() && d_aggr)
442   {
443     // If x is less than t based on an ordering, then we use { x -> t } as a
444     // substitution to the children of ite( x = t ^ C, s, t ) below.
445     std::vector<Node> vars;
446     std::vector<Node> subs;
447     inferSubstitution(n[0], vars, subs, true);
448 
449     if (!vars.empty())
450     {
451       // reverse substitution to opposite child
452       // r{ x -> t } = s  implies  ite( x=t ^ C, s, r ) ---> r
453       Node nn =
454           t2.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
455       if (nn != t2)
456       {
457         nn = Rewriter::rewrite(nn);
458         if (nn == t1)
459         {
460           new_ret = t2;
461           ss_reason << "ITE rev subs";
462         }
463       }
464 
465       // ite( x=t ^ C, s, r ) ---> ite( x=t ^ C, s{ x -> t }, r )
466       nn = t1.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
467       if (nn != t1)
468       {
469         // If full=false, then we've duplicated a term u in the children of n.
470         // For example, when ITE pulling, we have n is of the form:
471         //   ite( C, f( u, t1 ), f( u, t2 ) )
472         // We must show that at least one copy of u dissappears in this case.
473         nn = Rewriter::rewrite(nn);
474         if (nn == t2)
475         {
476           new_ret = nn;
477           ss_reason << "ITE subs invariant";
478         }
479         else if (full || nn.isConst())
480         {
481           new_ret = nm->mkNode(itek, n[0], nn, t2);
482           ss_reason << "ITE subs";
483         }
484       }
485     }
486     if (new_ret.isNull())
487     {
488       // ite( C, t, s ) ----> ite( C, t, s { C -> false } )
489       TNode tv = n[0];
490       TNode ts = d_false;
491       Node nn = t2.substitute(tv, ts);
492       if (nn != t2)
493       {
494         nn = Rewriter::rewrite(nn);
495         if (nn == t1)
496         {
497           new_ret = nn;
498           ss_reason << "ITE subs invariant false";
499         }
500         else if (full || nn.isConst())
501         {
502           new_ret = nm->mkNode(itek, n[0], t1, nn);
503           ss_reason << "ITE subs false";
504         }
505       }
506     }
507   }
508 
509   // only print debug trace if full=true
510   if (!new_ret.isNull() && full)
511   {
512     debugExtendedRewrite(n, new_ret, ss_reason.str().c_str());
513   }
514 
515   return new_ret;
516 }
517 
extendedRewriteAndOr(Node n)518 Node ExtendedRewriter::extendedRewriteAndOr(Node n)
519 {
520   // all the below rewrites are aggressive
521   if (!d_aggr)
522   {
523     return Node::null();
524   }
525   Node new_ret;
526   // all kinds are legal to substitute over : hence we give the empty map
527   std::map<Kind, bool> bcp_kinds;
528   new_ret = extendedRewriteBcp(AND, OR, NOT, bcp_kinds, n);
529   if (!new_ret.isNull())
530   {
531     debugExtendedRewrite(n, new_ret, "Bool bcp");
532     return new_ret;
533   }
534   // factoring
535   new_ret = extendedRewriteFactoring(AND, OR, NOT, n);
536   if (!new_ret.isNull())
537   {
538     debugExtendedRewrite(n, new_ret, "Bool factoring");
539     return new_ret;
540   }
541 
542   // equality resolution
543   new_ret = extendedRewriteEqRes(AND, OR, EQUAL, NOT, bcp_kinds, n, false);
544   debugExtendedRewrite(n, new_ret, "Bool eq res");
545   return new_ret;
546 }
547 
extendedRewritePullIte(Kind itek,Node n)548 Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n)
549 {
550   Assert(n.getKind() != ITE);
551   NodeManager* nm = NodeManager::currentNM();
552   TypeNode tn = n.getType();
553   std::vector<Node> children;
554   bool hasOp = (n.getMetaKind() == metakind::PARAMETERIZED);
555   if (hasOp)
556   {
557     children.push_back(n.getOperator());
558   }
559   unsigned nchildren = n.getNumChildren();
560   for (unsigned i = 0; i < nchildren; i++)
561   {
562     children.push_back(n[i]);
563   }
564   std::map<unsigned, std::map<unsigned, Node> > ite_c;
565   for (unsigned i = 0; i < nchildren; i++)
566   {
567     // only pull ITEs apart if we are aggressive
568     if (n[i].getKind() == itek
569         && (d_aggr || (n[i][1].getKind() != ITE && n[i][2].getKind() != ITE)))
570     {
571       unsigned ii = hasOp ? i + 1 : i;
572       for (unsigned j = 0; j < 2; j++)
573       {
574         children[ii] = n[i][j + 1];
575         Node pull = nm->mkNode(n.getKind(), children);
576         Node pullr = Rewriter::rewrite(pull);
577         children[ii] = n[i];
578         ite_c[i][j] = pullr;
579       }
580       if (ite_c[i][0] == ite_c[i][1])
581       {
582         // ITE dual invariance
583         // f( t1..s1..tn ) ---> t  and  f( t1..s2..tn ) ---> t implies
584         // f( t1..ite( A, s1, s2 )..tn ) ---> t
585         debugExtendedRewrite(n, ite_c[i][0], "ITE dual invariant");
586         return ite_c[i][0];
587       }
588       if (d_aggr)
589       {
590         if (nchildren == 2 && (n[1 - i].isVar() || n[1 - i].isConst())
591             && !n[1 - i].getType().isBoolean() && tn.isBoolean())
592         {
593           // always pull variable or constant with binary (theory) predicate
594           // e.g. P( x, ite( A, t1, t2 ) ) ---> ite( A, P( x, t1 ), P( x, t2 ) )
595           Node new_ret = nm->mkNode(ITE, n[i][0], ite_c[i][0], ite_c[i][1]);
596           debugExtendedRewrite(n, new_ret, "ITE pull var predicate");
597           return new_ret;
598         }
599         for (unsigned j = 0; j < 2; j++)
600         {
601           Node pullr = ite_c[i][j];
602           if (pullr.isConst() || pullr == n[i][j + 1])
603           {
604             // ITE single child elimination
605             // f( t1..s1..tn ) ---> t  where t is a constant or s1 itself
606             // implies
607             // f( t1..ite( A, s1, s2 )..tn ) ---> ite( A, t, f( t1..s2..tn ) )
608             Node new_ret;
609             if (tn.isBoolean() && pullr.isConst())
610             {
611               // remove false/true child immediately
612               bool pol = pullr.getConst<bool>();
613               std::vector<Node> new_children;
614               new_children.push_back((j == 0) == pol ? n[i][0]
615                                                      : n[i][0].negate());
616               new_children.push_back(ite_c[i][1 - j]);
617               new_ret = nm->mkNode(pol ? OR : AND, new_children);
618               debugExtendedRewrite(n, new_ret, "ITE Bool single elim");
619             }
620             else
621             {
622               new_ret = nm->mkNode(itek, n[i][0], ite_c[i][0], ite_c[i][1]);
623               debugExtendedRewrite(n, new_ret, "ITE single elim");
624             }
625             return new_ret;
626           }
627         }
628       }
629     }
630   }
631   if (d_aggr)
632   {
633     for (std::pair<const unsigned, std::map<unsigned, Node> >& ip : ite_c)
634     {
635       Node nite = n[ip.first];
636       Assert(nite.getKind() == itek);
637       // now, simply pull the ITE and try ITE rewrites
638       Node pull_ite = nm->mkNode(itek, nite[0], ip.second[0], ip.second[1]);
639       pull_ite = Rewriter::rewrite(pull_ite);
640       if (pull_ite.getKind() == ITE)
641       {
642         Node new_pull_ite = extendedRewriteIte(itek, pull_ite, false);
643         if (!new_pull_ite.isNull())
644         {
645           debugExtendedRewrite(n, new_pull_ite, "ITE pull rewrite");
646           return new_pull_ite;
647         }
648       }
649       else
650       {
651         // A general rewrite could eliminate the ITE by pulling.
652         // An example is:
653         //   ~( ite( C, ~x, ~ite( C, y, x ) ) ) --->
654         //   ite( C, ~~x, ite( C, y, x ) ) --->
655         //   x
656         // where ~ is bitvector negation.
657         debugExtendedRewrite(n, pull_ite, "ITE pull basic elim");
658         return pull_ite;
659       }
660     }
661   }
662 
663   return Node::null();
664 }
665 
extendedRewriteNnf(Node ret)666 Node ExtendedRewriter::extendedRewriteNnf(Node ret)
667 {
668   Assert(ret.getKind() == NOT);
669 
670   Kind nk = ret[0].getKind();
671   bool neg_ch = false;
672   bool neg_ch_1 = false;
673   if (nk == AND || nk == OR)
674   {
675     neg_ch = true;
676     nk = nk == AND ? OR : AND;
677   }
678   else if (nk == IMPLIES)
679   {
680     neg_ch = true;
681     neg_ch_1 = true;
682     nk = AND;
683   }
684   else if (nk == ITE)
685   {
686     neg_ch = true;
687     neg_ch_1 = true;
688   }
689   else if (nk == XOR)
690   {
691     nk = EQUAL;
692   }
693   else if (nk == EQUAL && ret[0][0].getType().isBoolean())
694   {
695     neg_ch_1 = true;
696   }
697   else
698   {
699     return Node::null();
700   }
701 
702   std::vector<Node> new_children;
703   for (unsigned i = 0, nchild = ret[0].getNumChildren(); i < nchild; i++)
704   {
705     Node c = ret[0][i];
706     c = (i == 0 ? neg_ch_1 : false) != neg_ch ? c.negate() : c;
707     new_children.push_back(c);
708   }
709   return NodeManager::currentNM()->mkNode(nk, new_children);
710 }
711 
extendedRewriteBcp(Kind andk,Kind ork,Kind notk,std::map<Kind,bool> & bcp_kinds,Node ret)712 Node ExtendedRewriter::extendedRewriteBcp(
713     Kind andk, Kind ork, Kind notk, std::map<Kind, bool>& bcp_kinds, Node ret)
714 {
715   Kind k = ret.getKind();
716   Assert(k == andk || k == ork);
717   Trace("ext-rew-bcp") << "BCP: **** INPUT: " << ret << std::endl;
718 
719   NodeManager* nm = NodeManager::currentNM();
720 
721   TypeNode tn = ret.getType();
722   Node truen = TermUtil::mkTypeMaxValue(tn);
723   Node falsen = TermUtil::mkTypeValue(tn, 0);
724 
725   // terms to process
726   std::vector<Node> to_process;
727   for (const Node& cn : ret)
728   {
729     to_process.push_back(cn);
730   }
731   // the processing terms
732   std::vector<Node> clauses;
733   // the terms we have propagated information to
734   std::unordered_set<Node, NodeHashFunction> prop_clauses;
735   // the assignment
736   std::map<Node, Node> assign;
737   std::vector<Node> avars;
738   std::vector<Node> asubs;
739 
740   Kind ok = k == andk ? ork : andk;
741   // global polarity : when k=ork, everything is negated
742   bool gpol = k == andk;
743 
744   do
745   {
746     // process the current nodes
747     while (!to_process.empty())
748     {
749       std::vector<Node> new_to_process;
750       for (const Node& cn : to_process)
751       {
752         Trace("ext-rew-bcp-debug") << "process " << cn << std::endl;
753         Kind cnk = cn.getKind();
754         bool pol = cnk != notk;
755         Node cln = cnk == notk ? cn[0] : cn;
756         Assert(cln.getKind() != notk);
757         if ((pol && cln.getKind() == k) || (!pol && cln.getKind() == ok))
758         {
759           // flatten
760           for (const Node& ccln : cln)
761           {
762             Node lccln = pol ? ccln : TermUtil::mkNegate(notk, ccln);
763             new_to_process.push_back(lccln);
764           }
765         }
766         else
767         {
768           // add it to the assignment
769           Node val = gpol == pol ? truen : falsen;
770           std::map<Node, Node>::iterator it = assign.find(cln);
771           Trace("ext-rew-bcp") << "BCP: assign " << cln << " -> " << val
772                                << std::endl;
773           if (it != assign.end())
774           {
775             if (val != it->second)
776             {
777               Trace("ext-rew-bcp") << "BCP: conflict!" << std::endl;
778               // a conflicting assignment: we are done
779               return gpol ? falsen : truen;
780             }
781           }
782           else
783           {
784             assign[cln] = val;
785             avars.push_back(cln);
786             asubs.push_back(val);
787           }
788 
789           // also, treat it as clause if possible
790           if (cln.getNumChildren() > 0
791               && (bcp_kinds.empty()
792                   || bcp_kinds.find(cln.getKind()) != bcp_kinds.end()))
793           {
794             if (std::find(clauses.begin(), clauses.end(), cn) == clauses.end()
795                 && prop_clauses.find(cn) == prop_clauses.end())
796             {
797               Trace("ext-rew-bcp") << "BCP: new clause: " << cn << std::endl;
798               clauses.push_back(cn);
799             }
800           }
801         }
802       }
803       to_process.clear();
804       to_process.insert(
805           to_process.end(), new_to_process.begin(), new_to_process.end());
806     }
807 
808     // apply substitution to all subterms of clauses
809     std::vector<Node> new_clauses;
810     for (const Node& c : clauses)
811     {
812       bool cpol = c.getKind() != notk;
813       Node ca = c.getKind() == notk ? c[0] : c;
814       bool childChanged = false;
815       std::vector<Node> ccs_children;
816       for (const Node& cc : ca)
817       {
818         Node ccs = cc;
819         if (bcp_kinds.empty())
820         {
821           Trace("ext-rew-bcp-debug") << "...do ordinary substitute"
822                                      << std::endl;
823           ccs = cc.substitute(
824               avars.begin(), avars.end(), asubs.begin(), asubs.end());
825         }
826         else
827         {
828           Trace("ext-rew-bcp-debug") << "...do partial substitute" << std::endl;
829           // substitution is only applicable to compatible kinds
830           ccs = partialSubstitute(ccs, assign, bcp_kinds);
831         }
832         childChanged = childChanged || ccs != cc;
833         ccs_children.push_back(ccs);
834       }
835       if (childChanged)
836       {
837         if (ca.getMetaKind() == metakind::PARAMETERIZED)
838         {
839           ccs_children.insert(ccs_children.begin(), ca.getOperator());
840         }
841         Node ccs = nm->mkNode(ca.getKind(), ccs_children);
842         ccs = cpol ? ccs : TermUtil::mkNegate(notk, ccs);
843         Trace("ext-rew-bcp") << "BCP: propagated " << c << " -> " << ccs
844                              << std::endl;
845         ccs = Rewriter::rewrite(ccs);
846         Trace("ext-rew-bcp") << "BCP: rewritten to " << ccs << std::endl;
847         to_process.push_back(ccs);
848         // store this as a node that propagation touched. This marks c so that
849         // it will not be included in the final construction.
850         prop_clauses.insert(ca);
851       }
852       else
853       {
854         new_clauses.push_back(c);
855       }
856     }
857     clauses.clear();
858     clauses.insert(clauses.end(), new_clauses.begin(), new_clauses.end());
859   } while (!to_process.empty());
860 
861   // remake the node
862   if (!prop_clauses.empty())
863   {
864     std::vector<Node> children;
865     for (std::pair<const Node, Node>& l : assign)
866     {
867       Node a = l.first;
868       // if propagation did not touch a
869       if (prop_clauses.find(a) == prop_clauses.end())
870       {
871         Assert(l.second == truen || l.second == falsen);
872         Node ln = (l.second == truen) == gpol ? a : TermUtil::mkNegate(notk, a);
873         children.push_back(ln);
874       }
875     }
876     Node new_ret = children.size() == 1 ? children[0] : nm->mkNode(k, children);
877     Trace("ext-rew-bcp") << "BCP: **** OUTPUT: " << new_ret << std::endl;
878     return new_ret;
879   }
880 
881   return Node::null();
882 }
883 
extendedRewriteFactoring(Kind andk,Kind ork,Kind notk,Node n)884 Node ExtendedRewriter::extendedRewriteFactoring(Kind andk,
885                                                 Kind ork,
886                                                 Kind notk,
887                                                 Node n)
888 {
889   Trace("ext-rew-factoring") << "Factoring: *** INPUT: " << n << std::endl;
890   NodeManager* nm = NodeManager::currentNM();
891 
892   Kind nk = n.getKind();
893   Assert(nk == andk || nk == ork);
894   Kind onk = nk == andk ? ork : andk;
895   // count the number of times atoms occur
896   std::map<Node, std::vector<Node> > lit_to_cl;
897   std::map<Node, std::vector<Node> > cl_to_lits;
898   for (const Node& nc : n)
899   {
900     Kind nck = nc.getKind();
901     if (nck == onk)
902     {
903       for (const Node& ncl : nc)
904       {
905         if (std::find(lit_to_cl[ncl].begin(), lit_to_cl[ncl].end(), nc)
906             == lit_to_cl[ncl].end())
907         {
908           lit_to_cl[ncl].push_back(nc);
909           cl_to_lits[nc].push_back(ncl);
910         }
911       }
912     }
913     else
914     {
915       lit_to_cl[nc].push_back(nc);
916       cl_to_lits[nc].push_back(nc);
917     }
918   }
919   // get the maximum shared literal to factor
920   unsigned max_size = 0;
921   Node flit;
922   for (const std::pair<const Node, std::vector<Node> >& ltc : lit_to_cl)
923   {
924     if (ltc.second.size() > max_size)
925     {
926       max_size = ltc.second.size();
927       flit = ltc.first;
928     }
929   }
930   if (max_size > 1)
931   {
932     // do the factoring
933     std::vector<Node> children;
934     std::vector<Node> fchildren;
935     std::map<Node, std::vector<Node> >::iterator itl = lit_to_cl.find(flit);
936     std::vector<Node>& cls = itl->second;
937     for (const Node& nc : n)
938     {
939       if (std::find(cls.begin(), cls.end(), nc) == cls.end())
940       {
941         children.push_back(nc);
942       }
943       else
944       {
945         // rebuild
946         std::vector<Node>& lits = cl_to_lits[nc];
947         std::vector<Node>::iterator itlfl =
948             std::find(lits.begin(), lits.end(), flit);
949         Assert(itlfl != lits.end());
950         lits.erase(itlfl);
951         // rebuild
952         if (!lits.empty())
953         {
954           Node new_cl = lits.size() == 1 ? lits[0] : nm->mkNode(onk, lits);
955           fchildren.push_back(new_cl);
956         }
957       }
958     }
959     // rebuild the factored children
960     Assert(!fchildren.empty());
961     Node fcn = fchildren.size() == 1 ? fchildren[0] : nm->mkNode(nk, fchildren);
962     children.push_back(nm->mkNode(onk, flit, fcn));
963     Node ret = children.size() == 1 ? children[0] : nm->mkNode(nk, children);
964     Trace("ext-rew-factoring") << "Factoring: *** OUTPUT: " << ret << std::endl;
965     return ret;
966   }
967   else
968   {
969     Trace("ext-rew-factoring") << "Factoring: no change" << std::endl;
970   }
971   return Node::null();
972 }
973 
extendedRewriteEqRes(Kind andk,Kind ork,Kind eqk,Kind notk,std::map<Kind,bool> & bcp_kinds,Node n,bool isXor)974 Node ExtendedRewriter::extendedRewriteEqRes(Kind andk,
975                                             Kind ork,
976                                             Kind eqk,
977                                             Kind notk,
978                                             std::map<Kind, bool>& bcp_kinds,
979                                             Node n,
980                                             bool isXor)
981 {
982   Assert(n.getKind() == andk || n.getKind() == ork);
983   Trace("ext-rew-eqres") << "Eq res: **** INPUT: " << n << std::endl;
984 
985   NodeManager* nm = NodeManager::currentNM();
986   Kind nk = n.getKind();
987   bool gpol = (nk == andk);
988   for (unsigned i = 0, nchild = n.getNumChildren(); i < nchild; i++)
989   {
990     Node lit = n[i];
991     if (lit.getKind() == eqk)
992     {
993       // eq is the equality we are basing a substitution on
994       Node eq;
995       if (gpol == isXor)
996       {
997         // can only turn disequality into equality if types are the same
998         if (lit[1].getType() == lit.getType())
999         {
1000           // t != s ---> ~t = s
1001           if (lit[1].getKind() == notk && lit[0].getKind() != notk)
1002           {
1003             eq = nm->mkNode(EQUAL, lit[0], TermUtil::mkNegate(notk, lit[1]));
1004           }
1005           else
1006           {
1007             eq = nm->mkNode(EQUAL, TermUtil::mkNegate(notk, lit[0]), lit[1]);
1008           }
1009         }
1010       }
1011       else
1012       {
1013         eq = eqk == EQUAL ? lit : nm->mkNode(EQUAL, lit[0], lit[1]);
1014       }
1015       if (!eq.isNull())
1016       {
1017         // see if it corresponds to a substitution
1018         std::vector<Node> vars;
1019         std::vector<Node> subs;
1020         if (inferSubstitution(eq, vars, subs))
1021         {
1022           Assert(vars.size() == 1);
1023           std::vector<Node> children;
1024           bool childrenChanged = false;
1025           // apply to all other children
1026           for (unsigned j = 0; j < nchild; j++)
1027           {
1028             Node ccs = n[j];
1029             if (i != j)
1030             {
1031               if (bcp_kinds.empty())
1032               {
1033                 ccs = ccs.substitute(
1034                     vars.begin(), vars.end(), subs.begin(), subs.end());
1035               }
1036               else
1037               {
1038                 std::map<Node, Node> assign;
1039                 // vars.size()==subs.size()==1
1040                 assign[vars[0]] = subs[0];
1041                 // substitution is only applicable to compatible kinds
1042                 ccs = partialSubstitute(ccs, assign, bcp_kinds);
1043               }
1044               childrenChanged = childrenChanged || n[j] != ccs;
1045             }
1046             children.push_back(ccs);
1047           }
1048           if (childrenChanged)
1049           {
1050             return nm->mkNode(nk, children);
1051           }
1052         }
1053       }
1054     }
1055   }
1056 
1057   return Node::null();
1058 }
1059 
1060 /** sort pairs by their second (unsigned) argument */
sortPairSecond(const std::pair<Node,unsigned> & a,const std::pair<Node,unsigned> & b)1061 static bool sortPairSecond(const std::pair<Node, unsigned>& a,
1062                            const std::pair<Node, unsigned>& b)
1063 {
1064   return (a.second < b.second);
1065 }
1066 
1067 /** A simple subsumption trie used to compute pairwise list subsets */
1068 class SimpSubsumeTrie
1069 {
1070  public:
1071   /** the children of this node */
1072   std::map<Node, SimpSubsumeTrie> d_children;
1073   /** the term at this node */
1074   Node d_data;
1075   /** add term to the trie
1076    *
1077    * This adds term c to this trie, whose atom list is alist. This adds terms
1078    * s to subsumes such that the atom list of s is a subset of the atom list
1079    * of c. For example, say:
1080    *   c1.alist = { A }
1081    *   c2.alist = { C }
1082    *   c3.alist = { B, C }
1083    *   c4.alist = { A, B, D }
1084    *   c5.alist = { A, B, C }
1085    * If these terms are added in the order c1, c2, c3, c4, c5, then:
1086    *   addTerm c1 results in subsumes = {}
1087    *   addTerm c2 results in subsumes = {}
1088    *   addTerm c3 results in subsumes = { c2 }
1089    *   addTerm c4 results in subsumes = { c1 }
1090    *   addTerm c5 results in subsumes = { c1, c2, c3 }
1091    * Notice that the intended use case of this trie is to add term t before t'
1092    * only when size( t.alist ) <= size( t'.alist ).
1093    *
1094    * The last two arguments describe the state of the path [t0...tn] we
1095    * have followed in the trie during the recursive call.
1096    * If doAdd = true,
1097    *   then n+1 = index and alist[1]...alist[n] = t1...tn. If index=alist.size()
1098    *   we add c as the current node of this trie.
1099    * If doAdd = false,
1100    *   then t1...tn occur in alist.
1101    */
addTerm(Node c,std::vector<Node> & alist,std::vector<Node> & subsumes,unsigned index=0,bool doAdd=true)1102   void addTerm(Node c,
1103                std::vector<Node>& alist,
1104                std::vector<Node>& subsumes,
1105                unsigned index = 0,
1106                bool doAdd = true)
1107   {
1108     if (!d_data.isNull())
1109     {
1110       subsumes.push_back(d_data);
1111     }
1112     if (doAdd)
1113     {
1114       if (index == alist.size())
1115       {
1116         d_data = c;
1117         return;
1118       }
1119     }
1120     // try all children where we have this atom
1121     for (std::pair<const Node, SimpSubsumeTrie>& cp : d_children)
1122     {
1123       if (std::find(alist.begin(), alist.end(), cp.first) != alist.end())
1124       {
1125         cp.second.addTerm(c, alist, subsumes, 0, false);
1126       }
1127     }
1128     if (doAdd)
1129     {
1130       d_children[alist[index]].addTerm(c, alist, subsumes, index + 1, doAdd);
1131     }
1132   }
1133 };
1134 
extendedRewriteEqChain(Kind eqk,Kind andk,Kind ork,Kind notk,Node ret,bool isXor)1135 Node ExtendedRewriter::extendedRewriteEqChain(
1136     Kind eqk, Kind andk, Kind ork, Kind notk, Node ret, bool isXor)
1137 {
1138   Assert(ret.getKind() == eqk);
1139 
1140   NodeManager* nm = NodeManager::currentNM();
1141 
1142   TypeNode tn = ret[0].getType();
1143 
1144   // sort/cancelling for Boolean EQUAL/XOR-chains
1145   Trace("ext-rew-eqchain") << "Eq-Chain : " << ret << std::endl;
1146 
1147   // get the children on either side
1148   bool gpol = true;
1149   std::vector<Node> children;
1150   for (unsigned r = 0, size = ret.getNumChildren(); r < size; r++)
1151   {
1152     Node curr = ret[r];
1153     // assume, if necessary, right associative
1154     while (curr.getKind() == eqk && curr[0].getType() == tn)
1155     {
1156       children.push_back(curr[0]);
1157       curr = curr[1];
1158     }
1159     children.push_back(curr);
1160   }
1161 
1162   std::map<Node, bool> cstatus;
1163   // add children to status
1164   for (const Node& c : children)
1165   {
1166     Node a = c;
1167     if (a.getKind() == notk)
1168     {
1169       gpol = !gpol;
1170       a = a[0];
1171     }
1172     Trace("ext-rew-eqchain") << "...child : " << a << std::endl;
1173     std::map<Node, bool>::iterator itc = cstatus.find(a);
1174     if (itc == cstatus.end())
1175     {
1176       cstatus[a] = true;
1177     }
1178     else
1179     {
1180       // cancels
1181       cstatus.erase(a);
1182       if (isXor)
1183       {
1184         gpol = !gpol;
1185       }
1186     }
1187   }
1188   Trace("ext-rew-eqchain") << "Global polarity : " << gpol << std::endl;
1189 
1190   if (cstatus.empty())
1191   {
1192     return TermUtil::mkTypeConst(tn, gpol);
1193   }
1194 
1195   children.clear();
1196 
1197   // compute the atoms of each child
1198   Trace("ext-rew-eqchain") << "eqchain-simplify: begin\n";
1199   Trace("ext-rew-eqchain") << "  eqchain-simplify: get atoms...\n";
1200   std::map<Node, std::map<Node, bool> > atoms;
1201   std::map<Node, std::vector<Node> > alist;
1202   std::vector<std::pair<Node, unsigned> > atom_count;
1203   for (std::pair<const Node, bool>& cp : cstatus)
1204   {
1205     if (!cp.second)
1206     {
1207       // already eliminated
1208       continue;
1209     }
1210     Node c = cp.first;
1211     Kind ck = c.getKind();
1212     if (ck == andk || ck == ork)
1213     {
1214       for (unsigned j = 0, nchild = c.getNumChildren(); j < nchild; j++)
1215       {
1216         Node cl = c[j];
1217         bool pol = cl.getKind() != notk;
1218         Node ca = pol ? cl : cl[0];
1219         Assert(atoms[c].find(ca) == atoms[c].end());
1220         // polarity is flipped when we are AND
1221         atoms[c][ca] = (ck == andk ? !pol : pol);
1222         alist[c].push_back(ca);
1223 
1224         // if this already exists as a child of the equality chain, eliminate.
1225         // this catches cases like ( x & y ) = ( ( x & y ) | z ), where we
1226         // consider ( x & y ) a unit, whereas below it is expanded to
1227         // ~( ~x | ~y ).
1228         std::map<Node, bool>::iterator itc = cstatus.find(ca);
1229         if (itc != cstatus.end() && itc->second)
1230         {
1231           // cancel it
1232           cstatus[ca] = false;
1233           cstatus[c] = false;
1234           // make new child
1235           // x = ( y | ~x ) ---> y & x
1236           // x = ( y | x ) ---> ~y | x
1237           // x = ( y & x ) ---> y | ~x
1238           // x = ( y & ~x ) ---> ~y & ~x
1239           std::vector<Node> new_children;
1240           for (unsigned k = 0, nchild = c.getNumChildren(); k < nchild; k++)
1241           {
1242             if (j != k)
1243             {
1244               new_children.push_back(c[k]);
1245             }
1246           }
1247           Node nc[2];
1248           nc[0] = c[j];
1249           nc[1] = new_children.size() == 1 ? new_children[0]
1250                                            : nm->mkNode(ck, new_children);
1251           // negate the proper child
1252           unsigned nindex = (ck == andk) == pol ? 0 : 1;
1253           nc[nindex] = TermUtil::mkNegate(notk, nc[nindex]);
1254           Kind nk = pol ? ork : andk;
1255           // store as new child
1256           children.push_back(nm->mkNode(nk, nc[0], nc[1]));
1257           if (isXor)
1258           {
1259             gpol = !gpol;
1260           }
1261           break;
1262         }
1263       }
1264     }
1265     else
1266     {
1267       bool pol = ck != notk;
1268       Node ca = pol ? c : c[0];
1269       atoms[c][ca] = pol;
1270       alist[c].push_back(ca);
1271     }
1272     atom_count.push_back(std::pair<Node, unsigned>(c, alist[c].size()));
1273   }
1274   // sort the atoms in each atom list
1275   for (std::map<Node, std::vector<Node> >::iterator it = alist.begin();
1276        it != alist.end();
1277        ++it)
1278   {
1279     std::sort(it->second.begin(), it->second.end());
1280   }
1281   // check subsumptions
1282   // sort by #atoms
1283   std::sort(atom_count.begin(), atom_count.end(), sortPairSecond);
1284   if (Trace.isOn("ext-rew-eqchain"))
1285   {
1286     for (const std::pair<Node, unsigned>& ac : atom_count)
1287     {
1288       Trace("ext-rew-eqchain") << "  eqchain-simplify: " << ac.first << " has "
1289                                << ac.second << " atoms." << std::endl;
1290     }
1291     Trace("ext-rew-eqchain") << "  eqchain-simplify: compute subsumptions...\n";
1292   }
1293   SimpSubsumeTrie sst;
1294   for (std::pair<const Node, bool>& cp : cstatus)
1295   {
1296     if (!cp.second)
1297     {
1298       // already eliminated
1299       continue;
1300     }
1301     Node c = cp.first;
1302     std::map<Node, std::map<Node, bool> >::iterator itc = atoms.find(c);
1303     Assert(itc != atoms.end());
1304     Trace("ext-rew-eqchain") << "  - add term " << c << " with atom list "
1305                              << alist[c] << "...\n";
1306     std::vector<Node> subsumes;
1307     sst.addTerm(c, alist[c], subsumes);
1308     for (const Node& cc : subsumes)
1309     {
1310       if (!cstatus[cc])
1311       {
1312         // subsumes a child that was already eliminated
1313         continue;
1314       }
1315       Trace("ext-rew-eqchain") << "  eqchain-simplify: " << c << " subsumes "
1316                                << cc << std::endl;
1317       // for each of the atoms in cc
1318       std::map<Node, std::map<Node, bool> >::iterator itcc = atoms.find(cc);
1319       Assert(itcc != atoms.end());
1320       std::vector<Node> common_children;
1321       std::vector<Node> diff_children;
1322       for (const std::pair<const Node, bool>& ap : itcc->second)
1323       {
1324         // compare the polarity
1325         Node a = ap.first;
1326         bool polcc = ap.second;
1327         Assert(itc->second.find(a) != itc->second.end());
1328         bool polc = itc->second[a];
1329         Trace("ext-rew-eqchain") << "    eqchain-simplify: atom " << a
1330                                  << " has polarities : " << polc << " " << polcc
1331                                  << "\n";
1332         Node lit = polc ? a : TermUtil::mkNegate(notk, a);
1333         if (polc != polcc)
1334         {
1335           diff_children.push_back(lit);
1336         }
1337         else
1338         {
1339           common_children.push_back(lit);
1340         }
1341       }
1342       std::vector<Node> rem_children;
1343       for (const std::pair<const Node, bool>& ap : itc->second)
1344       {
1345         Node a = ap.first;
1346         if (atoms[cc].find(a) == atoms[cc].end())
1347         {
1348           bool polc = ap.second;
1349           rem_children.push_back(polc ? a : TermUtil::mkNegate(notk, a));
1350         }
1351       }
1352       Trace("ext-rew-eqchain")
1353           << "    #common/diff/rem: " << common_children.size() << "/"
1354           << diff_children.size() << "/" << rem_children.size() << "\n";
1355       bool do_rewrite = false;
1356       if (common_children.empty() && itc->second.size() == itcc->second.size()
1357           && itcc->second.size() == 2)
1358       {
1359         // x | y = ~x | ~y ---> ~( x = y )
1360         do_rewrite = true;
1361         children.push_back(diff_children[0]);
1362         children.push_back(diff_children[1]);
1363         gpol = !gpol;
1364         Trace("ext-rew-eqchain") << "    apply 2-child all-diff\n";
1365       }
1366       else if (common_children.empty() && diff_children.size() == 1)
1367       {
1368         do_rewrite = true;
1369         // x = ( ~x | y ) ---> ~( ~x | ~y )
1370         Node remn = rem_children.size() == 1 ? rem_children[0]
1371                                              : nm->mkNode(ork, rem_children);
1372         remn = TermUtil::mkNegate(notk, remn);
1373         children.push_back(nm->mkNode(ork, diff_children[0], remn));
1374         if (!isXor)
1375         {
1376           gpol = !gpol;
1377         }
1378         Trace("ext-rew-eqchain") << "    apply unit resolution\n";
1379       }
1380       else if (diff_children.size() == 1
1381                && itc->second.size() == itcc->second.size())
1382       {
1383         // ( x | y | z ) = ( x | ~y | z ) ---> ( x | z )
1384         do_rewrite = true;
1385         Assert(!common_children.empty());
1386         Node comn = common_children.size() == 1
1387                         ? common_children[0]
1388                         : nm->mkNode(ork, common_children);
1389         children.push_back(comn);
1390         if (isXor)
1391         {
1392           gpol = !gpol;
1393         }
1394         Trace("ext-rew-eqchain") << "    apply resolution\n";
1395       }
1396       else if (diff_children.empty())
1397       {
1398         do_rewrite = true;
1399         if (rem_children.empty())
1400         {
1401           // x | y = x | y ---> true
1402           // this can happen if we have ( ~x & ~y ) = ( x | y )
1403           children.push_back(TermUtil::mkTypeMaxValue(tn));
1404           if (isXor)
1405           {
1406             gpol = !gpol;
1407           }
1408           Trace("ext-rew-eqchain") << "    apply cancel\n";
1409         }
1410         else
1411         {
1412           // x | y = ( x | y | z ) ---> ( x | y | ~z )
1413           Node remn = rem_children.size() == 1 ? rem_children[0]
1414                                                : nm->mkNode(ork, rem_children);
1415           remn = TermUtil::mkNegate(notk, remn);
1416           Node comn = common_children.size() == 1
1417                           ? common_children[0]
1418                           : nm->mkNode(ork, common_children);
1419           children.push_back(nm->mkNode(ork, comn, remn));
1420           if (isXor)
1421           {
1422             gpol = !gpol;
1423           }
1424           Trace("ext-rew-eqchain") << "    apply subsume\n";
1425         }
1426       }
1427       if (do_rewrite)
1428       {
1429         // eliminate the children, reverse polarity as needed
1430         for (unsigned r = 0; r < 2; r++)
1431         {
1432           Node c_rem = r == 0 ? c : cc;
1433           cstatus[c_rem] = false;
1434           if (c_rem.getKind() == andk)
1435           {
1436             gpol = !gpol;
1437           }
1438         }
1439         break;
1440       }
1441     }
1442   }
1443   Trace("ext-rew-eqchain") << "eqchain-simplify: finish" << std::endl;
1444 
1445   // sorted right associative chain
1446   bool has_nvar = false;
1447   unsigned nvar_index = 0;
1448   for (std::pair<const Node, bool>& cp : cstatus)
1449   {
1450     if (cp.second)
1451     {
1452       if (!cp.first.isVar())
1453       {
1454         has_nvar = true;
1455         nvar_index = children.size();
1456       }
1457       children.push_back(cp.first);
1458     }
1459   }
1460   std::sort(children.begin(), children.end());
1461 
1462   Node new_ret;
1463   if (!gpol)
1464   {
1465     // negate the constant child if it exists
1466     unsigned nindex = has_nvar ? nvar_index : 0;
1467     children[nindex] = TermUtil::mkNegate(notk, children[nindex]);
1468   }
1469   new_ret = children.back();
1470   unsigned index = children.size() - 1;
1471   while (index > 0)
1472   {
1473     index--;
1474     new_ret = nm->mkNode(eqk, children[index], new_ret);
1475   }
1476   new_ret = Rewriter::rewrite(new_ret);
1477   if (new_ret != ret)
1478   {
1479     return new_ret;
1480   }
1481   return Node::null();
1482 }
1483 
partialSubstitute(Node n,std::map<Node,Node> & assign,std::map<Kind,bool> & rkinds)1484 Node ExtendedRewriter::partialSubstitute(Node n,
1485                                          std::map<Node, Node>& assign,
1486                                          std::map<Kind, bool>& rkinds)
1487 {
1488   std::unordered_map<TNode, Node, TNodeHashFunction> visited;
1489   std::unordered_map<TNode, Node, TNodeHashFunction>::iterator it;
1490   std::vector<TNode> visit;
1491   TNode cur;
1492   visit.push_back(n);
1493   do
1494   {
1495     cur = visit.back();
1496     visit.pop_back();
1497     it = visited.find(cur);
1498 
1499     if (it == visited.end())
1500     {
1501       std::map<Node, Node>::iterator it = assign.find(cur);
1502       if (it != assign.end())
1503       {
1504         visited[cur] = it->second;
1505       }
1506       else
1507       {
1508         // can only recurse on these kinds
1509         Kind k = cur.getKind();
1510         if (rkinds.find(k) != rkinds.end())
1511         {
1512           visited[cur] = Node::null();
1513           visit.push_back(cur);
1514           for (const Node& cn : cur)
1515           {
1516             visit.push_back(cn);
1517           }
1518         }
1519         else
1520         {
1521           visited[cur] = cur;
1522         }
1523       }
1524     }
1525     else if (it->second.isNull())
1526     {
1527       Node ret = cur;
1528       bool childChanged = false;
1529       std::vector<Node> children;
1530       if (cur.getMetaKind() == metakind::PARAMETERIZED)
1531       {
1532         children.push_back(cur.getOperator());
1533       }
1534       for (const Node& cn : cur)
1535       {
1536         it = visited.find(cn);
1537         Assert(it != visited.end());
1538         Assert(!it->second.isNull());
1539         childChanged = childChanged || cn != it->second;
1540         children.push_back(it->second);
1541       }
1542       if (childChanged)
1543       {
1544         ret = NodeManager::currentNM()->mkNode(cur.getKind(), children);
1545       }
1546       visited[cur] = ret;
1547     }
1548   } while (!visit.empty());
1549   Assert(visited.find(n) != visited.end());
1550   Assert(!visited.find(n)->second.isNull());
1551   return visited[n];
1552 }
1553 
solveEquality(Node n)1554 Node ExtendedRewriter::solveEquality(Node n)
1555 {
1556   // TODO (#1706) : implement
1557   Assert(n.getKind() == EQUAL);
1558 
1559   return Node::null();
1560 }
1561 
inferSubstitution(Node n,std::vector<Node> & vars,std::vector<Node> & subs,bool usePred)1562 bool ExtendedRewriter::inferSubstitution(Node n,
1563                                          std::vector<Node>& vars,
1564                                          std::vector<Node>& subs,
1565                                          bool usePred)
1566 {
1567   if (n.getKind() == AND)
1568   {
1569     bool ret = false;
1570     for (const Node& nc : n)
1571     {
1572       bool cret = inferSubstitution(nc, vars, subs, usePred);
1573       ret = ret || cret;
1574     }
1575     return ret;
1576   }
1577   if (n.getKind() == EQUAL)
1578   {
1579     // see if it can be put into form x = y
1580     Node slv_eq = solveEquality(n);
1581     if (!slv_eq.isNull())
1582     {
1583       n = slv_eq;
1584     }
1585     Node v[2];
1586     for (unsigned i = 0; i < 2; i++)
1587     {
1588       if (n[i].isConst())
1589       {
1590         vars.push_back(n[1 - i]);
1591         subs.push_back(n[i]);
1592         return true;
1593       }
1594       if (n[i].isVar())
1595       {
1596         v[i] = n[i];
1597       }
1598       else if (TermUtil::isNegate(n[i].getKind()) && n[i][0].isVar())
1599       {
1600         v[i] = n[i][0];
1601       }
1602     }
1603     for (unsigned i = 0; i < 2; i++)
1604     {
1605       TNode r1 = v[i];
1606       Node r2 = v[1 - i];
1607       if (r1.isVar() && ((r2.isVar() && r1 < r2) || r2.isConst()))
1608       {
1609         r2 = n[1 - i];
1610         if (v[i] != n[i])
1611         {
1612           Assert(TermUtil::isNegate(n[i].getKind()));
1613           r2 = TermUtil::mkNegate(n[i].getKind(), r2);
1614         }
1615         // TODO (#1706) : union find
1616         if (std::find(vars.begin(), vars.end(), r1) == vars.end())
1617         {
1618           vars.push_back(r1);
1619           subs.push_back(r2);
1620           return true;
1621         }
1622       }
1623     }
1624   }
1625   if (usePred)
1626   {
1627     bool negated = n.getKind() == NOT;
1628     vars.push_back(negated ? n[0] : n);
1629     subs.push_back(negated ? d_false : d_true);
1630     return true;
1631   }
1632   return false;
1633 }
1634 
extendedRewriteArith(Node ret)1635 Node ExtendedRewriter::extendedRewriteArith(Node ret)
1636 {
1637   Kind k = ret.getKind();
1638   NodeManager* nm = NodeManager::currentNM();
1639   Node new_ret;
1640   if (k == DIVISION || k == INTS_DIVISION || k == INTS_MODULUS)
1641   {
1642     // rewrite as though total
1643     std::vector<Node> children;
1644     bool all_const = true;
1645     for (unsigned i = 0, size = ret.getNumChildren(); i < size; i++)
1646     {
1647       if (ret[i].isConst())
1648       {
1649         children.push_back(ret[i]);
1650       }
1651       else
1652       {
1653         all_const = false;
1654         break;
1655       }
1656     }
1657     if (all_const)
1658     {
1659       Kind new_k = (ret.getKind() == DIVISION ? DIVISION_TOTAL
1660                                               : (ret.getKind() == INTS_DIVISION
1661                                                      ? INTS_DIVISION_TOTAL
1662                                                      : INTS_MODULUS_TOTAL));
1663       new_ret = nm->mkNode(new_k, children);
1664       debugExtendedRewrite(ret, new_ret, "total-interpretation");
1665     }
1666   }
1667   return new_ret;
1668 }
1669 
extendedRewriteStrings(Node ret)1670 Node ExtendedRewriter::extendedRewriteStrings(Node ret)
1671 {
1672   Node new_ret;
1673   Trace("q-ext-rewrite-debug")
1674       << "Extended rewrite strings : " << ret << std::endl;
1675 
1676   if (ret.getKind() == EQUAL)
1677   {
1678     new_ret = strings::TheoryStringsRewriter::rewriteEqualityExt(ret);
1679   }
1680 
1681   return new_ret;
1682 }
1683 
debugExtendedRewrite(Node n,Node ret,const char * c) const1684 void ExtendedRewriter::debugExtendedRewrite(Node n,
1685                                             Node ret,
1686                                             const char* c) const
1687 {
1688   if (Trace.isOn("q-ext-rewrite"))
1689   {
1690     if (!ret.isNull())
1691     {
1692       Trace("q-ext-rewrite-apply") << "sygus-extr : apply " << c << std::endl;
1693       Trace("q-ext-rewrite") << "sygus-extr : " << c << " : " << n
1694                              << " rewrites to " << ret << std::endl;
1695     }
1696   }
1697 }
1698 
1699 } /* CVC4::theory::quantifiers namespace */
1700 } /* CVC4::theory namespace */
1701 } /* CVC4 namespace */
1702