1 /********************* */
2 /*! \file extended_rewrite.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Andres Noetzli
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Implementation of extended rewriting techniques
13 **/
14
15 #include "theory/quantifiers/extended_rewrite.h"
16
17 #include "options/quantifiers_options.h"
18 #include "theory/arith/arith_msum.h"
19 #include "theory/bv/theory_bv_utils.h"
20 #include "theory/datatypes/datatypes_rewriter.h"
21 #include "theory/quantifiers/term_util.h"
22 #include "theory/rewriter.h"
23 #include "theory/strings/theory_strings_rewriter.h"
24
25 using namespace CVC4::kind;
26 using namespace std;
27
28 namespace CVC4 {
29 namespace theory {
30 namespace quantifiers {
31
32 struct ExtRewriteAttributeId
33 {
34 };
35 typedef expr::Attribute<ExtRewriteAttributeId, Node> ExtRewriteAttribute;
36
ExtendedRewriter(bool aggr)37 ExtendedRewriter::ExtendedRewriter(bool aggr) : d_aggr(aggr)
38 {
39 d_true = NodeManager::currentNM()->mkConst(true);
40 d_false = NodeManager::currentNM()->mkConst(false);
41 }
42
setCache(Node n,Node ret)43 void ExtendedRewriter::setCache(Node n, Node ret)
44 {
45 ExtRewriteAttribute era;
46 n.setAttribute(era, ret);
47 }
48
addToChildren(Node nc,std::vector<Node> & children,bool dropDup)49 bool ExtendedRewriter::addToChildren(Node nc,
50 std::vector<Node>& children,
51 bool dropDup)
52 {
53 // If the operator is non-additive, do not consider duplicates
54 if (dropDup
55 && std::find(children.begin(), children.end(), nc) != children.end())
56 {
57 return false;
58 }
59 children.push_back(nc);
60 return true;
61 }
62
extendedRewrite(Node n)63 Node ExtendedRewriter::extendedRewrite(Node n)
64 {
65 n = Rewriter::rewrite(n);
66 if (!options::sygusExtRew())
67 {
68 return n;
69 }
70
71 // has it already been computed?
72 if (n.hasAttribute(ExtRewriteAttribute()))
73 {
74 return n.getAttribute(ExtRewriteAttribute());
75 }
76
77 Node ret = n;
78 NodeManager* nm = NodeManager::currentNM();
79
80 //--------------------pre-rewrite
81 if (d_aggr)
82 {
83 Node pre_new_ret;
84 if (ret.getKind() == IMPLIES)
85 {
86 pre_new_ret = nm->mkNode(OR, ret[0].negate(), ret[1]);
87 debugExtendedRewrite(ret, pre_new_ret, "IMPLIES elim");
88 }
89 else if (ret.getKind() == XOR)
90 {
91 pre_new_ret = nm->mkNode(EQUAL, ret[0].negate(), ret[1]);
92 debugExtendedRewrite(ret, pre_new_ret, "XOR elim");
93 }
94 else if (ret.getKind() == NOT)
95 {
96 pre_new_ret = extendedRewriteNnf(ret);
97 debugExtendedRewrite(ret, pre_new_ret, "NNF");
98 }
99 if (!pre_new_ret.isNull())
100 {
101 ret = extendedRewrite(pre_new_ret);
102
103 Trace("q-ext-rewrite-debug")
104 << "...ext-pre-rewrite : " << n << " -> " << pre_new_ret << std::endl;
105 setCache(n, ret);
106 return ret;
107 }
108 }
109 //--------------------end pre-rewrite
110
111 //--------------------rewrite children
112 if (n.getNumChildren() > 0)
113 {
114 std::vector<Node> children;
115 if (n.getMetaKind() == metakind::PARAMETERIZED)
116 {
117 children.push_back(n.getOperator());
118 }
119 Kind k = n.getKind();
120 bool childChanged = false;
121 bool isNonAdditive = TermUtil::isNonAdditive(k);
122 // We flatten associative operators below, which requires k to be n-ary.
123 bool isAssoc = TermUtil::isAssoc(k, true);
124 for (unsigned i = 0; i < n.getNumChildren(); i++)
125 {
126 Node nc = extendedRewrite(n[i]);
127 childChanged = nc != n[i] || childChanged;
128 if (isAssoc && nc.getKind() == n.getKind())
129 {
130 for (const Node& ncc : nc)
131 {
132 if (!addToChildren(ncc, children, isNonAdditive))
133 {
134 childChanged = true;
135 }
136 }
137 }
138 else if (!addToChildren(nc, children, isNonAdditive))
139 {
140 childChanged = true;
141 }
142 }
143 Assert(!children.empty());
144 // Some commutative operators have rewriters that are agnostic to order,
145 // thus, we sort here.
146 if (TermUtil::isComm(k) && (d_aggr || children.size() <= 5))
147 {
148 childChanged = true;
149 std::sort(children.begin(), children.end());
150 }
151 if (childChanged)
152 {
153 if (isNonAdditive && children.size() == 1)
154 {
155 // we may have subsumed children down to one
156 ret = children[0];
157 }
158 else
159 {
160 ret = nm->mkNode(k, children);
161 }
162 }
163 }
164 ret = Rewriter::rewrite(ret);
165 //--------------------end rewrite children
166
167 // now, do extended rewrite
168 Trace("q-ext-rewrite-debug") << "Do extended rewrite on : " << ret
169 << " (from " << n << ")" << std::endl;
170 Node new_ret;
171
172 //---------------------- theory-independent post-rewriting
173 if (ret.getKind() == ITE)
174 {
175 new_ret = extendedRewriteIte(ITE, ret);
176 }
177 else if (ret.getKind() == AND || ret.getKind() == OR)
178 {
179 new_ret = extendedRewriteAndOr(ret);
180 }
181 else if (ret.getKind() == EQUAL)
182 {
183 new_ret = extendedRewriteEqChain(EQUAL, AND, OR, NOT, ret);
184 debugExtendedRewrite(ret, new_ret, "Bool eq-chain simplify");
185 }
186 Assert(new_ret.isNull() || new_ret != ret);
187 if (new_ret.isNull() && ret.getKind() != ITE)
188 {
189 // simple ITE pulling
190 new_ret = extendedRewritePullIte(ITE, ret);
191 }
192 //----------------------end theory-independent post-rewriting
193
194 //----------------------theory-specific post-rewriting
195 if (new_ret.isNull())
196 {
197 TheoryId tid;
198 if (ret.getKind() == ITE)
199 {
200 tid = Theory::theoryOf(ret.getType());
201 }
202 else
203 {
204 tid = Theory::theoryOf(ret);
205 }
206 Trace("q-ext-rewrite-debug") << "theoryOf( " << ret << " )= " << tid
207 << std::endl;
208 if (tid == THEORY_ARITH)
209 {
210 new_ret = extendedRewriteArith(ret);
211 }
212 else if (tid == THEORY_STRINGS)
213 {
214 new_ret = extendedRewriteStrings(ret);
215 }
216 }
217 //----------------------end theory-specific post-rewriting
218
219 //----------------------aggressive rewrites
220 if (new_ret.isNull() && d_aggr)
221 {
222 new_ret = extendedRewriteAggr(ret);
223 }
224 //----------------------end aggressive rewrites
225
226 setCache(n, ret);
227 if (!new_ret.isNull())
228 {
229 ret = extendedRewrite(new_ret);
230 }
231 Trace("q-ext-rewrite-debug") << "...ext-rewrite : " << n << " -> " << ret
232 << std::endl;
233 if (Trace.isOn("q-ext-rewrite-nf"))
234 {
235 if (n == ret)
236 {
237 Trace("q-ext-rewrite-nf") << "ext-rew normal form : " << n << std::endl;
238 }
239 }
240 setCache(n, ret);
241 return ret;
242 }
243
extendedRewriteAggr(Node n)244 Node ExtendedRewriter::extendedRewriteAggr(Node n)
245 {
246 Node new_ret;
247 Trace("q-ext-rewrite-debug2")
248 << "Do aggressive rewrites on " << n << std::endl;
249 bool polarity = n.getKind() != NOT;
250 Node ret_atom = n.getKind() == NOT ? n[0] : n;
251 if ((ret_atom.getKind() == EQUAL && ret_atom[0].getType().isReal())
252 || ret_atom.getKind() == GEQ)
253 {
254 // ITE term removal in polynomials
255 // e.g. ite( x=0, x, y ) = x+1 ---> ( x=0 ^ y = x+1 )
256 Trace("q-ext-rewrite-debug2")
257 << "Compute monomial sum " << ret_atom << std::endl;
258 // compute monomial sum
259 std::map<Node, Node> msum;
260 if (ArithMSum::getMonomialSumLit(ret_atom, msum))
261 {
262 for (std::map<Node, Node>::iterator itm = msum.begin(); itm != msum.end();
263 ++itm)
264 {
265 Node v = itm->first;
266 Trace("q-ext-rewrite-debug2")
267 << itm->first << " * " << itm->second << std::endl;
268 if (v.getKind() == ITE)
269 {
270 Node veq;
271 int res = ArithMSum::isolate(v, msum, veq, ret_atom.getKind());
272 if (res != 0)
273 {
274 Trace("q-ext-rewrite-debug")
275 << " have ITE relation, solved form : " << veq << std::endl;
276 // try pulling ITE
277 new_ret = extendedRewritePullIte(ITE, veq);
278 if (!new_ret.isNull())
279 {
280 if (!polarity)
281 {
282 new_ret = new_ret.negate();
283 }
284 break;
285 }
286 }
287 else
288 {
289 Trace("q-ext-rewrite-debug")
290 << " failed to isolate " << v << " in " << n << std::endl;
291 }
292 }
293 }
294 }
295 else
296 {
297 Trace("q-ext-rewrite-debug")
298 << " failed to get monomial sum of " << n << std::endl;
299 }
300 }
301 // TODO (#1706) : conditional rewriting, condition merging
302 return new_ret;
303 }
304
extendedRewriteIte(Kind itek,Node n,bool full)305 Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full)
306 {
307 Assert(n.getKind() == itek);
308 Assert(n[1] != n[2]);
309
310 NodeManager* nm = NodeManager::currentNM();
311
312 Trace("ext-rew-ite") << "Rewrite ITE : " << n << std::endl;
313
314 Node flip_cond;
315 if (n[0].getKind() == NOT)
316 {
317 flip_cond = n[0][0];
318 }
319 else if (n[0].getKind() == OR)
320 {
321 // a | b ---> ~( ~a & ~b )
322 flip_cond = TermUtil::simpleNegate(n[0]);
323 }
324 if (!flip_cond.isNull())
325 {
326 Node new_ret = nm->mkNode(ITE, flip_cond, n[2], n[1]);
327 // only print debug trace if full=true
328 if (full)
329 {
330 debugExtendedRewrite(n, new_ret, "ITE flip");
331 }
332 return new_ret;
333 }
334 // Boolean true/false return
335 TypeNode tn = n.getType();
336 if (tn.isBoolean())
337 {
338 for (unsigned i = 1; i <= 2; i++)
339 {
340 if (n[i].isConst())
341 {
342 Node cond = i == 1 ? n[0] : n[0].negate();
343 Node other = n[i == 1 ? 2 : 1];
344 Kind retk = AND;
345 if (n[i].getConst<bool>())
346 {
347 retk = OR;
348 }
349 else
350 {
351 cond = cond.negate();
352 }
353 Node new_ret = nm->mkNode(retk, cond, other);
354 if (full)
355 {
356 // ite( A, true, B ) ---> A V B
357 // ite( A, false, B ) ---> ~A /\ B
358 // ite( A, B, true ) ---> ~A V B
359 // ite( A, B, false ) ---> A /\ B
360 debugExtendedRewrite(n, new_ret, "ITE const return");
361 }
362 return new_ret;
363 }
364 }
365 }
366
367 // get entailed equalities in the condition
368 std::vector<Node> eq_conds;
369 Kind ck = n[0].getKind();
370 if (ck == EQUAL)
371 {
372 eq_conds.push_back(n[0]);
373 }
374 else if (ck == AND)
375 {
376 for (const Node& cn : n[0])
377 {
378 if (cn.getKind() == EQUAL)
379 {
380 eq_conds.push_back(cn);
381 }
382 }
383 }
384
385 Node new_ret;
386 Node b;
387 Node e;
388 Node t1 = n[1];
389 Node t2 = n[2];
390 std::stringstream ss_reason;
391
392 for (const Node& eq : eq_conds)
393 {
394 // simple invariant ITE
395 for (unsigned i = 0; i <= 1; i++)
396 {
397 // ite( x = y ^ C, y, x ) ---> x
398 // this is subsumed by the rewrites below
399 if (t2 == eq[i] && t1 == eq[1 - i])
400 {
401 new_ret = t2;
402 ss_reason << "ITE simple rev subs";
403 break;
404 }
405 }
406 if (!new_ret.isNull())
407 {
408 break;
409 }
410 }
411 if (new_ret.isNull())
412 {
413 // merging branches
414 for (unsigned i = 1; i <= 2; i++)
415 {
416 if (n[i].getKind() == ITE)
417 {
418 Node no = n[3 - i];
419 for (unsigned j = 1; j <= 2; j++)
420 {
421 if (n[i][j] == no)
422 {
423 // e.g.
424 // ite( C1, ite( C2, t1, t2 ), t1 ) ----> ite( C1 ^ ~C2, t2, t1 )
425 Node nc1 = i == 2 ? n[0].negate() : n[0];
426 Node nc2 = j == 1 ? n[i][0].negate() : n[i][0];
427 Node new_cond = nm->mkNode(AND, nc1, nc2);
428 new_ret = nm->mkNode(ITE, new_cond, n[i][3 - j], no);
429 ss_reason << "ITE merge branch";
430 break;
431 }
432 }
433 }
434 if (!new_ret.isNull())
435 {
436 break;
437 }
438 }
439 }
440
441 if (new_ret.isNull() && d_aggr)
442 {
443 // If x is less than t based on an ordering, then we use { x -> t } as a
444 // substitution to the children of ite( x = t ^ C, s, t ) below.
445 std::vector<Node> vars;
446 std::vector<Node> subs;
447 inferSubstitution(n[0], vars, subs, true);
448
449 if (!vars.empty())
450 {
451 // reverse substitution to opposite child
452 // r{ x -> t } = s implies ite( x=t ^ C, s, r ) ---> r
453 Node nn =
454 t2.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
455 if (nn != t2)
456 {
457 nn = Rewriter::rewrite(nn);
458 if (nn == t1)
459 {
460 new_ret = t2;
461 ss_reason << "ITE rev subs";
462 }
463 }
464
465 // ite( x=t ^ C, s, r ) ---> ite( x=t ^ C, s{ x -> t }, r )
466 nn = t1.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
467 if (nn != t1)
468 {
469 // If full=false, then we've duplicated a term u in the children of n.
470 // For example, when ITE pulling, we have n is of the form:
471 // ite( C, f( u, t1 ), f( u, t2 ) )
472 // We must show that at least one copy of u dissappears in this case.
473 nn = Rewriter::rewrite(nn);
474 if (nn == t2)
475 {
476 new_ret = nn;
477 ss_reason << "ITE subs invariant";
478 }
479 else if (full || nn.isConst())
480 {
481 new_ret = nm->mkNode(itek, n[0], nn, t2);
482 ss_reason << "ITE subs";
483 }
484 }
485 }
486 if (new_ret.isNull())
487 {
488 // ite( C, t, s ) ----> ite( C, t, s { C -> false } )
489 TNode tv = n[0];
490 TNode ts = d_false;
491 Node nn = t2.substitute(tv, ts);
492 if (nn != t2)
493 {
494 nn = Rewriter::rewrite(nn);
495 if (nn == t1)
496 {
497 new_ret = nn;
498 ss_reason << "ITE subs invariant false";
499 }
500 else if (full || nn.isConst())
501 {
502 new_ret = nm->mkNode(itek, n[0], t1, nn);
503 ss_reason << "ITE subs false";
504 }
505 }
506 }
507 }
508
509 // only print debug trace if full=true
510 if (!new_ret.isNull() && full)
511 {
512 debugExtendedRewrite(n, new_ret, ss_reason.str().c_str());
513 }
514
515 return new_ret;
516 }
517
extendedRewriteAndOr(Node n)518 Node ExtendedRewriter::extendedRewriteAndOr(Node n)
519 {
520 // all the below rewrites are aggressive
521 if (!d_aggr)
522 {
523 return Node::null();
524 }
525 Node new_ret;
526 // all kinds are legal to substitute over : hence we give the empty map
527 std::map<Kind, bool> bcp_kinds;
528 new_ret = extendedRewriteBcp(AND, OR, NOT, bcp_kinds, n);
529 if (!new_ret.isNull())
530 {
531 debugExtendedRewrite(n, new_ret, "Bool bcp");
532 return new_ret;
533 }
534 // factoring
535 new_ret = extendedRewriteFactoring(AND, OR, NOT, n);
536 if (!new_ret.isNull())
537 {
538 debugExtendedRewrite(n, new_ret, "Bool factoring");
539 return new_ret;
540 }
541
542 // equality resolution
543 new_ret = extendedRewriteEqRes(AND, OR, EQUAL, NOT, bcp_kinds, n, false);
544 debugExtendedRewrite(n, new_ret, "Bool eq res");
545 return new_ret;
546 }
547
extendedRewritePullIte(Kind itek,Node n)548 Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n)
549 {
550 Assert(n.getKind() != ITE);
551 NodeManager* nm = NodeManager::currentNM();
552 TypeNode tn = n.getType();
553 std::vector<Node> children;
554 bool hasOp = (n.getMetaKind() == metakind::PARAMETERIZED);
555 if (hasOp)
556 {
557 children.push_back(n.getOperator());
558 }
559 unsigned nchildren = n.getNumChildren();
560 for (unsigned i = 0; i < nchildren; i++)
561 {
562 children.push_back(n[i]);
563 }
564 std::map<unsigned, std::map<unsigned, Node> > ite_c;
565 for (unsigned i = 0; i < nchildren; i++)
566 {
567 // only pull ITEs apart if we are aggressive
568 if (n[i].getKind() == itek
569 && (d_aggr || (n[i][1].getKind() != ITE && n[i][2].getKind() != ITE)))
570 {
571 unsigned ii = hasOp ? i + 1 : i;
572 for (unsigned j = 0; j < 2; j++)
573 {
574 children[ii] = n[i][j + 1];
575 Node pull = nm->mkNode(n.getKind(), children);
576 Node pullr = Rewriter::rewrite(pull);
577 children[ii] = n[i];
578 ite_c[i][j] = pullr;
579 }
580 if (ite_c[i][0] == ite_c[i][1])
581 {
582 // ITE dual invariance
583 // f( t1..s1..tn ) ---> t and f( t1..s2..tn ) ---> t implies
584 // f( t1..ite( A, s1, s2 )..tn ) ---> t
585 debugExtendedRewrite(n, ite_c[i][0], "ITE dual invariant");
586 return ite_c[i][0];
587 }
588 if (d_aggr)
589 {
590 if (nchildren == 2 && (n[1 - i].isVar() || n[1 - i].isConst())
591 && !n[1 - i].getType().isBoolean() && tn.isBoolean())
592 {
593 // always pull variable or constant with binary (theory) predicate
594 // e.g. P( x, ite( A, t1, t2 ) ) ---> ite( A, P( x, t1 ), P( x, t2 ) )
595 Node new_ret = nm->mkNode(ITE, n[i][0], ite_c[i][0], ite_c[i][1]);
596 debugExtendedRewrite(n, new_ret, "ITE pull var predicate");
597 return new_ret;
598 }
599 for (unsigned j = 0; j < 2; j++)
600 {
601 Node pullr = ite_c[i][j];
602 if (pullr.isConst() || pullr == n[i][j + 1])
603 {
604 // ITE single child elimination
605 // f( t1..s1..tn ) ---> t where t is a constant or s1 itself
606 // implies
607 // f( t1..ite( A, s1, s2 )..tn ) ---> ite( A, t, f( t1..s2..tn ) )
608 Node new_ret;
609 if (tn.isBoolean() && pullr.isConst())
610 {
611 // remove false/true child immediately
612 bool pol = pullr.getConst<bool>();
613 std::vector<Node> new_children;
614 new_children.push_back((j == 0) == pol ? n[i][0]
615 : n[i][0].negate());
616 new_children.push_back(ite_c[i][1 - j]);
617 new_ret = nm->mkNode(pol ? OR : AND, new_children);
618 debugExtendedRewrite(n, new_ret, "ITE Bool single elim");
619 }
620 else
621 {
622 new_ret = nm->mkNode(itek, n[i][0], ite_c[i][0], ite_c[i][1]);
623 debugExtendedRewrite(n, new_ret, "ITE single elim");
624 }
625 return new_ret;
626 }
627 }
628 }
629 }
630 }
631 if (d_aggr)
632 {
633 for (std::pair<const unsigned, std::map<unsigned, Node> >& ip : ite_c)
634 {
635 Node nite = n[ip.first];
636 Assert(nite.getKind() == itek);
637 // now, simply pull the ITE and try ITE rewrites
638 Node pull_ite = nm->mkNode(itek, nite[0], ip.second[0], ip.second[1]);
639 pull_ite = Rewriter::rewrite(pull_ite);
640 if (pull_ite.getKind() == ITE)
641 {
642 Node new_pull_ite = extendedRewriteIte(itek, pull_ite, false);
643 if (!new_pull_ite.isNull())
644 {
645 debugExtendedRewrite(n, new_pull_ite, "ITE pull rewrite");
646 return new_pull_ite;
647 }
648 }
649 else
650 {
651 // A general rewrite could eliminate the ITE by pulling.
652 // An example is:
653 // ~( ite( C, ~x, ~ite( C, y, x ) ) ) --->
654 // ite( C, ~~x, ite( C, y, x ) ) --->
655 // x
656 // where ~ is bitvector negation.
657 debugExtendedRewrite(n, pull_ite, "ITE pull basic elim");
658 return pull_ite;
659 }
660 }
661 }
662
663 return Node::null();
664 }
665
extendedRewriteNnf(Node ret)666 Node ExtendedRewriter::extendedRewriteNnf(Node ret)
667 {
668 Assert(ret.getKind() == NOT);
669
670 Kind nk = ret[0].getKind();
671 bool neg_ch = false;
672 bool neg_ch_1 = false;
673 if (nk == AND || nk == OR)
674 {
675 neg_ch = true;
676 nk = nk == AND ? OR : AND;
677 }
678 else if (nk == IMPLIES)
679 {
680 neg_ch = true;
681 neg_ch_1 = true;
682 nk = AND;
683 }
684 else if (nk == ITE)
685 {
686 neg_ch = true;
687 neg_ch_1 = true;
688 }
689 else if (nk == XOR)
690 {
691 nk = EQUAL;
692 }
693 else if (nk == EQUAL && ret[0][0].getType().isBoolean())
694 {
695 neg_ch_1 = true;
696 }
697 else
698 {
699 return Node::null();
700 }
701
702 std::vector<Node> new_children;
703 for (unsigned i = 0, nchild = ret[0].getNumChildren(); i < nchild; i++)
704 {
705 Node c = ret[0][i];
706 c = (i == 0 ? neg_ch_1 : false) != neg_ch ? c.negate() : c;
707 new_children.push_back(c);
708 }
709 return NodeManager::currentNM()->mkNode(nk, new_children);
710 }
711
extendedRewriteBcp(Kind andk,Kind ork,Kind notk,std::map<Kind,bool> & bcp_kinds,Node ret)712 Node ExtendedRewriter::extendedRewriteBcp(
713 Kind andk, Kind ork, Kind notk, std::map<Kind, bool>& bcp_kinds, Node ret)
714 {
715 Kind k = ret.getKind();
716 Assert(k == andk || k == ork);
717 Trace("ext-rew-bcp") << "BCP: **** INPUT: " << ret << std::endl;
718
719 NodeManager* nm = NodeManager::currentNM();
720
721 TypeNode tn = ret.getType();
722 Node truen = TermUtil::mkTypeMaxValue(tn);
723 Node falsen = TermUtil::mkTypeValue(tn, 0);
724
725 // terms to process
726 std::vector<Node> to_process;
727 for (const Node& cn : ret)
728 {
729 to_process.push_back(cn);
730 }
731 // the processing terms
732 std::vector<Node> clauses;
733 // the terms we have propagated information to
734 std::unordered_set<Node, NodeHashFunction> prop_clauses;
735 // the assignment
736 std::map<Node, Node> assign;
737 std::vector<Node> avars;
738 std::vector<Node> asubs;
739
740 Kind ok = k == andk ? ork : andk;
741 // global polarity : when k=ork, everything is negated
742 bool gpol = k == andk;
743
744 do
745 {
746 // process the current nodes
747 while (!to_process.empty())
748 {
749 std::vector<Node> new_to_process;
750 for (const Node& cn : to_process)
751 {
752 Trace("ext-rew-bcp-debug") << "process " << cn << std::endl;
753 Kind cnk = cn.getKind();
754 bool pol = cnk != notk;
755 Node cln = cnk == notk ? cn[0] : cn;
756 Assert(cln.getKind() != notk);
757 if ((pol && cln.getKind() == k) || (!pol && cln.getKind() == ok))
758 {
759 // flatten
760 for (const Node& ccln : cln)
761 {
762 Node lccln = pol ? ccln : TermUtil::mkNegate(notk, ccln);
763 new_to_process.push_back(lccln);
764 }
765 }
766 else
767 {
768 // add it to the assignment
769 Node val = gpol == pol ? truen : falsen;
770 std::map<Node, Node>::iterator it = assign.find(cln);
771 Trace("ext-rew-bcp") << "BCP: assign " << cln << " -> " << val
772 << std::endl;
773 if (it != assign.end())
774 {
775 if (val != it->second)
776 {
777 Trace("ext-rew-bcp") << "BCP: conflict!" << std::endl;
778 // a conflicting assignment: we are done
779 return gpol ? falsen : truen;
780 }
781 }
782 else
783 {
784 assign[cln] = val;
785 avars.push_back(cln);
786 asubs.push_back(val);
787 }
788
789 // also, treat it as clause if possible
790 if (cln.getNumChildren() > 0
791 && (bcp_kinds.empty()
792 || bcp_kinds.find(cln.getKind()) != bcp_kinds.end()))
793 {
794 if (std::find(clauses.begin(), clauses.end(), cn) == clauses.end()
795 && prop_clauses.find(cn) == prop_clauses.end())
796 {
797 Trace("ext-rew-bcp") << "BCP: new clause: " << cn << std::endl;
798 clauses.push_back(cn);
799 }
800 }
801 }
802 }
803 to_process.clear();
804 to_process.insert(
805 to_process.end(), new_to_process.begin(), new_to_process.end());
806 }
807
808 // apply substitution to all subterms of clauses
809 std::vector<Node> new_clauses;
810 for (const Node& c : clauses)
811 {
812 bool cpol = c.getKind() != notk;
813 Node ca = c.getKind() == notk ? c[0] : c;
814 bool childChanged = false;
815 std::vector<Node> ccs_children;
816 for (const Node& cc : ca)
817 {
818 Node ccs = cc;
819 if (bcp_kinds.empty())
820 {
821 Trace("ext-rew-bcp-debug") << "...do ordinary substitute"
822 << std::endl;
823 ccs = cc.substitute(
824 avars.begin(), avars.end(), asubs.begin(), asubs.end());
825 }
826 else
827 {
828 Trace("ext-rew-bcp-debug") << "...do partial substitute" << std::endl;
829 // substitution is only applicable to compatible kinds
830 ccs = partialSubstitute(ccs, assign, bcp_kinds);
831 }
832 childChanged = childChanged || ccs != cc;
833 ccs_children.push_back(ccs);
834 }
835 if (childChanged)
836 {
837 if (ca.getMetaKind() == metakind::PARAMETERIZED)
838 {
839 ccs_children.insert(ccs_children.begin(), ca.getOperator());
840 }
841 Node ccs = nm->mkNode(ca.getKind(), ccs_children);
842 ccs = cpol ? ccs : TermUtil::mkNegate(notk, ccs);
843 Trace("ext-rew-bcp") << "BCP: propagated " << c << " -> " << ccs
844 << std::endl;
845 ccs = Rewriter::rewrite(ccs);
846 Trace("ext-rew-bcp") << "BCP: rewritten to " << ccs << std::endl;
847 to_process.push_back(ccs);
848 // store this as a node that propagation touched. This marks c so that
849 // it will not be included in the final construction.
850 prop_clauses.insert(ca);
851 }
852 else
853 {
854 new_clauses.push_back(c);
855 }
856 }
857 clauses.clear();
858 clauses.insert(clauses.end(), new_clauses.begin(), new_clauses.end());
859 } while (!to_process.empty());
860
861 // remake the node
862 if (!prop_clauses.empty())
863 {
864 std::vector<Node> children;
865 for (std::pair<const Node, Node>& l : assign)
866 {
867 Node a = l.first;
868 // if propagation did not touch a
869 if (prop_clauses.find(a) == prop_clauses.end())
870 {
871 Assert(l.second == truen || l.second == falsen);
872 Node ln = (l.second == truen) == gpol ? a : TermUtil::mkNegate(notk, a);
873 children.push_back(ln);
874 }
875 }
876 Node new_ret = children.size() == 1 ? children[0] : nm->mkNode(k, children);
877 Trace("ext-rew-bcp") << "BCP: **** OUTPUT: " << new_ret << std::endl;
878 return new_ret;
879 }
880
881 return Node::null();
882 }
883
extendedRewriteFactoring(Kind andk,Kind ork,Kind notk,Node n)884 Node ExtendedRewriter::extendedRewriteFactoring(Kind andk,
885 Kind ork,
886 Kind notk,
887 Node n)
888 {
889 Trace("ext-rew-factoring") << "Factoring: *** INPUT: " << n << std::endl;
890 NodeManager* nm = NodeManager::currentNM();
891
892 Kind nk = n.getKind();
893 Assert(nk == andk || nk == ork);
894 Kind onk = nk == andk ? ork : andk;
895 // count the number of times atoms occur
896 std::map<Node, std::vector<Node> > lit_to_cl;
897 std::map<Node, std::vector<Node> > cl_to_lits;
898 for (const Node& nc : n)
899 {
900 Kind nck = nc.getKind();
901 if (nck == onk)
902 {
903 for (const Node& ncl : nc)
904 {
905 if (std::find(lit_to_cl[ncl].begin(), lit_to_cl[ncl].end(), nc)
906 == lit_to_cl[ncl].end())
907 {
908 lit_to_cl[ncl].push_back(nc);
909 cl_to_lits[nc].push_back(ncl);
910 }
911 }
912 }
913 else
914 {
915 lit_to_cl[nc].push_back(nc);
916 cl_to_lits[nc].push_back(nc);
917 }
918 }
919 // get the maximum shared literal to factor
920 unsigned max_size = 0;
921 Node flit;
922 for (const std::pair<const Node, std::vector<Node> >& ltc : lit_to_cl)
923 {
924 if (ltc.second.size() > max_size)
925 {
926 max_size = ltc.second.size();
927 flit = ltc.first;
928 }
929 }
930 if (max_size > 1)
931 {
932 // do the factoring
933 std::vector<Node> children;
934 std::vector<Node> fchildren;
935 std::map<Node, std::vector<Node> >::iterator itl = lit_to_cl.find(flit);
936 std::vector<Node>& cls = itl->second;
937 for (const Node& nc : n)
938 {
939 if (std::find(cls.begin(), cls.end(), nc) == cls.end())
940 {
941 children.push_back(nc);
942 }
943 else
944 {
945 // rebuild
946 std::vector<Node>& lits = cl_to_lits[nc];
947 std::vector<Node>::iterator itlfl =
948 std::find(lits.begin(), lits.end(), flit);
949 Assert(itlfl != lits.end());
950 lits.erase(itlfl);
951 // rebuild
952 if (!lits.empty())
953 {
954 Node new_cl = lits.size() == 1 ? lits[0] : nm->mkNode(onk, lits);
955 fchildren.push_back(new_cl);
956 }
957 }
958 }
959 // rebuild the factored children
960 Assert(!fchildren.empty());
961 Node fcn = fchildren.size() == 1 ? fchildren[0] : nm->mkNode(nk, fchildren);
962 children.push_back(nm->mkNode(onk, flit, fcn));
963 Node ret = children.size() == 1 ? children[0] : nm->mkNode(nk, children);
964 Trace("ext-rew-factoring") << "Factoring: *** OUTPUT: " << ret << std::endl;
965 return ret;
966 }
967 else
968 {
969 Trace("ext-rew-factoring") << "Factoring: no change" << std::endl;
970 }
971 return Node::null();
972 }
973
extendedRewriteEqRes(Kind andk,Kind ork,Kind eqk,Kind notk,std::map<Kind,bool> & bcp_kinds,Node n,bool isXor)974 Node ExtendedRewriter::extendedRewriteEqRes(Kind andk,
975 Kind ork,
976 Kind eqk,
977 Kind notk,
978 std::map<Kind, bool>& bcp_kinds,
979 Node n,
980 bool isXor)
981 {
982 Assert(n.getKind() == andk || n.getKind() == ork);
983 Trace("ext-rew-eqres") << "Eq res: **** INPUT: " << n << std::endl;
984
985 NodeManager* nm = NodeManager::currentNM();
986 Kind nk = n.getKind();
987 bool gpol = (nk == andk);
988 for (unsigned i = 0, nchild = n.getNumChildren(); i < nchild; i++)
989 {
990 Node lit = n[i];
991 if (lit.getKind() == eqk)
992 {
993 // eq is the equality we are basing a substitution on
994 Node eq;
995 if (gpol == isXor)
996 {
997 // can only turn disequality into equality if types are the same
998 if (lit[1].getType() == lit.getType())
999 {
1000 // t != s ---> ~t = s
1001 if (lit[1].getKind() == notk && lit[0].getKind() != notk)
1002 {
1003 eq = nm->mkNode(EQUAL, lit[0], TermUtil::mkNegate(notk, lit[1]));
1004 }
1005 else
1006 {
1007 eq = nm->mkNode(EQUAL, TermUtil::mkNegate(notk, lit[0]), lit[1]);
1008 }
1009 }
1010 }
1011 else
1012 {
1013 eq = eqk == EQUAL ? lit : nm->mkNode(EQUAL, lit[0], lit[1]);
1014 }
1015 if (!eq.isNull())
1016 {
1017 // see if it corresponds to a substitution
1018 std::vector<Node> vars;
1019 std::vector<Node> subs;
1020 if (inferSubstitution(eq, vars, subs))
1021 {
1022 Assert(vars.size() == 1);
1023 std::vector<Node> children;
1024 bool childrenChanged = false;
1025 // apply to all other children
1026 for (unsigned j = 0; j < nchild; j++)
1027 {
1028 Node ccs = n[j];
1029 if (i != j)
1030 {
1031 if (bcp_kinds.empty())
1032 {
1033 ccs = ccs.substitute(
1034 vars.begin(), vars.end(), subs.begin(), subs.end());
1035 }
1036 else
1037 {
1038 std::map<Node, Node> assign;
1039 // vars.size()==subs.size()==1
1040 assign[vars[0]] = subs[0];
1041 // substitution is only applicable to compatible kinds
1042 ccs = partialSubstitute(ccs, assign, bcp_kinds);
1043 }
1044 childrenChanged = childrenChanged || n[j] != ccs;
1045 }
1046 children.push_back(ccs);
1047 }
1048 if (childrenChanged)
1049 {
1050 return nm->mkNode(nk, children);
1051 }
1052 }
1053 }
1054 }
1055 }
1056
1057 return Node::null();
1058 }
1059
1060 /** sort pairs by their second (unsigned) argument */
sortPairSecond(const std::pair<Node,unsigned> & a,const std::pair<Node,unsigned> & b)1061 static bool sortPairSecond(const std::pair<Node, unsigned>& a,
1062 const std::pair<Node, unsigned>& b)
1063 {
1064 return (a.second < b.second);
1065 }
1066
1067 /** A simple subsumption trie used to compute pairwise list subsets */
1068 class SimpSubsumeTrie
1069 {
1070 public:
1071 /** the children of this node */
1072 std::map<Node, SimpSubsumeTrie> d_children;
1073 /** the term at this node */
1074 Node d_data;
1075 /** add term to the trie
1076 *
1077 * This adds term c to this trie, whose atom list is alist. This adds terms
1078 * s to subsumes such that the atom list of s is a subset of the atom list
1079 * of c. For example, say:
1080 * c1.alist = { A }
1081 * c2.alist = { C }
1082 * c3.alist = { B, C }
1083 * c4.alist = { A, B, D }
1084 * c5.alist = { A, B, C }
1085 * If these terms are added in the order c1, c2, c3, c4, c5, then:
1086 * addTerm c1 results in subsumes = {}
1087 * addTerm c2 results in subsumes = {}
1088 * addTerm c3 results in subsumes = { c2 }
1089 * addTerm c4 results in subsumes = { c1 }
1090 * addTerm c5 results in subsumes = { c1, c2, c3 }
1091 * Notice that the intended use case of this trie is to add term t before t'
1092 * only when size( t.alist ) <= size( t'.alist ).
1093 *
1094 * The last two arguments describe the state of the path [t0...tn] we
1095 * have followed in the trie during the recursive call.
1096 * If doAdd = true,
1097 * then n+1 = index and alist[1]...alist[n] = t1...tn. If index=alist.size()
1098 * we add c as the current node of this trie.
1099 * If doAdd = false,
1100 * then t1...tn occur in alist.
1101 */
addTerm(Node c,std::vector<Node> & alist,std::vector<Node> & subsumes,unsigned index=0,bool doAdd=true)1102 void addTerm(Node c,
1103 std::vector<Node>& alist,
1104 std::vector<Node>& subsumes,
1105 unsigned index = 0,
1106 bool doAdd = true)
1107 {
1108 if (!d_data.isNull())
1109 {
1110 subsumes.push_back(d_data);
1111 }
1112 if (doAdd)
1113 {
1114 if (index == alist.size())
1115 {
1116 d_data = c;
1117 return;
1118 }
1119 }
1120 // try all children where we have this atom
1121 for (std::pair<const Node, SimpSubsumeTrie>& cp : d_children)
1122 {
1123 if (std::find(alist.begin(), alist.end(), cp.first) != alist.end())
1124 {
1125 cp.second.addTerm(c, alist, subsumes, 0, false);
1126 }
1127 }
1128 if (doAdd)
1129 {
1130 d_children[alist[index]].addTerm(c, alist, subsumes, index + 1, doAdd);
1131 }
1132 }
1133 };
1134
extendedRewriteEqChain(Kind eqk,Kind andk,Kind ork,Kind notk,Node ret,bool isXor)1135 Node ExtendedRewriter::extendedRewriteEqChain(
1136 Kind eqk, Kind andk, Kind ork, Kind notk, Node ret, bool isXor)
1137 {
1138 Assert(ret.getKind() == eqk);
1139
1140 NodeManager* nm = NodeManager::currentNM();
1141
1142 TypeNode tn = ret[0].getType();
1143
1144 // sort/cancelling for Boolean EQUAL/XOR-chains
1145 Trace("ext-rew-eqchain") << "Eq-Chain : " << ret << std::endl;
1146
1147 // get the children on either side
1148 bool gpol = true;
1149 std::vector<Node> children;
1150 for (unsigned r = 0, size = ret.getNumChildren(); r < size; r++)
1151 {
1152 Node curr = ret[r];
1153 // assume, if necessary, right associative
1154 while (curr.getKind() == eqk && curr[0].getType() == tn)
1155 {
1156 children.push_back(curr[0]);
1157 curr = curr[1];
1158 }
1159 children.push_back(curr);
1160 }
1161
1162 std::map<Node, bool> cstatus;
1163 // add children to status
1164 for (const Node& c : children)
1165 {
1166 Node a = c;
1167 if (a.getKind() == notk)
1168 {
1169 gpol = !gpol;
1170 a = a[0];
1171 }
1172 Trace("ext-rew-eqchain") << "...child : " << a << std::endl;
1173 std::map<Node, bool>::iterator itc = cstatus.find(a);
1174 if (itc == cstatus.end())
1175 {
1176 cstatus[a] = true;
1177 }
1178 else
1179 {
1180 // cancels
1181 cstatus.erase(a);
1182 if (isXor)
1183 {
1184 gpol = !gpol;
1185 }
1186 }
1187 }
1188 Trace("ext-rew-eqchain") << "Global polarity : " << gpol << std::endl;
1189
1190 if (cstatus.empty())
1191 {
1192 return TermUtil::mkTypeConst(tn, gpol);
1193 }
1194
1195 children.clear();
1196
1197 // compute the atoms of each child
1198 Trace("ext-rew-eqchain") << "eqchain-simplify: begin\n";
1199 Trace("ext-rew-eqchain") << " eqchain-simplify: get atoms...\n";
1200 std::map<Node, std::map<Node, bool> > atoms;
1201 std::map<Node, std::vector<Node> > alist;
1202 std::vector<std::pair<Node, unsigned> > atom_count;
1203 for (std::pair<const Node, bool>& cp : cstatus)
1204 {
1205 if (!cp.second)
1206 {
1207 // already eliminated
1208 continue;
1209 }
1210 Node c = cp.first;
1211 Kind ck = c.getKind();
1212 if (ck == andk || ck == ork)
1213 {
1214 for (unsigned j = 0, nchild = c.getNumChildren(); j < nchild; j++)
1215 {
1216 Node cl = c[j];
1217 bool pol = cl.getKind() != notk;
1218 Node ca = pol ? cl : cl[0];
1219 Assert(atoms[c].find(ca) == atoms[c].end());
1220 // polarity is flipped when we are AND
1221 atoms[c][ca] = (ck == andk ? !pol : pol);
1222 alist[c].push_back(ca);
1223
1224 // if this already exists as a child of the equality chain, eliminate.
1225 // this catches cases like ( x & y ) = ( ( x & y ) | z ), where we
1226 // consider ( x & y ) a unit, whereas below it is expanded to
1227 // ~( ~x | ~y ).
1228 std::map<Node, bool>::iterator itc = cstatus.find(ca);
1229 if (itc != cstatus.end() && itc->second)
1230 {
1231 // cancel it
1232 cstatus[ca] = false;
1233 cstatus[c] = false;
1234 // make new child
1235 // x = ( y | ~x ) ---> y & x
1236 // x = ( y | x ) ---> ~y | x
1237 // x = ( y & x ) ---> y | ~x
1238 // x = ( y & ~x ) ---> ~y & ~x
1239 std::vector<Node> new_children;
1240 for (unsigned k = 0, nchild = c.getNumChildren(); k < nchild; k++)
1241 {
1242 if (j != k)
1243 {
1244 new_children.push_back(c[k]);
1245 }
1246 }
1247 Node nc[2];
1248 nc[0] = c[j];
1249 nc[1] = new_children.size() == 1 ? new_children[0]
1250 : nm->mkNode(ck, new_children);
1251 // negate the proper child
1252 unsigned nindex = (ck == andk) == pol ? 0 : 1;
1253 nc[nindex] = TermUtil::mkNegate(notk, nc[nindex]);
1254 Kind nk = pol ? ork : andk;
1255 // store as new child
1256 children.push_back(nm->mkNode(nk, nc[0], nc[1]));
1257 if (isXor)
1258 {
1259 gpol = !gpol;
1260 }
1261 break;
1262 }
1263 }
1264 }
1265 else
1266 {
1267 bool pol = ck != notk;
1268 Node ca = pol ? c : c[0];
1269 atoms[c][ca] = pol;
1270 alist[c].push_back(ca);
1271 }
1272 atom_count.push_back(std::pair<Node, unsigned>(c, alist[c].size()));
1273 }
1274 // sort the atoms in each atom list
1275 for (std::map<Node, std::vector<Node> >::iterator it = alist.begin();
1276 it != alist.end();
1277 ++it)
1278 {
1279 std::sort(it->second.begin(), it->second.end());
1280 }
1281 // check subsumptions
1282 // sort by #atoms
1283 std::sort(atom_count.begin(), atom_count.end(), sortPairSecond);
1284 if (Trace.isOn("ext-rew-eqchain"))
1285 {
1286 for (const std::pair<Node, unsigned>& ac : atom_count)
1287 {
1288 Trace("ext-rew-eqchain") << " eqchain-simplify: " << ac.first << " has "
1289 << ac.second << " atoms." << std::endl;
1290 }
1291 Trace("ext-rew-eqchain") << " eqchain-simplify: compute subsumptions...\n";
1292 }
1293 SimpSubsumeTrie sst;
1294 for (std::pair<const Node, bool>& cp : cstatus)
1295 {
1296 if (!cp.second)
1297 {
1298 // already eliminated
1299 continue;
1300 }
1301 Node c = cp.first;
1302 std::map<Node, std::map<Node, bool> >::iterator itc = atoms.find(c);
1303 Assert(itc != atoms.end());
1304 Trace("ext-rew-eqchain") << " - add term " << c << " with atom list "
1305 << alist[c] << "...\n";
1306 std::vector<Node> subsumes;
1307 sst.addTerm(c, alist[c], subsumes);
1308 for (const Node& cc : subsumes)
1309 {
1310 if (!cstatus[cc])
1311 {
1312 // subsumes a child that was already eliminated
1313 continue;
1314 }
1315 Trace("ext-rew-eqchain") << " eqchain-simplify: " << c << " subsumes "
1316 << cc << std::endl;
1317 // for each of the atoms in cc
1318 std::map<Node, std::map<Node, bool> >::iterator itcc = atoms.find(cc);
1319 Assert(itcc != atoms.end());
1320 std::vector<Node> common_children;
1321 std::vector<Node> diff_children;
1322 for (const std::pair<const Node, bool>& ap : itcc->second)
1323 {
1324 // compare the polarity
1325 Node a = ap.first;
1326 bool polcc = ap.second;
1327 Assert(itc->second.find(a) != itc->second.end());
1328 bool polc = itc->second[a];
1329 Trace("ext-rew-eqchain") << " eqchain-simplify: atom " << a
1330 << " has polarities : " << polc << " " << polcc
1331 << "\n";
1332 Node lit = polc ? a : TermUtil::mkNegate(notk, a);
1333 if (polc != polcc)
1334 {
1335 diff_children.push_back(lit);
1336 }
1337 else
1338 {
1339 common_children.push_back(lit);
1340 }
1341 }
1342 std::vector<Node> rem_children;
1343 for (const std::pair<const Node, bool>& ap : itc->second)
1344 {
1345 Node a = ap.first;
1346 if (atoms[cc].find(a) == atoms[cc].end())
1347 {
1348 bool polc = ap.second;
1349 rem_children.push_back(polc ? a : TermUtil::mkNegate(notk, a));
1350 }
1351 }
1352 Trace("ext-rew-eqchain")
1353 << " #common/diff/rem: " << common_children.size() << "/"
1354 << diff_children.size() << "/" << rem_children.size() << "\n";
1355 bool do_rewrite = false;
1356 if (common_children.empty() && itc->second.size() == itcc->second.size()
1357 && itcc->second.size() == 2)
1358 {
1359 // x | y = ~x | ~y ---> ~( x = y )
1360 do_rewrite = true;
1361 children.push_back(diff_children[0]);
1362 children.push_back(diff_children[1]);
1363 gpol = !gpol;
1364 Trace("ext-rew-eqchain") << " apply 2-child all-diff\n";
1365 }
1366 else if (common_children.empty() && diff_children.size() == 1)
1367 {
1368 do_rewrite = true;
1369 // x = ( ~x | y ) ---> ~( ~x | ~y )
1370 Node remn = rem_children.size() == 1 ? rem_children[0]
1371 : nm->mkNode(ork, rem_children);
1372 remn = TermUtil::mkNegate(notk, remn);
1373 children.push_back(nm->mkNode(ork, diff_children[0], remn));
1374 if (!isXor)
1375 {
1376 gpol = !gpol;
1377 }
1378 Trace("ext-rew-eqchain") << " apply unit resolution\n";
1379 }
1380 else if (diff_children.size() == 1
1381 && itc->second.size() == itcc->second.size())
1382 {
1383 // ( x | y | z ) = ( x | ~y | z ) ---> ( x | z )
1384 do_rewrite = true;
1385 Assert(!common_children.empty());
1386 Node comn = common_children.size() == 1
1387 ? common_children[0]
1388 : nm->mkNode(ork, common_children);
1389 children.push_back(comn);
1390 if (isXor)
1391 {
1392 gpol = !gpol;
1393 }
1394 Trace("ext-rew-eqchain") << " apply resolution\n";
1395 }
1396 else if (diff_children.empty())
1397 {
1398 do_rewrite = true;
1399 if (rem_children.empty())
1400 {
1401 // x | y = x | y ---> true
1402 // this can happen if we have ( ~x & ~y ) = ( x | y )
1403 children.push_back(TermUtil::mkTypeMaxValue(tn));
1404 if (isXor)
1405 {
1406 gpol = !gpol;
1407 }
1408 Trace("ext-rew-eqchain") << " apply cancel\n";
1409 }
1410 else
1411 {
1412 // x | y = ( x | y | z ) ---> ( x | y | ~z )
1413 Node remn = rem_children.size() == 1 ? rem_children[0]
1414 : nm->mkNode(ork, rem_children);
1415 remn = TermUtil::mkNegate(notk, remn);
1416 Node comn = common_children.size() == 1
1417 ? common_children[0]
1418 : nm->mkNode(ork, common_children);
1419 children.push_back(nm->mkNode(ork, comn, remn));
1420 if (isXor)
1421 {
1422 gpol = !gpol;
1423 }
1424 Trace("ext-rew-eqchain") << " apply subsume\n";
1425 }
1426 }
1427 if (do_rewrite)
1428 {
1429 // eliminate the children, reverse polarity as needed
1430 for (unsigned r = 0; r < 2; r++)
1431 {
1432 Node c_rem = r == 0 ? c : cc;
1433 cstatus[c_rem] = false;
1434 if (c_rem.getKind() == andk)
1435 {
1436 gpol = !gpol;
1437 }
1438 }
1439 break;
1440 }
1441 }
1442 }
1443 Trace("ext-rew-eqchain") << "eqchain-simplify: finish" << std::endl;
1444
1445 // sorted right associative chain
1446 bool has_nvar = false;
1447 unsigned nvar_index = 0;
1448 for (std::pair<const Node, bool>& cp : cstatus)
1449 {
1450 if (cp.second)
1451 {
1452 if (!cp.first.isVar())
1453 {
1454 has_nvar = true;
1455 nvar_index = children.size();
1456 }
1457 children.push_back(cp.first);
1458 }
1459 }
1460 std::sort(children.begin(), children.end());
1461
1462 Node new_ret;
1463 if (!gpol)
1464 {
1465 // negate the constant child if it exists
1466 unsigned nindex = has_nvar ? nvar_index : 0;
1467 children[nindex] = TermUtil::mkNegate(notk, children[nindex]);
1468 }
1469 new_ret = children.back();
1470 unsigned index = children.size() - 1;
1471 while (index > 0)
1472 {
1473 index--;
1474 new_ret = nm->mkNode(eqk, children[index], new_ret);
1475 }
1476 new_ret = Rewriter::rewrite(new_ret);
1477 if (new_ret != ret)
1478 {
1479 return new_ret;
1480 }
1481 return Node::null();
1482 }
1483
partialSubstitute(Node n,std::map<Node,Node> & assign,std::map<Kind,bool> & rkinds)1484 Node ExtendedRewriter::partialSubstitute(Node n,
1485 std::map<Node, Node>& assign,
1486 std::map<Kind, bool>& rkinds)
1487 {
1488 std::unordered_map<TNode, Node, TNodeHashFunction> visited;
1489 std::unordered_map<TNode, Node, TNodeHashFunction>::iterator it;
1490 std::vector<TNode> visit;
1491 TNode cur;
1492 visit.push_back(n);
1493 do
1494 {
1495 cur = visit.back();
1496 visit.pop_back();
1497 it = visited.find(cur);
1498
1499 if (it == visited.end())
1500 {
1501 std::map<Node, Node>::iterator it = assign.find(cur);
1502 if (it != assign.end())
1503 {
1504 visited[cur] = it->second;
1505 }
1506 else
1507 {
1508 // can only recurse on these kinds
1509 Kind k = cur.getKind();
1510 if (rkinds.find(k) != rkinds.end())
1511 {
1512 visited[cur] = Node::null();
1513 visit.push_back(cur);
1514 for (const Node& cn : cur)
1515 {
1516 visit.push_back(cn);
1517 }
1518 }
1519 else
1520 {
1521 visited[cur] = cur;
1522 }
1523 }
1524 }
1525 else if (it->second.isNull())
1526 {
1527 Node ret = cur;
1528 bool childChanged = false;
1529 std::vector<Node> children;
1530 if (cur.getMetaKind() == metakind::PARAMETERIZED)
1531 {
1532 children.push_back(cur.getOperator());
1533 }
1534 for (const Node& cn : cur)
1535 {
1536 it = visited.find(cn);
1537 Assert(it != visited.end());
1538 Assert(!it->second.isNull());
1539 childChanged = childChanged || cn != it->second;
1540 children.push_back(it->second);
1541 }
1542 if (childChanged)
1543 {
1544 ret = NodeManager::currentNM()->mkNode(cur.getKind(), children);
1545 }
1546 visited[cur] = ret;
1547 }
1548 } while (!visit.empty());
1549 Assert(visited.find(n) != visited.end());
1550 Assert(!visited.find(n)->second.isNull());
1551 return visited[n];
1552 }
1553
solveEquality(Node n)1554 Node ExtendedRewriter::solveEquality(Node n)
1555 {
1556 // TODO (#1706) : implement
1557 Assert(n.getKind() == EQUAL);
1558
1559 return Node::null();
1560 }
1561
inferSubstitution(Node n,std::vector<Node> & vars,std::vector<Node> & subs,bool usePred)1562 bool ExtendedRewriter::inferSubstitution(Node n,
1563 std::vector<Node>& vars,
1564 std::vector<Node>& subs,
1565 bool usePred)
1566 {
1567 if (n.getKind() == AND)
1568 {
1569 bool ret = false;
1570 for (const Node& nc : n)
1571 {
1572 bool cret = inferSubstitution(nc, vars, subs, usePred);
1573 ret = ret || cret;
1574 }
1575 return ret;
1576 }
1577 if (n.getKind() == EQUAL)
1578 {
1579 // see if it can be put into form x = y
1580 Node slv_eq = solveEquality(n);
1581 if (!slv_eq.isNull())
1582 {
1583 n = slv_eq;
1584 }
1585 Node v[2];
1586 for (unsigned i = 0; i < 2; i++)
1587 {
1588 if (n[i].isConst())
1589 {
1590 vars.push_back(n[1 - i]);
1591 subs.push_back(n[i]);
1592 return true;
1593 }
1594 if (n[i].isVar())
1595 {
1596 v[i] = n[i];
1597 }
1598 else if (TermUtil::isNegate(n[i].getKind()) && n[i][0].isVar())
1599 {
1600 v[i] = n[i][0];
1601 }
1602 }
1603 for (unsigned i = 0; i < 2; i++)
1604 {
1605 TNode r1 = v[i];
1606 Node r2 = v[1 - i];
1607 if (r1.isVar() && ((r2.isVar() && r1 < r2) || r2.isConst()))
1608 {
1609 r2 = n[1 - i];
1610 if (v[i] != n[i])
1611 {
1612 Assert(TermUtil::isNegate(n[i].getKind()));
1613 r2 = TermUtil::mkNegate(n[i].getKind(), r2);
1614 }
1615 // TODO (#1706) : union find
1616 if (std::find(vars.begin(), vars.end(), r1) == vars.end())
1617 {
1618 vars.push_back(r1);
1619 subs.push_back(r2);
1620 return true;
1621 }
1622 }
1623 }
1624 }
1625 if (usePred)
1626 {
1627 bool negated = n.getKind() == NOT;
1628 vars.push_back(negated ? n[0] : n);
1629 subs.push_back(negated ? d_false : d_true);
1630 return true;
1631 }
1632 return false;
1633 }
1634
extendedRewriteArith(Node ret)1635 Node ExtendedRewriter::extendedRewriteArith(Node ret)
1636 {
1637 Kind k = ret.getKind();
1638 NodeManager* nm = NodeManager::currentNM();
1639 Node new_ret;
1640 if (k == DIVISION || k == INTS_DIVISION || k == INTS_MODULUS)
1641 {
1642 // rewrite as though total
1643 std::vector<Node> children;
1644 bool all_const = true;
1645 for (unsigned i = 0, size = ret.getNumChildren(); i < size; i++)
1646 {
1647 if (ret[i].isConst())
1648 {
1649 children.push_back(ret[i]);
1650 }
1651 else
1652 {
1653 all_const = false;
1654 break;
1655 }
1656 }
1657 if (all_const)
1658 {
1659 Kind new_k = (ret.getKind() == DIVISION ? DIVISION_TOTAL
1660 : (ret.getKind() == INTS_DIVISION
1661 ? INTS_DIVISION_TOTAL
1662 : INTS_MODULUS_TOTAL));
1663 new_ret = nm->mkNode(new_k, children);
1664 debugExtendedRewrite(ret, new_ret, "total-interpretation");
1665 }
1666 }
1667 return new_ret;
1668 }
1669
extendedRewriteStrings(Node ret)1670 Node ExtendedRewriter::extendedRewriteStrings(Node ret)
1671 {
1672 Node new_ret;
1673 Trace("q-ext-rewrite-debug")
1674 << "Extended rewrite strings : " << ret << std::endl;
1675
1676 if (ret.getKind() == EQUAL)
1677 {
1678 new_ret = strings::TheoryStringsRewriter::rewriteEqualityExt(ret);
1679 }
1680
1681 return new_ret;
1682 }
1683
debugExtendedRewrite(Node n,Node ret,const char * c) const1684 void ExtendedRewriter::debugExtendedRewrite(Node n,
1685 Node ret,
1686 const char* c) const
1687 {
1688 if (Trace.isOn("q-ext-rewrite"))
1689 {
1690 if (!ret.isNull())
1691 {
1692 Trace("q-ext-rewrite-apply") << "sygus-extr : apply " << c << std::endl;
1693 Trace("q-ext-rewrite") << "sygus-extr : " << c << " : " << n
1694 << " rewrites to " << ret << std::endl;
1695 }
1696 }
1697 }
1698
1699 } /* CVC4::theory::quantifiers namespace */
1700 } /* CVC4::theory namespace */
1701 } /* CVC4 namespace */
1702