1 /********************* */
2 /*! \file datatypes_rewriter.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Morgan Deters, Paul Meng
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Implementation of rewriter for the theory of (co)inductive datatypes.
13 **
14 ** Implementation of rewriter for the theory of (co)inductive datatypes.
15 **/
16
17 #include "theory/datatypes/datatypes_rewriter.h"
18
19 using namespace CVC4;
20 using namespace CVC4::kind;
21
22 namespace CVC4 {
23 namespace theory {
24 namespace datatypes {
25
postRewrite(TNode in)26 RewriteResponse DatatypesRewriter::postRewrite(TNode in)
27 {
28 Trace("datatypes-rewrite-debug") << "post-rewriting " << in << std::endl;
29 Kind k = in.getKind();
30 NodeManager* nm = NodeManager::currentNM();
31 if (k == kind::APPLY_CONSTRUCTOR)
32 {
33 return rewriteConstructor(in);
34 }
35 else if (k == kind::APPLY_SELECTOR_TOTAL || k == kind::APPLY_SELECTOR)
36 {
37 return rewriteSelector(in);
38 }
39 else if (k == kind::APPLY_TESTER)
40 {
41 return rewriteTester(in);
42 }
43 else if (k == kind::DT_SIZE)
44 {
45 if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
46 {
47 std::vector<Node> children;
48 for (unsigned i = 0, size = in [0].getNumChildren(); i < size; i++)
49 {
50 if (in[0][i].getType().isDatatype())
51 {
52 children.push_back(nm->mkNode(kind::DT_SIZE, in[0][i]));
53 }
54 }
55 TNode constructor = in[0].getOperator();
56 size_t constructorIndex = indexOf(constructor);
57 const Datatype& dt = Datatype::datatypeOf(constructor.toExpr());
58 const DatatypeConstructor& c = dt[constructorIndex];
59 unsigned weight = c.getWeight();
60 children.push_back(nm->mkConst(Rational(weight)));
61 Node res =
62 children.size() == 1 ? children[0] : nm->mkNode(kind::PLUS, children);
63 Trace("datatypes-rewrite")
64 << "DatatypesRewriter::postRewrite: rewrite size " << in << " to "
65 << res << std::endl;
66 return RewriteResponse(REWRITE_AGAIN_FULL, res);
67 }
68 }
69 else if (k == kind::DT_HEIGHT_BOUND)
70 {
71 if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
72 {
73 std::vector<Node> children;
74 Node res;
75 Rational r = in[1].getConst<Rational>();
76 Rational rmo = Rational(r - Rational(1));
77 for (unsigned i = 0, size = in [0].getNumChildren(); i < size; i++)
78 {
79 if (in[0][i].getType().isDatatype())
80 {
81 if (r.isZero())
82 {
83 res = nm->mkConst(false);
84 break;
85 }
86 children.push_back(
87 nm->mkNode(kind::DT_HEIGHT_BOUND, in[0][i], nm->mkConst(rmo)));
88 }
89 }
90 if (res.isNull())
91 {
92 res = children.size() == 0
93 ? nm->mkConst(true)
94 : (children.size() == 1 ? children[0]
95 : nm->mkNode(kind::AND, children));
96 }
97 Trace("datatypes-rewrite")
98 << "DatatypesRewriter::postRewrite: rewrite height " << in << " to "
99 << res << std::endl;
100 return RewriteResponse(REWRITE_AGAIN_FULL, res);
101 }
102 }
103 else if (k == kind::DT_SIZE_BOUND)
104 {
105 if (in[0].isConst())
106 {
107 Node res = nm->mkNode(kind::LEQ, nm->mkNode(kind::DT_SIZE, in[0]), in[1]);
108 return RewriteResponse(REWRITE_AGAIN_FULL, res);
109 }
110 }
111 else if (k == DT_SYGUS_EVAL)
112 {
113 // sygus evaluation function
114 Node ev = in[0];
115 if (ev.getKind() == APPLY_CONSTRUCTOR)
116 {
117 Trace("dt-sygus-util") << "Rewrite " << in << " by unfolding...\n";
118 const Datatype& dt =
119 static_cast<DatatypeType>(ev.getType().toType()).getDatatype();
120 unsigned i = indexOf(ev.getOperator());
121 Node op = Node::fromExpr(dt[i].getSygusOp());
122 // if it is the "any constant" constructor, return its argument
123 if (op.getAttribute(SygusAnyConstAttribute()))
124 {
125 Assert(ev.getNumChildren() == 1);
126 Assert(ev[0].getType().isComparableTo(in.getType()));
127 return RewriteResponse(REWRITE_AGAIN_FULL, ev[0]);
128 }
129 std::vector<Node> args;
130 for (unsigned j = 1, nchild = in.getNumChildren(); j < nchild; j++)
131 {
132 args.push_back(in[j]);
133 }
134 Assert(!dt.isParametric());
135 std::vector<Node> children;
136 for (const Node& evc : ev)
137 {
138 std::vector<Node> cc;
139 cc.push_back(evc);
140 cc.insert(cc.end(), args.begin(), args.end());
141 children.push_back(nm->mkNode(DT_SYGUS_EVAL, cc));
142 }
143 Node ret = mkSygusTerm(dt, i, children);
144 // if it is a variable, apply the substitution
145 if (ret.getKind() == BOUND_VARIABLE)
146 {
147 Assert(ret.hasAttribute(SygusVarNumAttribute()));
148 int vn = ret.getAttribute(SygusVarNumAttribute());
149 Assert(Node::fromExpr(dt.getSygusVarList())[vn] == ret);
150 ret = args[vn];
151 }
152 Trace("dt-sygus-util") << "...got " << ret << "\n";
153 return RewriteResponse(REWRITE_AGAIN_FULL, ret);
154 }
155 }
156
157 if (k == kind::EQUAL)
158 {
159 if (in[0] == in[1])
160 {
161 return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
162 }
163 std::vector<Node> rew;
164 if (checkClash(in[0], in[1], rew))
165 {
166 Trace("datatypes-rewrite")
167 << "Rewrite clashing equality " << in << " to false" << std::endl;
168 return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
169 //}else if( rew.size()==1 && rew[0]!=in ){
170 // Trace("datatypes-rewrite") << "Rewrite equality " << in << " to " <<
171 // rew[0] << std::endl;
172 // return RewriteResponse(REWRITE_AGAIN_FULL, rew[0] );
173 }
174 else if (in[1] < in[0])
175 {
176 Node ins = nm->mkNode(in.getKind(), in[1], in[0]);
177 Trace("datatypes-rewrite")
178 << "Swap equality " << in << " to " << ins << std::endl;
179 return RewriteResponse(REWRITE_DONE, ins);
180 }
181 Trace("datatypes-rewrite-debug")
182 << "Did not rewrite equality " << in << " " << in[0].getKind() << " "
183 << in[1].getKind() << std::endl;
184 }
185
186 return RewriteResponse(REWRITE_DONE, in);
187 }
188
getOperatorKindForSygusBuiltin(Node op)189 Kind DatatypesRewriter::getOperatorKindForSygusBuiltin(Node op)
190 {
191 Assert(op.getKind() != BUILTIN);
192 if (op.getKind() == LAMBDA)
193 {
194 // we use APPLY_UF instead of APPLY, since the rewriter for APPLY_UF
195 // does beta-reduction but does not for APPLY
196 return APPLY_UF;
197 }
198 TypeNode tn = op.getType();
199 if (tn.isConstructor())
200 {
201 return APPLY_CONSTRUCTOR;
202 }
203 else if (tn.isSelector())
204 {
205 return APPLY_SELECTOR;
206 }
207 else if (tn.isTester())
208 {
209 return APPLY_TESTER;
210 }
211 else if (tn.isFunction())
212 {
213 return APPLY_UF;
214 }
215 return UNDEFINED_KIND;
216 }
217
mkSygusTerm(const Datatype & dt,unsigned i,const std::vector<Node> & children)218 Node DatatypesRewriter::mkSygusTerm(const Datatype& dt,
219 unsigned i,
220 const std::vector<Node>& children)
221 {
222 Trace("dt-sygus-util") << "Make sygus term " << dt.getName() << "[" << i
223 << "] with children: " << children << std::endl;
224 Assert(i < dt.getNumConstructors());
225 Assert(dt.isSygus());
226 Assert(!dt[i].getSygusOp().isNull());
227 std::vector<Node> schildren;
228 Node op = Node::fromExpr(dt[i].getSygusOp());
229 // if it is the any constant, we simply return the child
230 if (op.getAttribute(SygusAnyConstAttribute()))
231 {
232 Assert(children.size() == 1);
233 return children[0];
234 }
235 if (op.getKind() != BUILTIN)
236 {
237 schildren.push_back(op);
238 }
239 schildren.insert(schildren.end(), children.begin(), children.end());
240 Node ret;
241 if (op.getKind() == BUILTIN)
242 {
243 ret = NodeManager::currentNM()->mkNode(op, schildren);
244 Trace("dt-sygus-util") << "...return (builtin) " << ret << std::endl;
245 return ret;
246 }
247 Kind ok = NodeManager::operatorToKind(op);
248 if (ok != UNDEFINED_KIND)
249 {
250 ret = NodeManager::currentNM()->mkNode(ok, schildren);
251 Trace("dt-sygus-util") << "...return (op) " << ret << std::endl;
252 return ret;
253 }
254 Kind tok = getOperatorKindForSygusBuiltin(op);
255 if (schildren.size() == 1 && tok == kind::UNDEFINED_KIND)
256 {
257 ret = schildren[0];
258 }
259 else
260 {
261 ret = NodeManager::currentNM()->mkNode(tok, schildren);
262 }
263 Trace("dt-sygus-util") << "...return " << ret << std::endl;
264 return ret;
265 }
266
preRewrite(TNode in)267 RewriteResponse DatatypesRewriter::preRewrite(TNode in)
268 {
269 Trace("datatypes-rewrite-debug") << "pre-rewriting " << in << std::endl;
270 // must prewrite to apply type ascriptions since rewriting does not preserve
271 // types
272 if (in.getKind() == kind::APPLY_CONSTRUCTOR)
273 {
274 TypeNode tn = in.getType();
275 Type t = tn.toType();
276 DatatypeType dt = DatatypeType(t);
277
278 // check for parametric datatype constructors
279 // to ensure a normal form, all parameteric datatype constructors must have
280 // a type ascription
281 if (dt.isParametric())
282 {
283 if (in.getOperator().getKind() != kind::APPLY_TYPE_ASCRIPTION)
284 {
285 Trace("datatypes-rewrite-debug")
286 << "Ascribing type to parametric datatype constructor " << in
287 << std::endl;
288 Node op = in.getOperator();
289 // get the constructor object
290 const DatatypeConstructor& dtc =
291 Datatype::datatypeOf(op.toExpr())[indexOf(op)];
292 // create ascribed constructor type
293 Node tc = NodeManager::currentNM()->mkConst(
294 AscriptionType(dtc.getSpecializedConstructorType(t)));
295 Node op_new = NodeManager::currentNM()->mkNode(
296 kind::APPLY_TYPE_ASCRIPTION, tc, op);
297 // make new node
298 std::vector<Node> children;
299 children.push_back(op_new);
300 children.insert(children.end(), in.begin(), in.end());
301 Node inr =
302 NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, children);
303 Trace("datatypes-rewrite-debug") << "Created " << inr << std::endl;
304 return RewriteResponse(REWRITE_DONE, inr);
305 }
306 }
307 }
308 return RewriteResponse(REWRITE_DONE, in);
309 }
310
rewriteConstructor(TNode in)311 RewriteResponse DatatypesRewriter::rewriteConstructor(TNode in)
312 {
313 if (in.isConst())
314 {
315 Trace("datatypes-rewrite-debug") << "Normalizing constant " << in
316 << std::endl;
317 Node inn = normalizeConstant(in);
318 // constant may be a subterm of another constant, so cannot assume that this
319 // will succeed for codatatypes
320 // Assert( !inn.isNull() );
321 if (!inn.isNull() && inn != in)
322 {
323 Trace("datatypes-rewrite") << "Normalized constant " << in << " -> "
324 << inn << std::endl;
325 return RewriteResponse(REWRITE_DONE, inn);
326 }
327 return RewriteResponse(REWRITE_DONE, in);
328 }
329 return RewriteResponse(REWRITE_DONE, in);
330 }
331
rewriteSelector(TNode in)332 RewriteResponse DatatypesRewriter::rewriteSelector(TNode in)
333 {
334 Kind k = in.getKind();
335 if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
336 {
337 // Have to be careful not to rewrite well-typed expressions
338 // where the selector doesn't match the constructor,
339 // e.g. "pred(zero)".
340 TypeNode tn = in.getType();
341 TypeNode argType = in[0].getType();
342 Expr selector = in.getOperator().toExpr();
343 TNode constructor = in[0].getOperator();
344 size_t constructorIndex = indexOf(constructor);
345 const Datatype& dt = Datatype::datatypeOf(selector);
346 const DatatypeConstructor& c = dt[constructorIndex];
347 Trace("datatypes-rewrite-debug") << "Rewriting collapsable selector : "
348 << in;
349 Trace("datatypes-rewrite-debug") << ", cindex = " << constructorIndex
350 << ", selector is " << selector
351 << std::endl;
352 // The argument that the selector extracts, or -1 if the selector is
353 // is wrongly applied.
354 int selectorIndex = -1;
355 if (k == kind::APPLY_SELECTOR_TOTAL)
356 {
357 // The argument index of internal selectors is obtained by
358 // getSelectorIndexInternal.
359 selectorIndex = c.getSelectorIndexInternal(selector);
360 }
361 else
362 {
363 // The argument index of external selectors (applications of
364 // APPLY_SELECTOR) is given by an attribute and obtained via indexOf below
365 // The argument is only valid if it is the proper constructor.
366 selectorIndex = Datatype::indexOf(selector);
367 if (selectorIndex < 0
368 || selectorIndex >= static_cast<int>(c.getNumArgs()))
369 {
370 selectorIndex = -1;
371 }
372 else if (c[selectorIndex].getSelector() != selector)
373 {
374 selectorIndex = -1;
375 }
376 }
377 Trace("datatypes-rewrite-debug") << "Internal selector index is "
378 << selectorIndex << std::endl;
379 if (selectorIndex >= 0)
380 {
381 Assert(selectorIndex < (int)c.getNumArgs());
382 if (dt.isCodatatype() && in[0][selectorIndex].isConst())
383 {
384 // must replace all debruijn indices with self
385 Node sub = replaceDebruijn(in[0][selectorIndex], in[0], argType, 0);
386 Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
387 << "Rewrite trivial codatatype selector "
388 << in << " to " << sub << std::endl;
389 if (sub != in)
390 {
391 return RewriteResponse(REWRITE_AGAIN_FULL, sub);
392 }
393 }
394 else
395 {
396 Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
397 << "Rewrite trivial selector " << in
398 << std::endl;
399 return RewriteResponse(REWRITE_DONE, in[0][selectorIndex]);
400 }
401 }
402 else if (k == kind::APPLY_SELECTOR_TOTAL)
403 {
404 Node gt;
405 bool useTe = true;
406 // if( !tn.isSort() ){
407 // useTe = false;
408 //}
409 if (tn.isDatatype())
410 {
411 const Datatype& dta = ((DatatypeType)(tn).toType()).getDatatype();
412 useTe = !dta.isCodatatype();
413 }
414 if (useTe)
415 {
416 TypeEnumerator te(tn);
417 gt = *te;
418 }
419 else
420 {
421 gt = tn.mkGroundTerm();
422 }
423 if (!gt.isNull())
424 {
425 // Assert( gtt.isDatatype() || gtt.isParametricDatatype() );
426 if (tn.isDatatype() && !tn.isInstantiatedDatatype())
427 {
428 gt = NodeManager::currentNM()->mkNode(
429 kind::APPLY_TYPE_ASCRIPTION,
430 NodeManager::currentNM()->mkConst(AscriptionType(tn.toType())),
431 gt);
432 }
433 Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
434 << "Rewrite trivial selector " << in
435 << " to distinguished ground term " << gt
436 << std::endl;
437 return RewriteResponse(REWRITE_DONE, gt);
438 }
439 }
440 }
441 return RewriteResponse(REWRITE_DONE, in);
442 }
443
rewriteTester(TNode in)444 RewriteResponse DatatypesRewriter::rewriteTester(TNode in)
445 {
446 if (in[0].getKind() == kind::APPLY_CONSTRUCTOR)
447 {
448 bool result = indexOf(in.getOperator()) == indexOf(in[0].getOperator());
449 Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
450 << "Rewrite trivial tester " << in << " "
451 << result << std::endl;
452 return RewriteResponse(REWRITE_DONE,
453 NodeManager::currentNM()->mkConst(result));
454 }
455 const Datatype& dt = static_cast<DatatypeType>(in[0].getType().toType()).getDatatype();
456 if (dt.getNumConstructors() == 1 && !dt.isSygus())
457 {
458 // only one constructor, so it must be
459 Trace("datatypes-rewrite")
460 << "DatatypesRewriter::postRewrite: "
461 << "only one ctor for " << dt.getName() << " and that is "
462 << dt[0].getName() << std::endl;
463 return RewriteResponse(REWRITE_DONE,
464 NodeManager::currentNM()->mkConst(true));
465 }
466 // could try dt.getNumConstructors()==2 && indexOf(in.getOperator())==1 ?
467 return RewriteResponse(REWRITE_DONE, in);
468 }
469
checkClash(Node n1,Node n2,std::vector<Node> & rew)470 bool DatatypesRewriter::checkClash(Node n1, Node n2, std::vector<Node>& rew)
471 {
472 Trace("datatypes-rewrite-debug") << "Check clash : " << n1 << " " << n2
473 << std::endl;
474 if (n1.getKind() == kind::APPLY_CONSTRUCTOR
475 && n2.getKind() == kind::APPLY_CONSTRUCTOR)
476 {
477 if (n1.getOperator() != n2.getOperator())
478 {
479 Trace("datatypes-rewrite-debug") << "Clash operators : " << n1 << " "
480 << n2 << " " << n1.getOperator() << " "
481 << n2.getOperator() << std::endl;
482 return true;
483 }
484 Assert(n1.getNumChildren() == n2.getNumChildren());
485 for (unsigned i = 0, size = n1.getNumChildren(); i < size; i++)
486 {
487 if (checkClash(n1[i], n2[i], rew))
488 {
489 return true;
490 }
491 }
492 }
493 else if (n1 != n2)
494 {
495 if (n1.isConst() && n2.isConst())
496 {
497 Trace("datatypes-rewrite-debug") << "Clash constants : " << n1 << " "
498 << n2 << std::endl;
499 return true;
500 }
501 else
502 {
503 Node eq = NodeManager::currentNM()->mkNode(kind::EQUAL, n1, n2);
504 rew.push_back(eq);
505 }
506 }
507 return false;
508 }
509 /** get instantiate cons */
getInstCons(Node n,const Datatype & dt,int index)510 Node DatatypesRewriter::getInstCons(Node n, const Datatype& dt, int index)
511 {
512 Assert(index >= 0 && index < (int)dt.getNumConstructors());
513 std::vector<Node> children;
514 NodeManager* nm = NodeManager::currentNM();
515 children.push_back(Node::fromExpr(dt[index].getConstructor()));
516 Type t = n.getType().toType();
517 for (unsigned i = 0, nargs = dt[index].getNumArgs(); i < nargs; i++)
518 {
519 Node nc = nm->mkNode(kind::APPLY_SELECTOR_TOTAL,
520 Node::fromExpr(dt[index].getSelectorInternal(t, i)),
521 n);
522 children.push_back(nc);
523 }
524 Node n_ic = nm->mkNode(kind::APPLY_CONSTRUCTOR, children);
525 if (dt.isParametric())
526 {
527 TypeNode tn = TypeNode::fromType(t);
528 // add type ascription for ambiguous constructor types
529 if (!n_ic.getType().isComparableTo(tn))
530 {
531 Debug("datatypes-parametric") << "DtInstantiate: ambiguous type for "
532 << n_ic << ", ascribe to " << n.getType()
533 << std::endl;
534 Debug("datatypes-parametric") << "Constructor is " << dt[index]
535 << std::endl;
536 Type tspec =
537 dt[index].getSpecializedConstructorType(n.getType().toType());
538 Debug("datatypes-parametric") << "Type specification is " << tspec
539 << std::endl;
540 children[0] = nm->mkNode(kind::APPLY_TYPE_ASCRIPTION,
541 nm->mkConst(AscriptionType(tspec)),
542 children[0]);
543 n_ic = nm->mkNode(kind::APPLY_CONSTRUCTOR, children);
544 Assert(n_ic.getType() == tn);
545 }
546 }
547 Assert(isInstCons(n, n_ic, dt) == index);
548 // n_ic = Rewriter::rewrite( n_ic );
549 return n_ic;
550 }
551
isInstCons(Node t,Node n,const Datatype & dt)552 int DatatypesRewriter::isInstCons(Node t, Node n, const Datatype& dt)
553 {
554 if (n.getKind() == kind::APPLY_CONSTRUCTOR)
555 {
556 int index = indexOf(n.getOperator());
557 const DatatypeConstructor& c = dt[index];
558 Type nt = n.getType().toType();
559 for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
560 {
561 if (n[i].getKind() != kind::APPLY_SELECTOR_TOTAL
562 || n[i].getOperator() != Node::fromExpr(c.getSelectorInternal(nt, i))
563 || n[i][0] != t)
564 {
565 return -1;
566 }
567 }
568 return index;
569 }
570 return -1;
571 }
572
isTester(Node n,Node & a)573 int DatatypesRewriter::isTester(Node n, Node& a)
574 {
575 if (n.getKind() == kind::APPLY_TESTER)
576 {
577 a = n[0];
578 return indexOf(n.getOperator());
579 }
580 return -1;
581 }
582
isTester(Node n)583 int DatatypesRewriter::isTester(Node n)
584 {
585 if (n.getKind() == kind::APPLY_TESTER)
586 {
587 return indexOf(n.getOperator().toExpr());
588 }
589 return -1;
590 }
591
592 struct DtIndexAttributeId
593 {
594 };
595 typedef expr::Attribute<DtIndexAttributeId, uint64_t> DtIndexAttribute;
596
indexOf(Node n)597 unsigned DatatypesRewriter::indexOf(Node n)
598 {
599 if (!n.hasAttribute(DtIndexAttribute()))
600 {
601 Assert(n.getType().isConstructor() || n.getType().isTester()
602 || n.getType().isSelector());
603 unsigned index = Datatype::indexOfInternal(n.toExpr());
604 n.setAttribute(DtIndexAttribute(), index);
605 return index;
606 }
607 return n.getAttribute(DtIndexAttribute());
608 }
609
mkTester(Node n,int i,const Datatype & dt)610 Node DatatypesRewriter::mkTester(Node n, int i, const Datatype& dt)
611 {
612 return NodeManager::currentNM()->mkNode(
613 kind::APPLY_TESTER, Node::fromExpr(dt[i].getTester()), n);
614 }
615
mkSplit(Node n,const Datatype & dt)616 Node DatatypesRewriter::mkSplit(Node n, const Datatype& dt)
617 {
618 std::vector<Node> splits;
619 for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++)
620 {
621 Node test = mkTester(n, i, dt);
622 splits.push_back(test);
623 }
624 NodeManager* nm = NodeManager::currentNM();
625 return splits.size() == 1 ? splits[0] : nm->mkNode(kind::OR, splits);
626 }
627
isNullaryApplyConstructor(Node n)628 bool DatatypesRewriter::isNullaryApplyConstructor(Node n)
629 {
630 Assert(n.getKind() == kind::APPLY_CONSTRUCTOR);
631 for (unsigned i = 0; i < n.getNumChildren(); i++)
632 {
633 if (n[i].getType().isDatatype())
634 {
635 return false;
636 }
637 }
638 return true;
639 }
640
isNullaryConstructor(const DatatypeConstructor & c)641 bool DatatypesRewriter::isNullaryConstructor(const DatatypeConstructor& c)
642 {
643 for (unsigned j = 0, nargs = c.getNumArgs(); j < nargs; j++)
644 {
645 if (c[j].getType().getRangeType().isDatatype())
646 {
647 return false;
648 }
649 }
650 return true;
651 }
652
normalizeCodatatypeConstant(Node n)653 Node DatatypesRewriter::normalizeCodatatypeConstant(Node n)
654 {
655 Trace("dt-nconst") << "Normalize " << n << std::endl;
656 std::map<Node, Node> rf;
657 std::vector<Node> sk;
658 std::vector<Node> rf_pending;
659 std::vector<Node> terms;
660 std::map<Node, bool> cdts;
661 Node s = collectRef(n, sk, rf, rf_pending, terms, cdts);
662 if (!s.isNull())
663 {
664 Trace("dt-nconst") << "...symbolic normalized is : " << s << std::endl;
665 for (std::map<Node, Node>::iterator it = rf.begin(); it != rf.end(); ++it)
666 {
667 Trace("dt-nconst") << " " << it->first << " = " << it->second
668 << std::endl;
669 }
670 // now run DFA minimization on term structure
671 Trace("dt-nconst") << " " << terms.size()
672 << " total subterms :" << std::endl;
673 int eqc_count = 0;
674 std::map<Node, int> eqc_op_map;
675 std::map<Node, int> eqc;
676 std::map<int, std::vector<Node> > eqc_nodes;
677 // partition based on top symbol
678 for (unsigned j = 0, size = terms.size(); j < size; j++)
679 {
680 Node t = terms[j];
681 Trace("dt-nconst") << " " << t << ", cdt=" << cdts[t] << std::endl;
682 int e;
683 if (cdts[t])
684 {
685 Assert(t.getKind() == kind::APPLY_CONSTRUCTOR);
686 Node op = t.getOperator();
687 std::map<Node, int>::iterator it = eqc_op_map.find(op);
688 if (it == eqc_op_map.end())
689 {
690 e = eqc_count;
691 eqc_op_map[op] = eqc_count;
692 eqc_count++;
693 }
694 else
695 {
696 e = it->second;
697 }
698 }
699 else
700 {
701 e = eqc_count;
702 eqc_count++;
703 }
704 eqc[t] = e;
705 eqc_nodes[e].push_back(t);
706 }
707 // partition until fixed point
708 int eqc_curr = 0;
709 bool success = true;
710 while (success)
711 {
712 success = false;
713 int eqc_end = eqc_count;
714 while (eqc_curr < eqc_end)
715 {
716 if (eqc_nodes[eqc_curr].size() > 1)
717 {
718 // look at all nodes in this equivalence class
719 unsigned nchildren = eqc_nodes[eqc_curr][0].getNumChildren();
720 std::map<int, std::vector<Node> > prt;
721 for (unsigned j = 0; j < nchildren; j++)
722 {
723 prt.clear();
724 // partition based on children : for the first child that causes a
725 // split, break
726 for (unsigned k = 0, size = eqc_nodes[eqc_curr].size(); k < size;
727 k++)
728 {
729 Node t = eqc_nodes[eqc_curr][k];
730 Assert(t.getNumChildren() == nchildren);
731 Node tc = t[j];
732 // refer to loops
733 std::map<Node, Node>::iterator itr = rf.find(tc);
734 if (itr != rf.end())
735 {
736 tc = itr->second;
737 }
738 Assert(eqc.find(tc) != eqc.end());
739 prt[eqc[tc]].push_back(t);
740 }
741 if (prt.size() > 1)
742 {
743 success = true;
744 break;
745 }
746 }
747 // move into new eqc(s)
748 for (const std::pair<const int, std::vector<Node> >& p : prt)
749 {
750 int e = eqc_count;
751 for (unsigned j = 0, size = p.second.size(); j < size; j++)
752 {
753 Node t = p.second[j];
754 eqc[t] = e;
755 eqc_nodes[e].push_back(t);
756 }
757 eqc_count++;
758 }
759 }
760 eqc_nodes.erase(eqc_curr);
761 eqc_curr++;
762 }
763 }
764 // add in already occurring loop variables
765 for (std::map<Node, Node>::iterator it = rf.begin(); it != rf.end(); ++it)
766 {
767 Trace("dt-nconst-debug") << "Mapping equivalence class of " << it->first
768 << " -> " << it->second << std::endl;
769 Assert(eqc.find(it->second) != eqc.end());
770 eqc[it->first] = eqc[it->second];
771 }
772 // we now have a partition of equivalent terms
773 Trace("dt-nconst") << "Computed equivalence classes ids : " << std::endl;
774 for (std::map<Node, int>::iterator it = eqc.begin(); it != eqc.end(); ++it)
775 {
776 Trace("dt-nconst") << " " << it->first << " -> " << it->second
777 << std::endl;
778 }
779 // traverse top-down based on equivalence class
780 std::map<int, int> eqc_stack;
781 return normalizeCodatatypeConstantEqc(s, eqc_stack, eqc, 0);
782 }
783 Trace("dt-nconst") << "...invalid." << std::endl;
784 return Node::null();
785 }
786
787 // normalize constant : apply to top-level codatatype constants
normalizeConstant(Node n)788 Node DatatypesRewriter::normalizeConstant(Node n)
789 {
790 TypeNode tn = n.getType();
791 if (tn.isDatatype())
792 {
793 if (tn.isCodatatype())
794 {
795 return normalizeCodatatypeConstant(n);
796 }
797 else
798 {
799 std::vector<Node> children;
800 bool childrenChanged = false;
801 for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
802 {
803 Node nc = normalizeConstant(n[i]);
804 children.push_back(nc);
805 childrenChanged = childrenChanged || nc != n[i];
806 }
807 if (childrenChanged)
808 {
809 return NodeManager::currentNM()->mkNode(n.getKind(), children);
810 }
811 }
812 }
813 return n;
814 }
815
collectRef(Node n,std::vector<Node> & sk,std::map<Node,Node> & rf,std::vector<Node> & rf_pending,std::vector<Node> & terms,std::map<Node,bool> & cdts)816 Node DatatypesRewriter::collectRef(Node n,
817 std::vector<Node>& sk,
818 std::map<Node, Node>& rf,
819 std::vector<Node>& rf_pending,
820 std::vector<Node>& terms,
821 std::map<Node, bool>& cdts)
822 {
823 Assert(n.isConst());
824 TypeNode tn = n.getType();
825 Node ret = n;
826 bool isCdt = false;
827 if (tn.isDatatype())
828 {
829 if (!tn.isCodatatype())
830 {
831 // nested datatype within codatatype : can be normalized independently
832 // since all loops should be self-contained
833 ret = normalizeConstant(n);
834 }
835 else
836 {
837 isCdt = true;
838 if (n.getKind() == kind::APPLY_CONSTRUCTOR)
839 {
840 sk.push_back(n);
841 rf_pending.push_back(Node::null());
842 std::vector<Node> children;
843 children.push_back(n.getOperator());
844 bool childChanged = false;
845 for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
846 {
847 Node nc = collectRef(n[i], sk, rf, rf_pending, terms, cdts);
848 if (nc.isNull())
849 {
850 return Node::null();
851 }
852 childChanged = nc != n[i] || childChanged;
853 children.push_back(nc);
854 }
855 sk.pop_back();
856 if (childChanged)
857 {
858 ret = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR,
859 children);
860 if (!rf_pending.back().isNull())
861 {
862 rf[rf_pending.back()] = ret;
863 }
864 }
865 else
866 {
867 Assert(rf_pending.back().isNull());
868 }
869 rf_pending.pop_back();
870 }
871 else
872 {
873 // a loop
874 const Integer& i = n.getConst<UninterpretedConstant>().getIndex();
875 uint32_t index = i.toUnsignedInt();
876 if (index >= sk.size())
877 {
878 return Node::null();
879 }
880 Assert(sk.size() == rf_pending.size());
881 Node r = rf_pending[rf_pending.size() - 1 - index];
882 if (r.isNull())
883 {
884 r = NodeManager::currentNM()->mkBoundVar(
885 sk[rf_pending.size() - 1 - index].getType());
886 rf_pending[rf_pending.size() - 1 - index] = r;
887 }
888 return r;
889 }
890 }
891 }
892 Trace("dt-nconst-debug") << "Return term : " << ret << " from " << n
893 << std::endl;
894 if (std::find(terms.begin(), terms.end(), ret) == terms.end())
895 {
896 terms.push_back(ret);
897 Assert(ret.getType() == tn);
898 cdts[ret] = isCdt;
899 }
900 return ret;
901 }
902 // eqc_stack stores depth
normalizeCodatatypeConstantEqc(Node n,std::map<int,int> & eqc_stack,std::map<Node,int> & eqc,int depth)903 Node DatatypesRewriter::normalizeCodatatypeConstantEqc(
904 Node n, std::map<int, int>& eqc_stack, std::map<Node, int>& eqc, int depth)
905 {
906 Trace("dt-nconst-debug") << "normalizeCodatatypeConstantEqc: " << n
907 << " depth=" << depth << std::endl;
908 if (eqc.find(n) != eqc.end())
909 {
910 int e = eqc[n];
911 std::map<int, int>::iterator it = eqc_stack.find(e);
912 if (it != eqc_stack.end())
913 {
914 int debruijn = depth - it->second - 1;
915 return NodeManager::currentNM()->mkConst(
916 UninterpretedConstant(n.getType().toType(), debruijn));
917 }
918 std::vector<Node> children;
919 bool childChanged = false;
920 eqc_stack[e] = depth;
921 for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
922 {
923 Node nc = normalizeCodatatypeConstantEqc(n[i], eqc_stack, eqc, depth + 1);
924 children.push_back(nc);
925 childChanged = childChanged || nc != n[i];
926 }
927 eqc_stack.erase(e);
928 if (childChanged)
929 {
930 Assert(n.getKind() == kind::APPLY_CONSTRUCTOR);
931 children.insert(children.begin(), n.getOperator());
932 return NodeManager::currentNM()->mkNode(n.getKind(), children);
933 }
934 }
935 return n;
936 }
937
replaceDebruijn(Node n,Node orig,TypeNode orig_tn,unsigned depth)938 Node DatatypesRewriter::replaceDebruijn(Node n,
939 Node orig,
940 TypeNode orig_tn,
941 unsigned depth)
942 {
943 if (n.getKind() == kind::UNINTERPRETED_CONSTANT && n.getType() == orig_tn)
944 {
945 unsigned index =
946 n.getConst<UninterpretedConstant>().getIndex().toUnsignedInt();
947 if (index == depth)
948 {
949 return orig;
950 }
951 }
952 else if (n.getNumChildren() > 0)
953 {
954 std::vector<Node> children;
955 bool childChanged = false;
956 for (unsigned i = 0, size = n.getNumChildren(); i < size; i++)
957 {
958 Node nc = replaceDebruijn(n[i], orig, orig_tn, depth + 1);
959 children.push_back(nc);
960 childChanged = childChanged || nc != n[i];
961 }
962 if (childChanged)
963 {
964 if (n.hasOperator())
965 {
966 children.insert(children.begin(), n.getOperator());
967 }
968 return NodeManager::currentNM()->mkNode(n.getKind(), children);
969 }
970 }
971 return n;
972 }
973
974 } /* CVC4::theory::datatypes namespace */
975 } /* CVC4::theory namespace */
976 } /* CVC4 namespace */
977