1 /*********************                                                        */
2 /*! \file term_util.cpp
3  ** \verbatim
4  ** Top contributors (to current version):
5  **   Andrew Reynolds, Morgan Deters, Andres Noetzli
6  ** This file is part of the CVC4 project.
7  ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8  ** in the top-level source directory) and their institutional affiliations.
9  ** All rights reserved.  See the file COPYING in the top-level source
10  ** directory for licensing information.\endverbatim
11  **
12  ** \brief Implementation of term utilities class
13  **/
14 
15 #include "theory/quantifiers/term_util.h"
16 
17 #include "expr/datatype.h"
18 #include "expr/node_algorithm.h"
19 #include "options/base_options.h"
20 #include "options/datatypes_options.h"
21 #include "options/quantifiers_options.h"
22 #include "options/uf_options.h"
23 #include "theory/arith/arith_msum.h"
24 #include "theory/bv/theory_bv_utils.h"
25 #include "theory/quantifiers/term_database.h"
26 #include "theory/quantifiers/term_enumeration.h"
27 #include "theory/quantifiers_engine.h"
28 #include "theory/theory_engine.h"
29 
30 using namespace std;
31 using namespace CVC4::kind;
32 using namespace CVC4::context;
33 using namespace CVC4::theory::inst;
34 
35 namespace CVC4 {
36 namespace theory {
37 namespace quantifiers {
38 
TermUtil(QuantifiersEngine * qe)39 TermUtil::TermUtil(QuantifiersEngine* qe) : d_quantEngine(qe)
40 {
41   d_true = NodeManager::currentNM()->mkConst(true);
42   d_false = NodeManager::currentNM()->mkConst(false);
43   d_zero = NodeManager::currentNM()->mkConst(Rational(0));
44   d_one = NodeManager::currentNM()->mkConst(Rational(1));
45 }
46 
~TermUtil()47 TermUtil::~TermUtil(){
48 
49 }
50 
registerQuantifier(Node q)51 void TermUtil::registerQuantifier( Node q ){
52   if( d_inst_constants.find( q )==d_inst_constants.end() ){
53     Debug("quantifiers-engine") << "Instantiation constants for " << q << " : " << std::endl;
54     for( unsigned i=0; i<q[0].getNumChildren(); i++ ){
55       d_vars[q].push_back( q[0][i] );
56       d_var_num[q][q[0][i]] = i;
57       //make instantiation constants
58       Node ic = NodeManager::currentNM()->mkInstConstant( q[0][i].getType() );
59       d_inst_constants_map[ic] = q;
60       d_inst_constants[ q ].push_back( ic );
61       Debug("quantifiers-engine") << "  " << ic << std::endl;
62       //set the var number attribute
63       InstVarNumAttribute ivna;
64       ic.setAttribute( ivna, i );
65       InstConstantAttribute ica;
66       ic.setAttribute( ica, q );
67     }
68   }
69 }
70 
getBoundVars2(Node n,std::vector<Node> & vars,std::map<Node,bool> & visited)71 void TermUtil::getBoundVars2( Node n, std::vector< Node >& vars, std::map< Node, bool >& visited ) {
72   if( visited.find( n )==visited.end() ){
73     visited[n] = true;
74     if( n.getKind()==BOUND_VARIABLE ){
75       if( std::find( vars.begin(), vars.end(), n )==vars.end() ) {
76         vars.push_back( n );
77       }
78     }
79     for( unsigned i=0; i<n.getNumChildren(); i++ ){
80       getBoundVars2( n[i], vars, visited );
81     }
82   }
83 }
84 
getRemoveQuantifiers2(Node n,std::map<Node,Node> & visited)85 Node TermUtil::getRemoveQuantifiers2( Node n, std::map< Node, Node >& visited ) {
86   std::map< Node, Node >::iterator it = visited.find( n );
87   if( it!=visited.end() ){
88     return it->second;
89   }else{
90     Node ret = n;
91     if( n.getKind()==FORALL ){
92       ret = getRemoveQuantifiers2( n[1], visited );
93     }else if( n.getNumChildren()>0 ){
94       std::vector< Node > children;
95       bool childrenChanged = false;
96       for( unsigned i=0; i<n.getNumChildren(); i++ ){
97         Node ni = getRemoveQuantifiers2( n[i], visited );
98         childrenChanged = childrenChanged || ni!=n[i];
99         children.push_back( ni );
100       }
101       if( childrenChanged ){
102         if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){
103           children.insert( children.begin(), n.getOperator() );
104         }
105         ret = NodeManager::currentNM()->mkNode( n.getKind(), children );
106       }
107     }
108     visited[n] = ret;
109     return ret;
110   }
111 }
112 
getInstConstAttr(Node n)113 Node TermUtil::getInstConstAttr( Node n ) {
114   if (!n.hasAttribute(InstConstantAttribute()) ){
115     Node q;
116     if (n.hasOperator())
117     {
118       q = getInstConstAttr(n.getOperator());
119     }
120     if (q.isNull())
121     {
122       for (const Node& nc : n)
123       {
124         q = getInstConstAttr(nc);
125         if (!q.isNull())
126         {
127           break;
128         }
129       }
130     }
131     InstConstantAttribute ica;
132     n.setAttribute(ica, q);
133   }
134   return n.getAttribute(InstConstantAttribute());
135 }
136 
hasInstConstAttr(Node n)137 bool TermUtil::hasInstConstAttr( Node n ) {
138   return !getInstConstAttr(n).isNull();
139 }
140 
getBoundVarAttr(Node n)141 Node TermUtil::getBoundVarAttr( Node n ) {
142   if (!n.hasAttribute(BoundVarAttribute()) ){
143     Node bv;
144     if( n.getKind()==BOUND_VARIABLE ){
145       bv = n;
146     }else{
147       for( unsigned i=0; i<n.getNumChildren(); i++ ){
148         bv = getBoundVarAttr(n[i]);
149         if( !bv.isNull() ){
150           break;
151         }
152       }
153     }
154     BoundVarAttribute bva;
155     n.setAttribute(bva, bv);
156   }
157   return n.getAttribute(BoundVarAttribute());
158 }
159 
hasBoundVarAttr(Node n)160 bool TermUtil::hasBoundVarAttr( Node n ) {
161   return !getBoundVarAttr(n).isNull();
162 }
163 
getBoundVars(Node n,std::vector<Node> & vars)164 void TermUtil::getBoundVars( Node n, std::vector< Node >& vars ) {
165   std::map< Node, bool > visited;
166   return getBoundVars2( n, vars, visited );
167 }
168 
169 //remove quantifiers
getRemoveQuantifiers(Node n)170 Node TermUtil::getRemoveQuantifiers( Node n ) {
171   std::map< Node, Node > visited;
172   return getRemoveQuantifiers2( n, visited );
173 }
174 
175 //quantified simplify
getQuantSimplify(Node n)176 Node TermUtil::getQuantSimplify( Node n ) {
177   std::vector< Node > bvs;
178   getBoundVars( n, bvs );
179   if( bvs.empty() ) {
180     return Rewriter::rewrite( n );
181   }else{
182     Node q = NodeManager::currentNM()->mkNode( FORALL, NodeManager::currentNM()->mkNode( BOUND_VAR_LIST, bvs ), n );
183     q = Rewriter::rewrite( q );
184     return getRemoveQuantifiers( q );
185   }
186 }
187 
188 /** get the i^th instantiation constant of q */
getInstantiationConstant(Node q,int i) const189 Node TermUtil::getInstantiationConstant( Node q, int i ) const {
190   std::map< Node, std::vector< Node > >::const_iterator it = d_inst_constants.find( q );
191   if( it!=d_inst_constants.end() ){
192     return it->second[i];
193   }else{
194     return Node::null();
195   }
196 }
197 
198 /** get number of instantiation constants for q */
getNumInstantiationConstants(Node q) const199 unsigned TermUtil::getNumInstantiationConstants( Node q ) const {
200   std::map< Node, std::vector< Node > >::const_iterator it = d_inst_constants.find( q );
201   if( it!=d_inst_constants.end() ){
202     return it->second.size();
203   }else{
204     return 0;
205   }
206 }
207 
getInstConstantBody(Node q)208 Node TermUtil::getInstConstantBody( Node q ){
209   std::map< Node, Node >::iterator it = d_inst_const_body.find( q );
210   if( it==d_inst_const_body.end() ){
211     Node n = substituteBoundVariablesToInstConstants(q[1], q);
212     d_inst_const_body[ q ] = n;
213     return n;
214   }else{
215     return it->second;
216   }
217 }
218 
getCounterexampleLiteral(Node q)219 Node TermUtil::getCounterexampleLiteral( Node q ){
220   if( d_ce_lit.find( q )==d_ce_lit.end() ){
221     /*
222     Node ceBody = getInstConstantBody( f );
223     //check if any variable are of bad types, and fail if so
224     for( size_t i=0; i<d_inst_constants[f].size(); i++ ){
225       if( d_inst_constants[f][i].getType().isBoolean() ){
226         d_ce_lit[ f ] = Node::null();
227         return Node::null();
228       }
229     }
230     */
231     Node g = NodeManager::currentNM()->mkSkolem( "g", NodeManager::currentNM()->booleanType() );
232     //otherwise, ensure literal
233     Node ceLit = d_quantEngine->getValuation().ensureLiteral( g );
234     d_ce_lit[ q ] = ceLit;
235   }
236   return d_ce_lit[ q ];
237 }
238 
substituteBoundVariablesToInstConstants(Node n,Node q)239 Node TermUtil::substituteBoundVariablesToInstConstants(Node n, Node q)
240 {
241   registerQuantifier( q );
242   return n.substitute( d_vars[q].begin(), d_vars[q].end(), d_inst_constants[q].begin(), d_inst_constants[q].end() );
243 }
244 
substituteInstConstantsToBoundVariables(Node n,Node q)245 Node TermUtil::substituteInstConstantsToBoundVariables(Node n, Node q)
246 {
247   registerQuantifier( q );
248   return n.substitute( d_inst_constants[q].begin(), d_inst_constants[q].end(), d_vars[q].begin(), d_vars[q].end() );
249 }
250 
substituteBoundVariables(Node n,Node q,std::vector<Node> & terms)251 Node TermUtil::substituteBoundVariables(Node n,
252                                         Node q,
253                                         std::vector<Node>& terms)
254 {
255   registerQuantifier(q);
256   Assert( d_vars[q].size()==terms.size() );
257   return n.substitute( d_vars[q].begin(), d_vars[q].end(), terms.begin(), terms.end() );
258 }
259 
substituteInstConstants(Node n,Node q,std::vector<Node> & terms)260 Node TermUtil::substituteInstConstants(Node n, Node q, std::vector<Node>& terms)
261 {
262   registerQuantifier(q);
263   Assert(d_inst_constants[q].size() == terms.size());
264   return n.substitute(d_inst_constants[q].begin(),
265                       d_inst_constants[q].end(),
266                       terms.begin(),
267                       terms.end());
268 }
269 
computeInstConstContains(Node n,std::vector<Node> & ics)270 void TermUtil::computeInstConstContains(Node n, std::vector<Node>& ics)
271 {
272   computeVarContainsInternal(n, INST_CONSTANT, ics);
273 }
274 
computeVarContains(Node n,std::vector<Node> & vars)275 void TermUtil::computeVarContains(Node n, std::vector<Node>& vars)
276 {
277   computeVarContainsInternal(n, BOUND_VARIABLE, vars);
278 }
279 
computeQuantContains(Node n,std::vector<Node> & quants)280 void TermUtil::computeQuantContains(Node n, std::vector<Node>& quants)
281 {
282   computeVarContainsInternal(n, FORALL, quants);
283 }
284 
computeVarContainsInternal(Node n,Kind k,std::vector<Node> & vars)285 void TermUtil::computeVarContainsInternal(Node n,
286                                           Kind k,
287                                           std::vector<Node>& vars)
288 {
289   std::unordered_set<TNode, TNodeHashFunction> visited;
290   std::unordered_set<TNode, TNodeHashFunction>::iterator it;
291   std::vector<TNode> visit;
292   TNode cur;
293   visit.push_back(n);
294   do
295   {
296     cur = visit.back();
297     visit.pop_back();
298     it = visited.find(cur);
299 
300     if (it == visited.end())
301     {
302       visited.insert(cur);
303       if (cur.getKind() == k)
304       {
305         if (std::find(vars.begin(), vars.end(), cur) == vars.end())
306         {
307           vars.push_back(cur);
308         }
309       }
310       else
311       {
312         if (cur.hasOperator())
313         {
314           visit.push_back(cur.getOperator());
315         }
316         for (const Node& cn : cur)
317         {
318           visit.push_back(cn);
319         }
320       }
321     }
322   } while (!visit.empty());
323 }
324 
computeInstConstContainsForQuant(Node q,Node n,std::vector<Node> & vars)325 void TermUtil::computeInstConstContainsForQuant(Node q,
326                                                 Node n,
327                                                 std::vector<Node>& vars)
328 {
329   std::vector<Node> ics;
330   computeInstConstContains(n, ics);
331   for (const Node& v : ics)
332   {
333     if (v.getAttribute(InstConstantAttribute()) == q)
334     {
335       if (std::find(vars.begin(), vars.end(), v) == vars.end())
336       {
337         vars.push_back(v);
338       }
339     }
340   }
341 }
342 
getVtsTerms(std::vector<Node> & t,bool isFree,bool create,bool inc_delta)343 void TermUtil::getVtsTerms( std::vector< Node >& t, bool isFree, bool create, bool inc_delta ) {
344   if( inc_delta ){
345     Node delta = getVtsDelta( isFree, create );
346     if( !delta.isNull() ){
347       t.push_back( delta );
348     }
349   }
350   for( unsigned r=0; r<2; r++ ){
351     Node inf = getVtsInfinityIndex( r, isFree, create );
352     if( !inf.isNull() ){
353       t.push_back( inf );
354     }
355   }
356 }
357 
getVtsDelta(bool isFree,bool create)358 Node TermUtil::getVtsDelta( bool isFree, bool create ) {
359   if( create ){
360     if( d_vts_delta_free.isNull() ){
361       d_vts_delta_free = NodeManager::currentNM()->mkSkolem( "delta_free", NodeManager::currentNM()->realType(), "free delta for virtual term substitution" );
362       Node delta_lem = NodeManager::currentNM()->mkNode( GT, d_vts_delta_free, d_zero );
363       d_quantEngine->getOutputChannel().lemma( delta_lem );
364     }
365     if( d_vts_delta.isNull() ){
366       d_vts_delta = NodeManager::currentNM()->mkSkolem( "delta", NodeManager::currentNM()->realType(), "delta for virtual term substitution" );
367       //mark as a virtual term
368       VirtualTermSkolemAttribute vtsa;
369       d_vts_delta.setAttribute(vtsa,true);
370     }
371   }
372   return isFree ? d_vts_delta_free : d_vts_delta;
373 }
374 
getVtsInfinity(TypeNode tn,bool isFree,bool create)375 Node TermUtil::getVtsInfinity( TypeNode tn, bool isFree, bool create ) {
376   if( create ){
377     if( d_vts_inf_free[tn].isNull() ){
378       d_vts_inf_free[tn] = NodeManager::currentNM()->mkSkolem( "inf_free", tn, "free infinity for virtual term substitution" );
379     }
380     if( d_vts_inf[tn].isNull() ){
381       d_vts_inf[tn] = NodeManager::currentNM()->mkSkolem( "inf", tn, "infinity for virtual term substitution" );
382       //mark as a virtual term
383       VirtualTermSkolemAttribute vtsa;
384       d_vts_inf[tn].setAttribute(vtsa,true);
385     }
386   }
387   return isFree ? d_vts_inf_free[tn] : d_vts_inf[tn];
388 }
389 
getVtsInfinityIndex(int i,bool isFree,bool create)390 Node TermUtil::getVtsInfinityIndex( int i, bool isFree, bool create ) {
391   if( i==0 ){
392     return getVtsInfinity( NodeManager::currentNM()->realType(), isFree, create );
393   }else if( i==1 ){
394     return getVtsInfinity( NodeManager::currentNM()->integerType(), isFree, create );
395   }else{
396     Assert( false );
397     return Node::null();
398   }
399 }
400 
substituteVtsFreeTerms(Node n)401 Node TermUtil::substituteVtsFreeTerms( Node n ) {
402   std::vector< Node > vars;
403   getVtsTerms( vars, false, false );
404   std::vector< Node > vars_free;
405   getVtsTerms( vars_free, true, false );
406   Assert( vars.size()==vars_free.size() );
407   if( !vars.empty() ){
408     return n.substitute( vars.begin(), vars.end(), vars_free.begin(), vars_free.end() );
409   }else{
410     return n;
411   }
412 }
413 
rewriteVtsSymbols(Node n)414 Node TermUtil::rewriteVtsSymbols( Node n ) {
415   if( ( n.getKind()==EQUAL || n.getKind()==GEQ ) ){
416     Trace("quant-vts-debug") << "VTS : process " << n << std::endl;
417     Node rew_vts_inf;
418     bool rew_delta = false;
419     //rewriting infinity always takes precedence over rewriting delta
420     for( unsigned r=0; r<2; r++ ){
421       Node inf = getVtsInfinityIndex( r, false, false );
422       if (!inf.isNull() && expr::hasSubterm(n, inf))
423       {
424         if( rew_vts_inf.isNull() ){
425           rew_vts_inf = inf;
426         }else{
427           //for mixed int/real with multiple infinities
428           Trace("quant-vts-debug") << "Multiple infinities...equate " << inf << " = " << rew_vts_inf << std::endl;
429           std::vector< Node > subs_lhs;
430           subs_lhs.push_back( inf );
431           std::vector< Node > subs_rhs;
432           subs_lhs.push_back( rew_vts_inf );
433           n = n.substitute( subs_lhs.begin(), subs_lhs.end(), subs_rhs.begin(), subs_rhs.end() );
434           n = Rewriter::rewrite( n );
435           // may have cancelled
436           if (!expr::hasSubterm(n, rew_vts_inf))
437           {
438             rew_vts_inf = Node::null();
439           }
440         }
441       }
442     }
443     if (rew_vts_inf.isNull())
444     {
445       if (!d_vts_delta.isNull() && expr::hasSubterm(n, d_vts_delta))
446       {
447         rew_delta = true;
448       }
449     }
450     if( !rew_vts_inf.isNull()  || rew_delta ){
451       std::map< Node, Node > msum;
452       if (ArithMSum::getMonomialSumLit(n, msum))
453       {
454         if( Trace.isOn("quant-vts-debug") ){
455           Trace("quant-vts-debug") << "VTS got monomial sum : " << std::endl;
456           ArithMSum::debugPrintMonomialSum(msum, "quant-vts-debug");
457         }
458         Node vts_sym = !rew_vts_inf.isNull() ? rew_vts_inf : d_vts_delta;
459         Assert( !vts_sym.isNull() );
460         Node iso_n;
461         Node nlit;
462         int res = ArithMSum::isolate(vts_sym, msum, iso_n, n.getKind(), true);
463         if( res!=0 ){
464           Trace("quant-vts-debug") << "VTS isolated :  -> " << iso_n << ", res = " << res << std::endl;
465           Node slv = iso_n[res==1 ? 1 : 0];
466           //ensure the vts terms have been eliminated
467           if( containsVtsTerm( slv ) ){
468             Trace("quant-vts-warn") << "Bad vts literal : " << n << ", contains " << vts_sym << " but bad solved form " << slv << "." << std::endl;
469             nlit = substituteVtsFreeTerms( n );
470             Trace("quant-vts-debug") << "...return " << nlit << std::endl;
471             //Assert( false );
472             //safe case: just convert to free symbols
473             return nlit;
474           }else{
475             if( !rew_vts_inf.isNull() ){
476               nlit = ( n.getKind()==GEQ && res==1 ) ? d_true : d_false;
477             }else{
478               Assert( iso_n[res==1 ? 0 : 1]==d_vts_delta );
479               if( n.getKind()==EQUAL ){
480                 nlit = d_false;
481               }else if( res==1 ){
482                 nlit = NodeManager::currentNM()->mkNode( GEQ, d_zero, slv );
483               }else{
484                 nlit = NodeManager::currentNM()->mkNode( GT, slv, d_zero );
485               }
486             }
487           }
488           Trace("quant-vts-debug") << "Return " << nlit << std::endl;
489           return nlit;
490         }else{
491           Trace("quant-vts-warn") << "Bad vts literal : " << n << ", contains " << vts_sym << " but could not isolate." << std::endl;
492           //safe case: just convert to free symbols
493           nlit = substituteVtsFreeTerms( n );
494           Trace("quant-vts-debug") << "...return " << nlit << std::endl;
495           //Assert( false );
496           return nlit;
497         }
498       }
499     }
500     return n;
501   }else if( n.getKind()==FORALL ){
502     //cannot traverse beneath quantifiers
503     return substituteVtsFreeTerms( n );
504   }else{
505     bool childChanged = false;
506     std::vector< Node > children;
507     for( unsigned i=0; i<n.getNumChildren(); i++ ){
508       Node nn = rewriteVtsSymbols( n[i] );
509       children.push_back( nn );
510       childChanged = childChanged || nn!=n[i];
511     }
512     if( childChanged ){
513       if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){
514         children.insert( children.begin(), n.getOperator() );
515       }
516       Node ret = NodeManager::currentNM()->mkNode( n.getKind(), children );
517       Trace("quant-vts-debug") << "...make node " << ret << std::endl;
518       return ret;
519     }else{
520       return n;
521     }
522   }
523 }
524 
containsVtsTerm(Node n,bool isFree)525 bool TermUtil::containsVtsTerm( Node n, bool isFree ) {
526   std::vector< Node > t;
527   getVtsTerms( t, isFree, false );
528   return containsTerms( n, t );
529 }
530 
containsVtsTerm(std::vector<Node> & n,bool isFree)531 bool TermUtil::containsVtsTerm( std::vector< Node >& n, bool isFree ) {
532   std::vector< Node > t;
533   getVtsTerms( t, isFree, false );
534   if( !t.empty() ){
535     for( unsigned i=0; i<n.size(); i++ ){
536       if( containsTerms( n[i], t ) ){
537         return true;
538       }
539     }
540   }
541   return false;
542 }
543 
containsVtsInfinity(Node n,bool isFree)544 bool TermUtil::containsVtsInfinity( Node n, bool isFree ) {
545   std::vector< Node > t;
546   getVtsTerms( t, isFree, false, false );
547   return containsTerms( n, t );
548 }
549 
ensureType(Node n,TypeNode tn)550 Node TermUtil::ensureType( Node n, TypeNode tn ) {
551   TypeNode ntn = n.getType();
552   Assert( ntn.isComparableTo( tn ) );
553   if( ntn.isSubtypeOf( tn ) ){
554     return n;
555   }else{
556     if( tn.isInteger() ){
557       return NodeManager::currentNM()->mkNode( TO_INTEGER, n );
558     }
559     return Node::null();
560   }
561 }
562 
containsTerms2(Node n,std::vector<Node> & t,std::map<Node,bool> & visited)563 bool TermUtil::containsTerms2( Node n, std::vector< Node >& t, std::map< Node, bool >& visited ) {
564   if (visited.find(n) == visited.end())
565   {
566     if( std::find( t.begin(), t.end(), n )!=t.end() ){
567       return true;
568     }
569     visited[n] = true;
570     if (n.hasOperator())
571     {
572       if (containsTerms2(n.getOperator(), t, visited))
573       {
574         return true;
575       }
576     }
577     for (const Node& nc : n)
578     {
579       if (containsTerms2(nc, t, visited))
580       {
581         return true;
582       }
583     }
584   }
585   return false;
586 }
587 
containsTerms(Node n,std::vector<Node> & t)588 bool TermUtil::containsTerms( Node n, std::vector< Node >& t ) {
589   if( t.empty() ){
590     return false;
591   }else{
592     std::map< Node, bool > visited;
593     return containsTerms2( n, t, visited );
594   }
595 }
596 
getTermDepth(Node n)597 int TermUtil::getTermDepth( Node n ) {
598   if (!n.hasAttribute(TermDepthAttribute()) ){
599     int maxDepth = -1;
600     for( unsigned i=0; i<n.getNumChildren(); i++ ){
601       int depth = getTermDepth( n[i] );
602       if( depth>maxDepth ){
603         maxDepth = depth;
604       }
605     }
606     TermDepthAttribute tda;
607     n.setAttribute(tda,1+maxDepth);
608   }
609   return n.getAttribute(TermDepthAttribute());
610 }
611 
containsUninterpretedConstant(Node n)612 bool TermUtil::containsUninterpretedConstant( Node n ) {
613   if (!n.hasAttribute(ContainsUConstAttribute()) ){
614     bool ret = false;
615     if( n.getKind()==UNINTERPRETED_CONSTANT ){
616       ret = true;
617     }else{
618       for( unsigned i=0; i<n.getNumChildren(); i++ ){
619         if( containsUninterpretedConstant( n[i] ) ){
620           ret = true;
621           break;
622         }
623       }
624     }
625     ContainsUConstAttribute cuca;
626     n.setAttribute(cuca, ret ? 1 : 0);
627   }
628   return n.getAttribute(ContainsUConstAttribute())!=0;
629 }
630 
simpleNegate(Node n)631 Node TermUtil::simpleNegate( Node n ){
632   if( n.getKind()==OR || n.getKind()==AND ){
633     std::vector< Node > children;
634     for (const Node& cn : n)
635     {
636       children.push_back(simpleNegate(cn));
637     }
638     return NodeManager::currentNM()->mkNode( n.getKind()==OR ? AND : OR, children );
639   }
640   return n.negate();
641 }
642 
mkNegate(Kind notk,Node n)643 Node TermUtil::mkNegate(Kind notk, Node n)
644 {
645   if (n.getKind() == notk)
646   {
647     return n[0];
648   }
649   return NodeManager::currentNM()->mkNode(notk, n);
650 }
651 
isNegate(Kind k)652 bool TermUtil::isNegate(Kind k)
653 {
654   return k == NOT || k == BITVECTOR_NOT || k == BITVECTOR_NEG || k == UMINUS;
655 }
656 
isAssoc(Kind k,bool reqNAry)657 bool TermUtil::isAssoc(Kind k, bool reqNAry)
658 {
659   if (reqNAry)
660   {
661     if (k == UNION || k == INTERSECTION)
662     {
663       return false;
664     }
665   }
666   return k == PLUS || k == MULT || k == NONLINEAR_MULT || k == AND || k == OR
667          || k == XOR || k == BITVECTOR_PLUS || k == BITVECTOR_MULT
668          || k == BITVECTOR_AND || k == BITVECTOR_OR || k == BITVECTOR_XOR
669          || k == BITVECTOR_XNOR || k == BITVECTOR_CONCAT || k == STRING_CONCAT
670          || k == UNION || k == INTERSECTION || k == JOIN || k == PRODUCT
671          || k == SEP_STAR;
672 }
673 
isComm(Kind k,bool reqNAry)674 bool TermUtil::isComm(Kind k, bool reqNAry)
675 {
676   if (reqNAry)
677   {
678     if (k == UNION || k == INTERSECTION)
679     {
680       return false;
681     }
682   }
683   return k == EQUAL || k == PLUS || k == MULT || k == NONLINEAR_MULT || k == AND
684          || k == OR || k == XOR || k == BITVECTOR_PLUS || k == BITVECTOR_MULT
685          || k == BITVECTOR_AND || k == BITVECTOR_OR || k == BITVECTOR_XOR
686          || k == BITVECTOR_XNOR || k == UNION || k == INTERSECTION
687          || k == SEP_STAR;
688 }
689 
isNonAdditive(Kind k)690 bool TermUtil::isNonAdditive( Kind k ) {
691   return k==AND || k==OR || k==BITVECTOR_AND || k==BITVECTOR_OR;
692 }
693 
isBoolConnective(Kind k)694 bool TermUtil::isBoolConnective( Kind k ) {
695   return k==OR || k==AND || k==EQUAL || k==ITE || k==FORALL || k==NOT || k==SEP_STAR;
696 }
697 
isBoolConnectiveTerm(TNode n)698 bool TermUtil::isBoolConnectiveTerm( TNode n ) {
699   return isBoolConnective( n.getKind() ) &&
700          ( n.getKind()!=EQUAL || n[0].getType().isBoolean() ) &&
701          ( n.getKind()!=ITE || n.getType().isBoolean() );
702 }
703 
getTypeValue(TypeNode tn,int val)704 Node TermUtil::getTypeValue(TypeNode tn, int val)
705 {
706   std::unordered_map<int, Node>::iterator it = d_type_value[tn].find(val);
707   if (it == d_type_value[tn].end())
708   {
709     Node n = mkTypeValue(tn, val);
710     d_type_value[tn][val] = n;
711     return n;
712   }
713   return it->second;
714 }
715 
mkTypeValue(TypeNode tn,int val)716 Node TermUtil::mkTypeValue(TypeNode tn, int val)
717 {
718   Node n;
719   if (tn.isInteger() || tn.isReal())
720   {
721     Rational c(val);
722     n = NodeManager::currentNM()->mkConst(c);
723   }
724   else if (tn.isBitVector())
725   {
726     unsigned int uv = val;
727     BitVector bval(tn.getConst<BitVectorSize>(), uv);
728     n = NodeManager::currentNM()->mkConst<BitVector>(bval);
729   }
730   else if (tn.isBoolean())
731   {
732     if (val == 0)
733     {
734       n = NodeManager::currentNM()->mkConst(false);
735     }
736   }
737   else if (tn.isString())
738   {
739     if (val == 0)
740     {
741       n = NodeManager::currentNM()->mkConst(::CVC4::String(""));
742     }
743   }
744   return n;
745 }
746 
getTypeMaxValue(TypeNode tn)747 Node TermUtil::getTypeMaxValue(TypeNode tn)
748 {
749   std::unordered_map<TypeNode, Node, TypeNodeHashFunction>::iterator it =
750       d_type_max_value.find(tn);
751   if (it == d_type_max_value.end())
752   {
753     Node n = mkTypeMaxValue(tn);
754     d_type_max_value[tn] = n;
755     return n;
756   }
757   return it->second;
758 }
759 
mkTypeMaxValue(TypeNode tn)760 Node TermUtil::mkTypeMaxValue(TypeNode tn)
761 {
762   Node n;
763   if (tn.isBitVector())
764   {
765     n = bv::utils::mkOnes(tn.getConst<BitVectorSize>());
766   }
767   else if (tn.isBoolean())
768   {
769     n = NodeManager::currentNM()->mkConst(true);
770   }
771   return n;
772 }
773 
getTypeValueOffset(TypeNode tn,Node val,int offset,int & status)774 Node TermUtil::getTypeValueOffset(TypeNode tn,
775                                   Node val,
776                                   int offset,
777                                   int& status)
778 {
779   std::unordered_map<int, Node>::iterator it =
780       d_type_value_offset[tn][val].find(offset);
781   if (it == d_type_value_offset[tn][val].end())
782   {
783     Node val_o;
784     Node offset_val = getTypeValue(tn, offset);
785     status = -1;
786     if (!offset_val.isNull())
787     {
788       if (tn.isInteger() || tn.isReal())
789       {
790         val_o = Rewriter::rewrite(
791             NodeManager::currentNM()->mkNode(PLUS, val, offset_val));
792         status = 0;
793       }
794       else if (tn.isBitVector())
795       {
796         val_o = Rewriter::rewrite(
797             NodeManager::currentNM()->mkNode(BITVECTOR_PLUS, val, offset_val));
798         // TODO : enable?  watch for overflows
799       }
800     }
801     d_type_value_offset[tn][val][offset] = val_o;
802     d_type_value_offset_status[tn][val][offset] = status;
803     return val_o;
804   }
805   status = d_type_value_offset_status[tn][val][offset];
806   return it->second;
807 }
808 
mkTypeConst(TypeNode tn,bool pol)809 Node TermUtil::mkTypeConst(TypeNode tn, bool pol)
810 {
811   return pol ? mkTypeMaxValue(tn) : mkTypeValue(tn, 0);
812 }
813 
isAntisymmetric(Kind k,Kind & dk)814 bool TermUtil::isAntisymmetric(Kind k, Kind& dk)
815 {
816   if (k == GT)
817   {
818     dk = LT;
819     return true;
820   }
821   else if (k == GEQ)
822   {
823     dk = LEQ;
824     return true;
825   }
826   else if (k == BITVECTOR_UGT)
827   {
828     dk = BITVECTOR_ULT;
829     return true;
830   }
831   else if (k == BITVECTOR_UGE)
832   {
833     dk = BITVECTOR_ULE;
834     return true;
835   }
836   else if (k == BITVECTOR_SGT)
837   {
838     dk = BITVECTOR_SLT;
839     return true;
840   }
841   else if (k == BITVECTOR_SGE)
842   {
843     dk = BITVECTOR_SLE;
844     return true;
845   }
846   return false;
847 }
848 
isIdempotentArg(Node n,Kind ik,int arg)849 bool TermUtil::isIdempotentArg(Node n, Kind ik, int arg)
850 {
851   // these should all be binary operators
852   // Assert( ik!=DIVISION && ik!=INTS_DIVISION && ik!=INTS_MODULUS &&
853   // ik!=BITVECTOR_UDIV );
854   TypeNode tn = n.getType();
855   if (n == getTypeValue(tn, 0))
856   {
857     if (ik == PLUS || ik == OR || ik == XOR || ik == BITVECTOR_PLUS
858         || ik == BITVECTOR_OR
859         || ik == BITVECTOR_XOR
860         || ik == STRING_CONCAT)
861     {
862       return true;
863     }
864     else if (ik == MINUS || ik == BITVECTOR_SHL || ik == BITVECTOR_LSHR
865              || ik == BITVECTOR_ASHR
866              || ik == BITVECTOR_SUB
867              || ik == BITVECTOR_UREM
868              || ik == BITVECTOR_UREM_TOTAL)
869     {
870       return arg == 1;
871     }
872   }
873   else if (n == getTypeValue(tn, 1))
874   {
875     if (ik == MULT || ik == BITVECTOR_MULT)
876     {
877       return true;
878     }
879     else if (ik == DIVISION || ik == DIVISION_TOTAL || ik == INTS_DIVISION
880              || ik == INTS_DIVISION_TOTAL
881              || ik == INTS_MODULUS
882              || ik == INTS_MODULUS_TOTAL
883              || ik == BITVECTOR_UDIV_TOTAL
884              || ik == BITVECTOR_UDIV
885              || ik == BITVECTOR_SDIV)
886     {
887       return arg == 1;
888     }
889   }
890   else if (n == getTypeMaxValue(tn))
891   {
892     if (ik == EQUAL || ik == BITVECTOR_AND || ik == BITVECTOR_XNOR)
893     {
894       return true;
895     }
896   }
897   return false;
898 }
899 
isSingularArg(Node n,Kind ik,unsigned arg)900 Node TermUtil::isSingularArg(Node n, Kind ik, unsigned arg)
901 {
902   TypeNode tn = n.getType();
903   if (n == getTypeValue(tn, 0))
904   {
905     if (ik == AND || ik == MULT || ik == BITVECTOR_AND || ik == BITVECTOR_MULT)
906     {
907       return n;
908     }
909     else if (ik == BITVECTOR_SHL || ik == BITVECTOR_LSHR || ik == BITVECTOR_ASHR
910              || ik == BITVECTOR_UREM
911              || ik == BITVECTOR_UREM_TOTAL)
912     {
913       if (arg == 0)
914       {
915         return n;
916       }
917     }
918     else if (ik == BITVECTOR_UDIV_TOTAL || ik == BITVECTOR_UDIV
919              || ik == BITVECTOR_SDIV)
920     {
921       if (arg == 0)
922       {
923         return n;
924       }
925       else if (arg == 1)
926       {
927         return getTypeMaxValue(tn);
928       }
929     }
930     else if (ik == DIVISION || ik == DIVISION_TOTAL || ik == INTS_DIVISION
931              || ik == INTS_DIVISION_TOTAL
932              || ik == INTS_MODULUS
933              || ik == INTS_MODULUS_TOTAL)
934     {
935       if (arg == 0)
936       {
937         return n;
938       }
939     }
940     else if (ik == STRING_SUBSTR)
941     {
942       if (arg == 0)
943       {
944         return n;
945       }
946       else if (arg == 2)
947       {
948         return getTypeValue(NodeManager::currentNM()->stringType(), 0);
949       }
950     }
951     else if (ik == STRING_STRIDOF)
952     {
953       if (arg == 0 || arg == 1)
954       {
955         return getTypeValue(NodeManager::currentNM()->integerType(), -1);
956       }
957     }
958   }
959   else if (n == getTypeValue(tn, 1))
960   {
961     if (ik == BITVECTOR_UREM_TOTAL)
962     {
963       return getTypeValue(tn, 0);
964     }
965   }
966   else if (n == getTypeMaxValue(tn))
967   {
968     if (ik == OR || ik == BITVECTOR_OR)
969     {
970       return n;
971     }
972   }
973   else
974   {
975     if (n.getType().isReal() && n.getConst<Rational>().sgn() < 0)
976     {
977       // negative arguments
978       if (ik == STRING_SUBSTR || ik == STRING_CHARAT)
979       {
980         return getTypeValue(NodeManager::currentNM()->stringType(), 0);
981       }
982       else if (ik == STRING_STRIDOF)
983       {
984         Assert(arg == 2);
985         return getTypeValue(NodeManager::currentNM()->integerType(), -1);
986       }
987     }
988   }
989   return Node::null();
990 }
991 
hasOffsetArg(Kind ik,int arg,int & offset,Kind & ok)992 bool TermUtil::hasOffsetArg(Kind ik, int arg, int& offset, Kind& ok)
993 {
994   if (ik == LT)
995   {
996     Assert(arg == 0 || arg == 1);
997     offset = arg == 0 ? 1 : -1;
998     ok = LEQ;
999     return true;
1000   }
1001   else if (ik == BITVECTOR_ULT)
1002   {
1003     Assert(arg == 0 || arg == 1);
1004     offset = arg == 0 ? 1 : -1;
1005     ok = BITVECTOR_ULE;
1006     return true;
1007   }
1008   else if (ik == BITVECTOR_SLT)
1009   {
1010     Assert(arg == 0 || arg == 1);
1011     offset = arg == 0 ? 1 : -1;
1012     ok = BITVECTOR_SLE;
1013     return true;
1014   }
1015   return false;
1016 }
1017 
getHoTypeMatchPredicate(TypeNode tn)1018 Node TermUtil::getHoTypeMatchPredicate( TypeNode tn ) {
1019   std::map< TypeNode, Node >::iterator ithp = d_ho_type_match_pred.find( tn );
1020   if( ithp==d_ho_type_match_pred.end() ){
1021     TypeNode ptn = NodeManager::currentNM()->mkFunctionType( tn, NodeManager::currentNM()->booleanType() );
1022     Node k = NodeManager::currentNM()->mkSkolem( "U", ptn, "predicate to force higher-order types" );
1023     d_ho_type_match_pred[tn] = k;
1024     return k;
1025   }else{
1026     return ithp->second;
1027   }
1028 }
1029 
1030 
1031 }/* CVC4::theory::quantifiers namespace */
1032 }/* CVC4::theory namespace */
1033 }/* CVC4 namespace */
1034