1 /*********************                                                        */
2 /*! \file datatypes_rewriter.cpp
3  ** \verbatim
4  ** Top contributors (to current version):
5  **   Andrew Reynolds, Morgan Deters, Paul Meng
6  ** This file is part of the CVC4 project.
7  ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8  ** in the top-level source directory) and their institutional affiliations.
9  ** All rights reserved.  See the file COPYING in the top-level source
10  ** directory for licensing information.\endverbatim
11  **
12  ** \brief Implementation of rewriter for the theory of (co)inductive datatypes.
13  **
14  ** Implementation of rewriter for the theory of (co)inductive datatypes.
15  **/
16 
17 #include "theory/datatypes/datatypes_rewriter.h"
18 
19 using namespace CVC4;
20 using namespace CVC4::kind;
21 
22 namespace CVC4 {
23 namespace theory {
24 namespace datatypes {
25 
postRewrite(TNode in)26 RewriteResponse DatatypesRewriter::postRewrite(TNode in)
27 {
28   Trace("datatypes-rewrite-debug") << "post-rewriting " << in << std::endl;
29   Kind k = in.getKind();
30   NodeManager* nm = NodeManager::currentNM();
31   if (k == kind::APPLY_CONSTRUCTOR)
32   {
33     return rewriteConstructor(in);
34   }
35   else if (k == kind::APPLY_SELECTOR_TOTAL || k == kind::APPLY_SELECTOR)
36   {
37     return rewriteSelector(in);
38   }
39   else if (k == kind::APPLY_TESTER)
40   {
41     return rewriteTester(in);
42   }
43   else if (k == kind::DT_SIZE)
44   {
45     if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
46     {
47       std::vector<Node> children;
48       for (unsigned i = 0, size = in [0].getNumChildren(); i < size; i++)
49       {
50         if (in[0][i].getType().isDatatype())
51         {
52           children.push_back(nm->mkNode(kind::DT_SIZE, in[0][i]));
53         }
54       }
55       TNode constructor = in[0].getOperator();
56       size_t constructorIndex = indexOf(constructor);
57       const Datatype& dt = Datatype::datatypeOf(constructor.toExpr());
58       const DatatypeConstructor& c = dt[constructorIndex];
59       unsigned weight = c.getWeight();
60       children.push_back(nm->mkConst(Rational(weight)));
61       Node res =
62           children.size() == 1 ? children[0] : nm->mkNode(kind::PLUS, children);
63       Trace("datatypes-rewrite")
64           << "DatatypesRewriter::postRewrite: rewrite size " << in << " to "
65           << res << std::endl;
66       return RewriteResponse(REWRITE_AGAIN_FULL, res);
67     }
68   }
69   else if (k == kind::DT_HEIGHT_BOUND)
70   {
71     if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
72     {
73       std::vector<Node> children;
74       Node res;
75       Rational r = in[1].getConst<Rational>();
76       Rational rmo = Rational(r - Rational(1));
77       for (unsigned i = 0, size = in [0].getNumChildren(); i < size; i++)
78       {
79         if (in[0][i].getType().isDatatype())
80         {
81           if (r.isZero())
82           {
83             res = nm->mkConst(false);
84             break;
85           }
86           children.push_back(
87               nm->mkNode(kind::DT_HEIGHT_BOUND, in[0][i], nm->mkConst(rmo)));
88         }
89       }
90       if (res.isNull())
91       {
92         res = children.size() == 0
93                   ? nm->mkConst(true)
94                   : (children.size() == 1 ? children[0]
95                                           : nm->mkNode(kind::AND, children));
96       }
97       Trace("datatypes-rewrite")
98           << "DatatypesRewriter::postRewrite: rewrite height " << in << " to "
99           << res << std::endl;
100       return RewriteResponse(REWRITE_AGAIN_FULL, res);
101     }
102   }
103   else if (k == kind::DT_SIZE_BOUND)
104   {
105     if (in[0].isConst())
106     {
107       Node res = nm->mkNode(kind::LEQ, nm->mkNode(kind::DT_SIZE, in[0]), in[1]);
108       return RewriteResponse(REWRITE_AGAIN_FULL, res);
109     }
110   }
111   else if (k == DT_SYGUS_EVAL)
112   {
113     // sygus evaluation function
114     Node ev = in[0];
115     if (ev.getKind() == APPLY_CONSTRUCTOR)
116     {
117       Trace("dt-sygus-util") << "Rewrite " << in << " by unfolding...\n";
118       const Datatype& dt =
119           static_cast<DatatypeType>(ev.getType().toType()).getDatatype();
120       unsigned i = indexOf(ev.getOperator());
121       Node op = Node::fromExpr(dt[i].getSygusOp());
122       // if it is the "any constant" constructor, return its argument
123       if (op.getAttribute(SygusAnyConstAttribute()))
124       {
125         Assert(ev.getNumChildren() == 1);
126         Assert(ev[0].getType().isComparableTo(in.getType()));
127         return RewriteResponse(REWRITE_AGAIN_FULL, ev[0]);
128       }
129       std::vector<Node> args;
130       for (unsigned j = 1, nchild = in.getNumChildren(); j < nchild; j++)
131       {
132         args.push_back(in[j]);
133       }
134       Assert(!dt.isParametric());
135       std::vector<Node> children;
136       for (const Node& evc : ev)
137       {
138         std::vector<Node> cc;
139         cc.push_back(evc);
140         cc.insert(cc.end(), args.begin(), args.end());
141         children.push_back(nm->mkNode(DT_SYGUS_EVAL, cc));
142       }
143       Node ret = mkSygusTerm(dt, i, children);
144       // if it is a variable, apply the substitution
145       if (ret.getKind() == BOUND_VARIABLE)
146       {
147         Assert(ret.hasAttribute(SygusVarNumAttribute()));
148         int vn = ret.getAttribute(SygusVarNumAttribute());
149         Assert(Node::fromExpr(dt.getSygusVarList())[vn] == ret);
150         ret = args[vn];
151       }
152       Trace("dt-sygus-util") << "...got " << ret << "\n";
153       return RewriteResponse(REWRITE_AGAIN_FULL, ret);
154     }
155   }
156 
157   if (k == kind::EQUAL)
158   {
159     if (in[0] == in[1])
160     {
161       return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
162     }
163     std::vector<Node> rew;
164     if (checkClash(in[0], in[1], rew))
165     {
166       Trace("datatypes-rewrite")
167           << "Rewrite clashing equality " << in << " to false" << std::endl;
168       return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
169       //}else if( rew.size()==1 && rew[0]!=in ){
170       //  Trace("datatypes-rewrite") << "Rewrite equality " << in << " to " <<
171       //  rew[0] << std::endl;
172       //  return RewriteResponse(REWRITE_AGAIN_FULL, rew[0] );
173     }
174     else if (in[1] < in[0])
175     {
176       Node ins = nm->mkNode(in.getKind(), in[1], in[0]);
177       Trace("datatypes-rewrite")
178           << "Swap equality " << in << " to " << ins << std::endl;
179       return RewriteResponse(REWRITE_DONE, ins);
180     }
181     Trace("datatypes-rewrite-debug")
182         << "Did not rewrite equality " << in << " " << in[0].getKind() << " "
183         << in[1].getKind() << std::endl;
184   }
185 
186   return RewriteResponse(REWRITE_DONE, in);
187 }
188 
getOperatorKindForSygusBuiltin(Node op)189 Kind DatatypesRewriter::getOperatorKindForSygusBuiltin(Node op)
190 {
191   Assert(op.getKind() != BUILTIN);
192   if (op.getKind() == LAMBDA)
193   {
194     // we use APPLY_UF instead of APPLY, since the rewriter for APPLY_UF
195     // does beta-reduction but does not for APPLY
196     return APPLY_UF;
197   }
198   TypeNode tn = op.getType();
199   if (tn.isConstructor())
200   {
201     return APPLY_CONSTRUCTOR;
202   }
203   else if (tn.isSelector())
204   {
205     return APPLY_SELECTOR;
206   }
207   else if (tn.isTester())
208   {
209     return APPLY_TESTER;
210   }
211   else if (tn.isFunction())
212   {
213     return APPLY_UF;
214   }
215   return UNDEFINED_KIND;
216 }
217 
mkSygusTerm(const Datatype & dt,unsigned i,const std::vector<Node> & children)218 Node DatatypesRewriter::mkSygusTerm(const Datatype& dt,
219                                     unsigned i,
220                                     const std::vector<Node>& children)
221 {
222   Trace("dt-sygus-util") << "Make sygus term " << dt.getName() << "[" << i
223                          << "] with children: " << children << std::endl;
224   Assert(i < dt.getNumConstructors());
225   Assert(dt.isSygus());
226   Assert(!dt[i].getSygusOp().isNull());
227   std::vector<Node> schildren;
228   Node op = Node::fromExpr(dt[i].getSygusOp());
229   // if it is the any constant, we simply return the child
230   if (op.getAttribute(SygusAnyConstAttribute()))
231   {
232     Assert(children.size() == 1);
233     return children[0];
234   }
235   if (op.getKind() != BUILTIN)
236   {
237     schildren.push_back(op);
238   }
239   schildren.insert(schildren.end(), children.begin(), children.end());
240   Node ret;
241   if (op.getKind() == BUILTIN)
242   {
243     ret = NodeManager::currentNM()->mkNode(op, schildren);
244     Trace("dt-sygus-util") << "...return (builtin) " << ret << std::endl;
245     return ret;
246   }
247   Kind ok = NodeManager::operatorToKind(op);
248   if (ok != UNDEFINED_KIND)
249   {
250     ret = NodeManager::currentNM()->mkNode(ok, schildren);
251     Trace("dt-sygus-util") << "...return (op) " << ret << std::endl;
252     return ret;
253   }
254   Kind tok = getOperatorKindForSygusBuiltin(op);
255   if (schildren.size() == 1 && tok == kind::UNDEFINED_KIND)
256   {
257     ret = schildren[0];
258   }
259   else
260   {
261     ret = NodeManager::currentNM()->mkNode(tok, schildren);
262   }
263   Trace("dt-sygus-util") << "...return " << ret << std::endl;
264   return ret;
265 }
266 
preRewrite(TNode in)267 RewriteResponse DatatypesRewriter::preRewrite(TNode in)
268 {
269   Trace("datatypes-rewrite-debug") << "pre-rewriting " << in << std::endl;
270   // must prewrite to apply type ascriptions since rewriting does not preserve
271   // types
272   if (in.getKind() == kind::APPLY_CONSTRUCTOR)
273   {
274     TypeNode tn = in.getType();
275     Type t = tn.toType();
276     DatatypeType dt = DatatypeType(t);
277 
278     // check for parametric datatype constructors
279     // to ensure a normal form, all parameteric datatype constructors must have
280     // a type ascription
281     if (dt.isParametric())
282     {
283       if (in.getOperator().getKind() != kind::APPLY_TYPE_ASCRIPTION)
284       {
285         Trace("datatypes-rewrite-debug")
286             << "Ascribing type to parametric datatype constructor " << in
287             << std::endl;
288         Node op = in.getOperator();
289         // get the constructor object
290         const DatatypeConstructor& dtc =
291             Datatype::datatypeOf(op.toExpr())[indexOf(op)];
292         // create ascribed constructor type
293         Node tc = NodeManager::currentNM()->mkConst(
294             AscriptionType(dtc.getSpecializedConstructorType(t)));
295         Node op_new = NodeManager::currentNM()->mkNode(
296             kind::APPLY_TYPE_ASCRIPTION, tc, op);
297         // make new node
298         std::vector<Node> children;
299         children.push_back(op_new);
300         children.insert(children.end(), in.begin(), in.end());
301         Node inr =
302             NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, children);
303         Trace("datatypes-rewrite-debug") << "Created " << inr << std::endl;
304         return RewriteResponse(REWRITE_DONE, inr);
305       }
306     }
307   }
308   return RewriteResponse(REWRITE_DONE, in);
309 }
310 
rewriteConstructor(TNode in)311 RewriteResponse DatatypesRewriter::rewriteConstructor(TNode in)
312 {
313   if (in.isConst())
314   {
315     Trace("datatypes-rewrite-debug") << "Normalizing constant " << in
316                                      << std::endl;
317     Node inn = normalizeConstant(in);
318     // constant may be a subterm of another constant, so cannot assume that this
319     // will succeed for codatatypes
320     // Assert( !inn.isNull() );
321     if (!inn.isNull() && inn != in)
322     {
323       Trace("datatypes-rewrite") << "Normalized constant " << in << " -> "
324                                  << inn << std::endl;
325       return RewriteResponse(REWRITE_DONE, inn);
326     }
327     return RewriteResponse(REWRITE_DONE, in);
328   }
329   return RewriteResponse(REWRITE_DONE, in);
330 }
331 
rewriteSelector(TNode in)332 RewriteResponse DatatypesRewriter::rewriteSelector(TNode in)
333 {
334   Kind k = in.getKind();
335   if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
336   {
337     // Have to be careful not to rewrite well-typed expressions
338     // where the selector doesn't match the constructor,
339     // e.g. "pred(zero)".
340     TypeNode tn = in.getType();
341     TypeNode argType = in[0].getType();
342     Expr selector = in.getOperator().toExpr();
343     TNode constructor = in[0].getOperator();
344     size_t constructorIndex = indexOf(constructor);
345     const Datatype& dt = Datatype::datatypeOf(selector);
346     const DatatypeConstructor& c = dt[constructorIndex];
347     Trace("datatypes-rewrite-debug") << "Rewriting collapsable selector : "
348                                      << in;
349     Trace("datatypes-rewrite-debug") << ", cindex = " << constructorIndex
350                                      << ", selector is " << selector
351                                      << std::endl;
352     // The argument that the selector extracts, or -1 if the selector is
353     // is wrongly applied.
354     int selectorIndex = -1;
355     if (k == kind::APPLY_SELECTOR_TOTAL)
356     {
357       // The argument index of internal selectors is obtained by
358       // getSelectorIndexInternal.
359       selectorIndex = c.getSelectorIndexInternal(selector);
360     }
361     else
362     {
363       // The argument index of external selectors (applications of
364       // APPLY_SELECTOR) is given by an attribute and obtained via indexOf below
365       // The argument is only valid if it is the proper constructor.
366       selectorIndex = Datatype::indexOf(selector);
367       if (selectorIndex < 0
368           || selectorIndex >= static_cast<int>(c.getNumArgs()))
369       {
370         selectorIndex = -1;
371       }
372       else if (c[selectorIndex].getSelector() != selector)
373       {
374         selectorIndex = -1;
375       }
376     }
377     Trace("datatypes-rewrite-debug") << "Internal selector index is "
378                                      << selectorIndex << std::endl;
379     if (selectorIndex >= 0)
380     {
381       Assert(selectorIndex < (int)c.getNumArgs());
382       if (dt.isCodatatype() && in[0][selectorIndex].isConst())
383       {
384         // must replace all debruijn indices with self
385         Node sub = replaceDebruijn(in[0][selectorIndex], in[0], argType, 0);
386         Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
387                                    << "Rewrite trivial codatatype selector "
388                                    << in << " to " << sub << std::endl;
389         if (sub != in)
390         {
391           return RewriteResponse(REWRITE_AGAIN_FULL, sub);
392         }
393       }
394       else
395       {
396         Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
397                                    << "Rewrite trivial selector " << in
398                                    << std::endl;
399         return RewriteResponse(REWRITE_DONE, in[0][selectorIndex]);
400       }
401     }
402     else if (k == kind::APPLY_SELECTOR_TOTAL)
403     {
404       Node gt;
405       bool useTe = true;
406       // if( !tn.isSort() ){
407       //  useTe = false;
408       //}
409       if (tn.isDatatype())
410       {
411         const Datatype& dta = ((DatatypeType)(tn).toType()).getDatatype();
412         useTe = !dta.isCodatatype();
413       }
414       if (useTe)
415       {
416         TypeEnumerator te(tn);
417         gt = *te;
418       }
419       else
420       {
421         gt = tn.mkGroundTerm();
422       }
423       if (!gt.isNull())
424       {
425         // Assert( gtt.isDatatype() || gtt.isParametricDatatype() );
426         if (tn.isDatatype() && !tn.isInstantiatedDatatype())
427         {
428           gt = NodeManager::currentNM()->mkNode(
429               kind::APPLY_TYPE_ASCRIPTION,
430               NodeManager::currentNM()->mkConst(AscriptionType(tn.toType())),
431               gt);
432         }
433         Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
434                                    << "Rewrite trivial selector " << in
435                                    << " to distinguished ground term " << gt
436                                    << std::endl;
437         return RewriteResponse(REWRITE_DONE, gt);
438       }
439     }
440   }
441   return RewriteResponse(REWRITE_DONE, in);
442 }
443 
rewriteTester(TNode in)444 RewriteResponse DatatypesRewriter::rewriteTester(TNode in)
445 {
446   if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
447   {
448     bool result = indexOf(in.getOperator()) == indexOf(in[0].getOperator());
449     Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
450                                << "Rewrite trivial tester " << in << " "
451                                << result << std::endl;
452     return RewriteResponse(REWRITE_DONE,
453                            NodeManager::currentNM()->mkConst(result));
454   }
455   const Datatype& dt = static_cast<DatatypeType>(in[0].getType().toType()).getDatatype();
456   if (dt.getNumConstructors() == 1 && !dt.isSygus())
457   {
458     // only one constructor, so it must be
459     Trace("datatypes-rewrite")
460         << "DatatypesRewriter::postRewrite: "
461         << "only one ctor for " << dt.getName() << " and that is "
462         << dt[0].getName() << std::endl;
463     return RewriteResponse(REWRITE_DONE,
464                            NodeManager::currentNM()->mkConst(true));
465   }
466   // could try dt.getNumConstructors()==2 && indexOf(in.getOperator())==1 ?
467   return RewriteResponse(REWRITE_DONE, in);
468 }
469 
checkClash(Node n1,Node n2,std::vector<Node> & rew)470 bool DatatypesRewriter::checkClash(Node n1, Node n2, std::vector<Node>& rew)
471 {
472   Trace("datatypes-rewrite-debug") << "Check clash : " << n1 << " " << n2
473                                    << std::endl;
474   if (n1.getKind() == kind::APPLY_CONSTRUCTOR
475       && n2.getKind() == kind::APPLY_CONSTRUCTOR)
476   {
477     if (n1.getOperator() != n2.getOperator())
478     {
479       Trace("datatypes-rewrite-debug") << "Clash operators : " << n1 << " "
480                                        << n2 << " " << n1.getOperator() << " "
481                                        << n2.getOperator() << std::endl;
482       return true;
483     }
484     Assert(n1.getNumChildren() == n2.getNumChildren());
485     for (unsigned i = 0, size = n1.getNumChildren(); i < size; i++)
486     {
487       if (checkClash(n1[i], n2[i], rew))
488       {
489         return true;
490       }
491     }
492   }
493   else if (n1 != n2)
494   {
495     if (n1.isConst() && n2.isConst())
496     {
497       Trace("datatypes-rewrite-debug") << "Clash constants : " << n1 << " "
498                                        << n2 << std::endl;
499       return true;
500     }
501     else
502     {
503       Node eq = NodeManager::currentNM()->mkNode(kind::EQUAL, n1, n2);
504       rew.push_back(eq);
505     }
506   }
507   return false;
508 }
509 /** get instantiate cons */
getInstCons(Node n,const Datatype & dt,int index)510 Node DatatypesRewriter::getInstCons(Node n, const Datatype& dt, int index)
511 {
512   Assert(index >= 0 && index < (int)dt.getNumConstructors());
513   std::vector<Node> children;
514   NodeManager* nm = NodeManager::currentNM();
515   children.push_back(Node::fromExpr(dt[index].getConstructor()));
516   Type t = n.getType().toType();
517   for (unsigned i = 0, nargs = dt[index].getNumArgs(); i < nargs; i++)
518   {
519     Node nc = nm->mkNode(kind::APPLY_SELECTOR_TOTAL,
520                          Node::fromExpr(dt[index].getSelectorInternal(t, i)),
521                          n);
522     children.push_back(nc);
523   }
524   Node n_ic = nm->mkNode(kind::APPLY_CONSTRUCTOR, children);
525   if (dt.isParametric())
526   {
527     TypeNode tn = TypeNode::fromType(t);
528     // add type ascription for ambiguous constructor types
529     if (!n_ic.getType().isComparableTo(tn))
530     {
531       Debug("datatypes-parametric") << "DtInstantiate: ambiguous type for "
532                                     << n_ic << ", ascribe to " << n.getType()
533                                     << std::endl;
534       Debug("datatypes-parametric") << "Constructor is " << dt[index]
535                                     << std::endl;
536       Type tspec =
537           dt[index].getSpecializedConstructorType(n.getType().toType());
538       Debug("datatypes-parametric") << "Type specification is " << tspec
539                                     << std::endl;
540       children[0] = nm->mkNode(kind::APPLY_TYPE_ASCRIPTION,
541                                nm->mkConst(AscriptionType(tspec)),
542                                children[0]);
543       n_ic = nm->mkNode(kind::APPLY_CONSTRUCTOR, children);
544       Assert(n_ic.getType() == tn);
545     }
546   }
547   Assert(isInstCons(n, n_ic, dt) == index);
548   // n_ic = Rewriter::rewrite( n_ic );
549   return n_ic;
550 }
551 
isInstCons(Node t,Node n,const Datatype & dt)552 int DatatypesRewriter::isInstCons(Node t, Node n, const Datatype& dt)
553 {
554   if (n.getKind() == kind::APPLY_CONSTRUCTOR)
555   {
556     int index = indexOf(n.getOperator());
557     const DatatypeConstructor& c = dt[index];
558     Type nt = n.getType().toType();
559     for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
560     {
561       if (n[i].getKind() != kind::APPLY_SELECTOR_TOTAL
562           || n[i].getOperator() != Node::fromExpr(c.getSelectorInternal(nt, i))
563           || n[i][0] != t)
564       {
565         return -1;
566       }
567     }
568     return index;
569   }
570   return -1;
571 }
572 
isTester(Node n,Node & a)573 int DatatypesRewriter::isTester(Node n, Node& a)
574 {
575   if (n.getKind() == kind::APPLY_TESTER)
576   {
577     a = n[0];
578     return indexOf(n.getOperator());
579   }
580   return -1;
581 }
582 
isTester(Node n)583 int DatatypesRewriter::isTester(Node n)
584 {
585   if (n.getKind() == kind::APPLY_TESTER)
586   {
587     return indexOf(n.getOperator().toExpr());
588   }
589   return -1;
590 }
591 
592 struct DtIndexAttributeId
593 {
594 };
595 typedef expr::Attribute<DtIndexAttributeId, uint64_t> DtIndexAttribute;
596 
indexOf(Node n)597 unsigned DatatypesRewriter::indexOf(Node n)
598 {
599   if (!n.hasAttribute(DtIndexAttribute()))
600   {
601     Assert(n.getType().isConstructor() || n.getType().isTester()
602            || n.getType().isSelector());
603     unsigned index = Datatype::indexOfInternal(n.toExpr());
604     n.setAttribute(DtIndexAttribute(), index);
605     return index;
606   }
607   return n.getAttribute(DtIndexAttribute());
608 }
609 
mkTester(Node n,int i,const Datatype & dt)610 Node DatatypesRewriter::mkTester(Node n, int i, const Datatype& dt)
611 {
612   return NodeManager::currentNM()->mkNode(
613       kind::APPLY_TESTER, Node::fromExpr(dt[i].getTester()), n);
614 }
615 
mkSplit(Node n,const Datatype & dt)616 Node DatatypesRewriter::mkSplit(Node n, const Datatype& dt)
617 {
618   std::vector<Node> splits;
619   for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++)
620   {
621     Node test = mkTester(n, i, dt);
622     splits.push_back(test);
623   }
624   NodeManager* nm = NodeManager::currentNM();
625   return splits.size() == 1 ? splits[0] : nm->mkNode(kind::OR, splits);
626 }
627 
isNullaryApplyConstructor(Node n)628 bool DatatypesRewriter::isNullaryApplyConstructor(Node n)
629 {
630   Assert(n.getKind() == kind::APPLY_CONSTRUCTOR);
631   for (unsigned i = 0; i < n.getNumChildren(); i++)
632   {
633     if (n[i].getType().isDatatype())
634     {
635       return false;
636     }
637   }
638   return true;
639 }
640 
isNullaryConstructor(const DatatypeConstructor & c)641 bool DatatypesRewriter::isNullaryConstructor(const DatatypeConstructor& c)
642 {
643   for (unsigned j = 0, nargs = c.getNumArgs(); j < nargs; j++)
644   {
645     if (c[j].getType().getRangeType().isDatatype())
646     {
647       return false;
648     }
649   }
650   return true;
651 }
652 
normalizeCodatatypeConstant(Node n)653 Node DatatypesRewriter::normalizeCodatatypeConstant(Node n)
654 {
655   Trace("dt-nconst") << "Normalize " << n << std::endl;
656   std::map<Node, Node> rf;
657   std::vector<Node> sk;
658   std::vector<Node> rf_pending;
659   std::vector<Node> terms;
660   std::map<Node, bool> cdts;
661   Node s = collectRef(n, sk, rf, rf_pending, terms, cdts);
662   if (!s.isNull())
663   {
664     Trace("dt-nconst") << "...symbolic normalized is : " << s << std::endl;
665     for (std::map<Node, Node>::iterator it = rf.begin(); it != rf.end(); ++it)
666     {
667       Trace("dt-nconst") << "  " << it->first << " = " << it->second
668                          << std::endl;
669     }
670     // now run DFA minimization on term structure
671     Trace("dt-nconst") << "  " << terms.size()
672                        << " total subterms :" << std::endl;
673     int eqc_count = 0;
674     std::map<Node, int> eqc_op_map;
675     std::map<Node, int> eqc;
676     std::map<int, std::vector<Node> > eqc_nodes;
677     // partition based on top symbol
678     for (unsigned j = 0, size = terms.size(); j < size; j++)
679     {
680       Node t = terms[j];
681       Trace("dt-nconst") << "    " << t << ", cdt=" << cdts[t] << std::endl;
682       int e;
683       if (cdts[t])
684       {
685         Assert(t.getKind() == kind::APPLY_CONSTRUCTOR);
686         Node op = t.getOperator();
687         std::map<Node, int>::iterator it = eqc_op_map.find(op);
688         if (it == eqc_op_map.end())
689         {
690           e = eqc_count;
691           eqc_op_map[op] = eqc_count;
692           eqc_count++;
693         }
694         else
695         {
696           e = it->second;
697         }
698       }
699       else
700       {
701         e = eqc_count;
702         eqc_count++;
703       }
704       eqc[t] = e;
705       eqc_nodes[e].push_back(t);
706     }
707     // partition until fixed point
708     int eqc_curr = 0;
709     bool success = true;
710     while (success)
711     {
712       success = false;
713       int eqc_end = eqc_count;
714       while (eqc_curr < eqc_end)
715       {
716         if (eqc_nodes[eqc_curr].size() > 1)
717         {
718           // look at all nodes in this equivalence class
719           unsigned nchildren = eqc_nodes[eqc_curr][0].getNumChildren();
720           std::map<int, std::vector<Node> > prt;
721           for (unsigned j = 0; j < nchildren; j++)
722           {
723             prt.clear();
724             // partition based on children : for the first child that causes a
725             // split, break
726             for (unsigned k = 0, size = eqc_nodes[eqc_curr].size(); k < size;
727                  k++)
728             {
729               Node t = eqc_nodes[eqc_curr][k];
730               Assert(t.getNumChildren() == nchildren);
731               Node tc = t[j];
732               // refer to loops
733               std::map<Node, Node>::iterator itr = rf.find(tc);
734               if (itr != rf.end())
735               {
736                 tc = itr->second;
737               }
738               Assert(eqc.find(tc) != eqc.end());
739               prt[eqc[tc]].push_back(t);
740             }
741             if (prt.size() > 1)
742             {
743               success = true;
744               break;
745             }
746           }
747           // move into new eqc(s)
748           for (const std::pair<const int, std::vector<Node> >& p : prt)
749           {
750             int e = eqc_count;
751             for (unsigned j = 0, size = p.second.size(); j < size; j++)
752             {
753               Node t = p.second[j];
754               eqc[t] = e;
755               eqc_nodes[e].push_back(t);
756             }
757             eqc_count++;
758           }
759         }
760         eqc_nodes.erase(eqc_curr);
761         eqc_curr++;
762       }
763     }
764     // add in already occurring loop variables
765     for (std::map<Node, Node>::iterator it = rf.begin(); it != rf.end(); ++it)
766     {
767       Trace("dt-nconst-debug") << "Mapping equivalence class of " << it->first
768                                << " -> " << it->second << std::endl;
769       Assert(eqc.find(it->second) != eqc.end());
770       eqc[it->first] = eqc[it->second];
771     }
772     // we now have a partition of equivalent terms
773     Trace("dt-nconst") << "Computed equivalence classes ids : " << std::endl;
774     for (std::map<Node, int>::iterator it = eqc.begin(); it != eqc.end(); ++it)
775     {
776       Trace("dt-nconst") << "  " << it->first << " -> " << it->second
777                          << std::endl;
778     }
779     // traverse top-down based on equivalence class
780     std::map<int, int> eqc_stack;
781     return normalizeCodatatypeConstantEqc(s, eqc_stack, eqc, 0);
782   }
783   Trace("dt-nconst") << "...invalid." << std::endl;
784   return Node::null();
785 }
786 
787 // normalize constant : apply to top-level codatatype constants
normalizeConstant(Node n)788 Node DatatypesRewriter::normalizeConstant(Node n)
789 {
790   TypeNode tn = n.getType();
791   if (tn.isDatatype())
792   {
793     if (tn.isCodatatype())
794     {
795       return normalizeCodatatypeConstant(n);
796     }
797     else
798     {
799       std::vector<Node> children;
800       bool childrenChanged = false;
801       for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
802       {
803         Node nc = normalizeConstant(n[i]);
804         children.push_back(nc);
805         childrenChanged = childrenChanged || nc != n[i];
806       }
807       if (childrenChanged)
808       {
809         return NodeManager::currentNM()->mkNode(n.getKind(), children);
810       }
811     }
812   }
813   return n;
814 }
815 
collectRef(Node n,std::vector<Node> & sk,std::map<Node,Node> & rf,std::vector<Node> & rf_pending,std::vector<Node> & terms,std::map<Node,bool> & cdts)816 Node DatatypesRewriter::collectRef(Node n,
817                                    std::vector<Node>& sk,
818                                    std::map<Node, Node>& rf,
819                                    std::vector<Node>& rf_pending,
820                                    std::vector<Node>& terms,
821                                    std::map<Node, bool>& cdts)
822 {
823   Assert(n.isConst());
824   TypeNode tn = n.getType();
825   Node ret = n;
826   bool isCdt = false;
827   if (tn.isDatatype())
828   {
829     if (!tn.isCodatatype())
830     {
831       // nested datatype within codatatype : can be normalized independently
832       // since all loops should be self-contained
833       ret = normalizeConstant(n);
834     }
835     else
836     {
837       isCdt = true;
838       if (n.getKind() == kind::APPLY_CONSTRUCTOR)
839       {
840         sk.push_back(n);
841         rf_pending.push_back(Node::null());
842         std::vector<Node> children;
843         children.push_back(n.getOperator());
844         bool childChanged = false;
845         for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
846         {
847           Node nc = collectRef(n[i], sk, rf, rf_pending, terms, cdts);
848           if (nc.isNull())
849           {
850             return Node::null();
851           }
852           childChanged = nc != n[i] || childChanged;
853           children.push_back(nc);
854         }
855         sk.pop_back();
856         if (childChanged)
857         {
858           ret = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR,
859                                                  children);
860           if (!rf_pending.back().isNull())
861           {
862             rf[rf_pending.back()] = ret;
863           }
864         }
865         else
866         {
867           Assert(rf_pending.back().isNull());
868         }
869         rf_pending.pop_back();
870       }
871       else
872       {
873         // a loop
874         const Integer& i = n.getConst<UninterpretedConstant>().getIndex();
875         uint32_t index = i.toUnsignedInt();
876         if (index >= sk.size())
877         {
878           return Node::null();
879         }
880         Assert(sk.size() == rf_pending.size());
881         Node r = rf_pending[rf_pending.size() - 1 - index];
882         if (r.isNull())
883         {
884           r = NodeManager::currentNM()->mkBoundVar(
885               sk[rf_pending.size() - 1 - index].getType());
886           rf_pending[rf_pending.size() - 1 - index] = r;
887         }
888         return r;
889       }
890     }
891   }
892   Trace("dt-nconst-debug") << "Return term : " << ret << " from " << n
893                            << std::endl;
894   if (std::find(terms.begin(), terms.end(), ret) == terms.end())
895   {
896     terms.push_back(ret);
897     Assert(ret.getType() == tn);
898     cdts[ret] = isCdt;
899   }
900   return ret;
901 }
902 // eqc_stack stores depth
normalizeCodatatypeConstantEqc(Node n,std::map<int,int> & eqc_stack,std::map<Node,int> & eqc,int depth)903 Node DatatypesRewriter::normalizeCodatatypeConstantEqc(
904     Node n, std::map<int, int>& eqc_stack, std::map<Node, int>& eqc, int depth)
905 {
906   Trace("dt-nconst-debug") << "normalizeCodatatypeConstantEqc: " << n
907                            << " depth=" << depth << std::endl;
908   if (eqc.find(n) != eqc.end())
909   {
910     int e = eqc[n];
911     std::map<int, int>::iterator it = eqc_stack.find(e);
912     if (it != eqc_stack.end())
913     {
914       int debruijn = depth - it->second - 1;
915       return NodeManager::currentNM()->mkConst(
916           UninterpretedConstant(n.getType().toType(), debruijn));
917     }
918     std::vector<Node> children;
919     bool childChanged = false;
920     eqc_stack[e] = depth;
921     for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
922     {
923       Node nc = normalizeCodatatypeConstantEqc(n[i], eqc_stack, eqc, depth + 1);
924       children.push_back(nc);
925       childChanged = childChanged || nc != n[i];
926     }
927     eqc_stack.erase(e);
928     if (childChanged)
929     {
930       Assert(n.getKind() == kind::APPLY_CONSTRUCTOR);
931       children.insert(children.begin(), n.getOperator());
932       return NodeManager::currentNM()->mkNode(n.getKind(), children);
933     }
934   }
935   return n;
936 }
937 
replaceDebruijn(Node n,Node orig,TypeNode orig_tn,unsigned depth)938 Node DatatypesRewriter::replaceDebruijn(Node n,
939                                         Node orig,
940                                         TypeNode orig_tn,
941                                         unsigned depth)
942 {
943   if (n.getKind() == kind::UNINTERPRETED_CONSTANT && n.getType() == orig_tn)
944   {
945     unsigned index =
946         n.getConst<UninterpretedConstant>().getIndex().toUnsignedInt();
947     if (index == depth)
948     {
949       return orig;
950     }
951   }
952   else if (n.getNumChildren() > 0)
953   {
954     std::vector<Node> children;
955     bool childChanged = false;
956     for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
957     {
958       Node nc = replaceDebruijn(n[i], orig, orig_tn, depth + 1);
959       children.push_back(nc);
960       childChanged = childChanged || nc != n[i];
961     }
962     if (childChanged)
963     {
964       if (n.hasOperator())
965       {
966         children.insert(children.begin(), n.getOperator());
967       }
968       return NodeManager::currentNM()->mkNode(n.getKind(), children);
969     }
970   }
971   return n;
972 }
973 
974 } /* CVC4::theory::datatypes namespace */
975 } /* CVC4::theory namespace */
976 } /* CVC4 namespace */
977