1 /*********************                                                        */
2 /*! \file sort_inference.cpp
3  ** \verbatim
4  ** Top contributors (to current version):
5  **   Andrew Reynolds, Paul Meng, Kshitij Bansal
6  ** This file is part of the CVC4 project.
7  ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8  ** in the top-level source directory) and their institutional affiliations.
9  ** All rights reserved.  See the file COPYING in the top-level source
10  ** directory for licensing information.\endverbatim
11  **
12  ** \brief Sort inference module
13  **
14  ** This class implements sort inference, based on a simple algorithm:
15  ** First, we assume all functions and predicates have distinct uninterpreted types.
16  ** One pass is made through the input assertions, while a union-find data structure
17  ** maintains necessary information regarding constraints on these types.
18  **/
19 
20 #include "theory/sort_inference.h"
21 
22 #include <vector>
23 
24 #include "options/quantifiers_options.h"
25 #include "options/smt_options.h"
26 #include "options/uf_options.h"
27 #include "proof/proof_manager.h"
28 #include "theory/rewriter.h"
29 #include "theory/quantifiers/quant_util.h"
30 
31 using namespace CVC4;
32 using namespace std;
33 
34 namespace CVC4 {
35 
print(const char * c)36 void SortInference::UnionFind::print(const char * c){
37   for( std::map< int, int >::iterator it = d_eqc.begin(); it != d_eqc.end(); ++it ){
38     Trace(c) << "s_" << it->first << " = s_" << it->second << ", ";
39   }
40   for( unsigned i=0; i<d_deq.size(); i++ ){
41     Trace(c) << "s_" << d_deq[i].first << " != s_" << d_deq[i].second << ", ";
42   }
43   Trace(c) << std::endl;
44 }
set(UnionFind & c)45 void SortInference::UnionFind::set( UnionFind& c ) {
46   clear();
47   for( std::map< int, int >::iterator it = c.d_eqc.begin(); it != c.d_eqc.end(); ++it ){
48     d_eqc[ it->first ] = it->second;
49   }
50   d_deq.insert( d_deq.end(), c.d_deq.begin(), c.d_deq.end() );
51 }
getRepresentative(int t)52 int SortInference::UnionFind::getRepresentative( int t ){
53   std::map< int, int >::iterator it = d_eqc.find( t );
54   if( it==d_eqc.end() || it->second==t ){
55     return t;
56   }else{
57     int rt = getRepresentative( it->second );
58     d_eqc[t] = rt;
59     return rt;
60   }
61 }
setEqual(int t1,int t2)62 void SortInference::UnionFind::setEqual( int t1, int t2 ){
63   if( t1!=t2 ){
64     int rt1 = getRepresentative( t1 );
65     int rt2 = getRepresentative( t2 );
66     if( rt1>rt2 ){
67       d_eqc[rt1] = rt2;
68     }else{
69       d_eqc[rt2] = rt1;
70     }
71   }
72 }
isValid()73 bool SortInference::UnionFind::isValid() {
74   for( unsigned i=0; i<d_deq.size(); i++ ){
75     if( areEqual( d_deq[i].first, d_deq[i].second ) ){
76       return false;
77     }
78   }
79   return true;
80 }
81 
82 
recordSubsort(TypeNode tn,int s)83 void SortInference::recordSubsort( TypeNode tn, int s ){
84   s = d_type_union_find.getRepresentative( s );
85   if( std::find( d_sub_sorts.begin(), d_sub_sorts.end(), s )==d_sub_sorts.end() ){
86     d_sub_sorts.push_back( s );
87     d_type_sub_sorts[tn].push_back( s );
88   }
89 }
90 
printSort(const char * c,int t)91 void SortInference::printSort( const char* c, int t ){
92   int rt = d_type_union_find.getRepresentative( t );
93   if( d_type_types.find( rt )!=d_type_types.end() ){
94     Trace(c) << d_type_types[rt];
95   }else{
96     Trace(c) << "s_" << rt;
97   }
98 }
99 
reset()100 void SortInference::reset() {
101   d_sub_sorts.clear();
102   d_non_monotonic_sorts.clear();
103   d_type_sub_sorts.clear();
104   //reset info
105   d_sortCount = 1;
106   d_type_union_find.clear();
107   d_type_types.clear();
108   d_id_for_types.clear();
109   d_op_return_types.clear();
110   d_op_arg_types.clear();
111   d_var_types.clear();
112   //for rewriting
113   d_symbol_map.clear();
114   d_const_map.clear();
115 }
116 
initialize(const std::vector<Node> & assertions)117 void SortInference::initialize(const std::vector<Node>& assertions)
118 {
119   Trace("sort-inference-proc") << "Calculating sort inference..." << std::endl;
120   // process all assertions
121   std::map<Node, int> visited;
122   for (const Node& a : assertions)
123   {
124     Trace("sort-inference-debug") << "Process " << a << std::endl;
125     std::map<Node, Node> var_bound;
126     process(a, var_bound, visited);
127   }
128   Trace("sort-inference-proc") << "...done" << std::endl;
129   for (const std::pair<const Node, int>& rt : d_op_return_types)
130   {
131     Trace("sort-inference") << rt.first << " : ";
132     TypeNode retTn = rt.first.getType();
133     if (!d_op_arg_types[rt.first].empty())
134     {
135       Trace("sort-inference") << "( ";
136       for (size_t i = 0; i < d_op_arg_types[rt.first].size(); i++)
137       {
138         recordSubsort(retTn[i], d_op_arg_types[rt.first][i]);
139         printSort("sort-inference", d_op_arg_types[rt.first][i]);
140         Trace("sort-inference") << " ";
141       }
142       Trace("sort-inference") << ") -> ";
143       retTn = retTn[(int)retTn.getNumChildren() - 1];
144     }
145     recordSubsort(retTn, rt.second);
146     printSort("sort-inference", rt.second);
147     Trace("sort-inference") << std::endl;
148   }
149   for (std::pair<const Node, std::map<Node, int> >& vt : d_var_types)
150   {
151     Trace("sort-inference")
152         << "Quantified formula : " << vt.first << " : " << std::endl;
153     for (const Node& v : vt.first[0])
154     {
155       recordSubsort(v.getType(), vt.second[v]);
156       printSort("sort-inference", vt.second[v]);
157       Trace("sort-inference") << std::endl;
158     }
159     Trace("sort-inference") << std::endl;
160   }
161 
162   // determine monotonicity of sorts
163   Trace("sort-inference-proc")
164       << "Calculating monotonicty for subsorts..." << std::endl;
165   std::map<Node, std::map<int, bool> > visitedm;
166   for (const Node& a : assertions)
167   {
168     Trace("sort-inference-debug")
169         << "Process monotonicity for " << a << std::endl;
170     std::map<Node, Node> var_bound;
171     processMonotonic(a, true, true, var_bound, visitedm);
172   }
173   Trace("sort-inference-proc") << "...done" << std::endl;
174 
175   Trace("sort-inference") << "We have " << d_sub_sorts.size()
176                           << " sub-sorts : " << std::endl;
177   for (unsigned i = 0, size = d_sub_sorts.size(); i < size; i++)
178   {
179     printSort("sort-inference", d_sub_sorts[i]);
180     if (d_type_types.find(d_sub_sorts[i]) != d_type_types.end())
181     {
182       Trace("sort-inference") << " is interpreted." << std::endl;
183     }
184     else if (d_non_monotonic_sorts.find(d_sub_sorts[i])
185              == d_non_monotonic_sorts.end())
186     {
187       Trace("sort-inference") << " is monotonic." << std::endl;
188     }
189     else
190     {
191       Trace("sort-inference") << " is not monotonic." << std::endl;
192     }
193   }
194 }
195 
simplify(Node n,std::map<Node,Node> & model_replace_f,std::map<Node,std::map<TypeNode,Node>> & visited)196 Node SortInference::simplify(Node n,
197                              std::map<Node, Node>& model_replace_f,
198                              std::map<Node, std::map<TypeNode, Node> >& visited)
199 {
200   Trace("sort-inference-debug") << "Simplify " << n << std::endl;
201   std::map<Node, Node> var_bound;
202   TypeNode tnn;
203   Node ret = simplifyNode(n, var_bound, tnn, model_replace_f, visited);
204   ret = theory::Rewriter::rewrite(ret);
205   return ret;
206 }
207 
getNewAssertions(std::vector<Node> & new_asserts)208 void SortInference::getNewAssertions(std::vector<Node>& new_asserts)
209 {
210   NodeManager* nm = NodeManager::currentNM();
211   // now, ensure constants are distinct
212   for (const std::pair<const TypeNode, std::map<Node, Node> >& cm : d_const_map)
213   {
214     std::vector<Node> consts;
215     for (const std::pair<const Node, Node>& c : cm.second)
216     {
217       Assert(c.first.isConst());
218       consts.push_back(c.second);
219     }
220     // add lemma enforcing introduced constants to be distinct
221     if (consts.size() > 1)
222     {
223       Node distinct_const = nm->mkNode(kind::DISTINCT, consts);
224       Trace("sort-inference-rewrite")
225           << "Add the constant distinctness lemma: " << std::endl;
226       Trace("sort-inference-rewrite") << "  " << distinct_const << std::endl;
227       new_asserts.push_back(distinct_const);
228     }
229   }
230 
231   // enforce constraints based on monotonicity
232   Trace("sort-inference-proc") << "Enforce monotonicity..." << std::endl;
233 
234   for (const std::pair<const TypeNode, std::vector<int> >& tss :
235        d_type_sub_sorts)
236   {
237     int nmonSort = -1;
238     unsigned nsorts = tss.second.size();
239     for (unsigned i = 0; i < nsorts; i++)
240     {
241       if (d_non_monotonic_sorts.find(tss.second[i])
242           != d_non_monotonic_sorts.end())
243       {
244         nmonSort = tss.second[i];
245         break;
246       }
247     }
248     if (nmonSort != -1)
249     {
250       std::vector<Node> injections;
251       TypeNode base_tn = getOrCreateTypeForId(nmonSort, tss.first);
252       for (unsigned i = 0; i < nsorts; i++)
253       {
254         if (tss.second[i] != nmonSort)
255         {
256           TypeNode new_tn = getOrCreateTypeForId(tss.second[i], tss.first);
257           // make injection to nmonSort
258           Node a1 = mkInjection(new_tn, base_tn);
259           injections.push_back(a1);
260           if (d_non_monotonic_sorts.find(tss.second[i])
261               != d_non_monotonic_sorts.end())
262           {
263             // also must make injection from nmonSort to this
264             Node a2 = mkInjection(base_tn, new_tn);
265             injections.push_back(a2);
266           }
267         }
268       }
269       if (Trace.isOn("sort-inference-rewrite"))
270       {
271         Trace("sort-inference-rewrite")
272             << "Add the following injections for " << tss.first
273             << " to ensure consistency wrt non-monotonic sorts : " << std::endl;
274         for (const Node& i : injections)
275         {
276           Trace("sort-inference-rewrite") << "   " << i << std::endl;
277         }
278       }
279       new_asserts.insert(
280           new_asserts.end(), injections.begin(), injections.end());
281     }
282   }
283   Trace("sort-inference-proc") << "...done" << std::endl;
284   // no sub-sort information is stored
285   reset();
286   Trace("sort-inference-debug") << "Finished sort inference" << std::endl;
287 }
288 
computeMonotonicity(const std::vector<Node> & assertions)289 void SortInference::computeMonotonicity(const std::vector<Node>& assertions)
290 {
291   std::map<Node, std::map<int, bool> > visitedmt;
292   Trace("sort-inference-proc")
293       << "Calculating monotonicty for types..." << std::endl;
294   for (const Node& a : assertions)
295   {
296     Trace("sort-inference-debug")
297         << "Process type monotonicity for " << a << std::endl;
298     std::map<Node, Node> var_bound;
299     processMonotonic(a, true, true, var_bound, visitedmt, true);
300   }
301   Trace("sort-inference-proc") << "...done" << std::endl;
302 }
303 
setEqual(int t1,int t2)304 void SortInference::setEqual( int t1, int t2 ){
305   if( t1!=t2 ){
306     int rt1 = d_type_union_find.getRepresentative( t1 );
307     int rt2 = d_type_union_find.getRepresentative( t2 );
308     if( rt1!=rt2 ){
309       Trace("sort-inference-debug") << "Set equal : ";
310       printSort( "sort-inference-debug", rt1 );
311       Trace("sort-inference-debug") << " ";
312       printSort( "sort-inference-debug", rt2 );
313       Trace("sort-inference-debug") << std::endl;
314       /*
315       d_type_eq_class[rt1].insert( d_type_eq_class[rt1].end(), d_type_eq_class[rt2].begin(), d_type_eq_class[rt2].end() );
316       d_type_eq_class[rt2].clear();
317       Trace("sort-inference-debug") << "EqClass : { ";
318       for( int i=0; i<(int)d_type_eq_class[rt1].size(); i++ ){
319         Trace("sort-inference-debug") << d_type_eq_class[rt1][i] << ", ";
320       }
321       Trace("sort-inference-debug") << "}" << std::endl;
322       */
323       if( rt2>rt1 ){
324         //swap
325         int swap = rt1;
326         rt1 = rt2;
327         rt2 = swap;
328       }
329       std::map< int, TypeNode >::iterator it1 = d_type_types.find( rt1 );
330       if( it1!=d_type_types.end() ){
331         if( d_type_types.find( rt2 )==d_type_types.end() ){
332           d_type_types[rt2] = it1->second;
333           d_type_types.erase( rt1 );
334         }else{
335           Trace("sort-inference-debug") << "...fail : associated with types " << d_type_types[rt1] << " and " << d_type_types[rt2] << std::endl;
336           return;
337         }
338       }
339       d_type_union_find.d_eqc[rt1] = rt2;
340     }
341   }
342 }
343 
getIdForType(TypeNode tn)344 int SortInference::getIdForType( TypeNode tn ){
345   //register the return type
346   std::map< TypeNode, int >::iterator it = d_id_for_types.find( tn );
347   if( it==d_id_for_types.end() ){
348     int sc = d_sortCount;
349     d_type_types[d_sortCount] = tn;
350     d_id_for_types[tn] = d_sortCount;
351     d_sortCount++;
352     return sc;
353   }else{
354     return it->second;
355   }
356 }
357 
process(Node n,std::map<Node,Node> & var_bound,std::map<Node,int> & visited)358 int SortInference::process( Node n, std::map< Node, Node >& var_bound, std::map< Node, int >& visited ){
359   std::map< Node, int >::iterator itv = visited.find( n );
360   if( itv!=visited.end() ){
361     return itv->second;
362   }else{
363     //add to variable bindings
364     bool use_new_visited = false;
365     std::map< Node, int > new_visited;
366     if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
367       if( d_var_types.find( n )!=d_var_types.end() ){
368         return getIdForType( n.getType() );
369       }else{
370         //apply sort inference to quantified variables
371         for( size_t i=0; i<n[0].getNumChildren(); i++ ){
372           TypeNode nitn = n[0][i].getType();
373           if( !nitn.isSort() )
374           {
375             // If the variable is of an interpreted sort, we assume the
376             // the sort of the variable will stay the same sort.
377             d_var_types[n][n[0][i]] = getIdForType( nitn );
378           }
379           else
380           {
381             // If it is of an uninterpreted sort, infer subsorts.
382             d_var_types[n][n[0][i]] = d_sortCount;
383             d_sortCount++;
384           }
385           var_bound[ n[0][i] ] = n;
386         }
387       }
388       use_new_visited = true;
389     }
390 
391     //process children
392     std::vector< Node > children;
393     std::vector< int > child_types;
394     for( size_t i=0; i<n.getNumChildren(); i++ ){
395       bool processChild = true;
396       if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
397         processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
398       }
399       if( processChild ){
400         children.push_back( n[i] );
401         child_types.push_back( process( n[i], var_bound, use_new_visited ? new_visited : visited ) );
402       }
403     }
404 
405     //remove from variable bindings
406     if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
407       //erase from variable bound
408       for( size_t i=0; i<n[0].getNumChildren(); i++ ){
409         var_bound.erase( n[0][i] );
410       }
411     }
412     Trace("sort-inference-debug") << "...Process " << n << std::endl;
413 
414     int retType;
415     if( n.getKind()==kind::EQUAL && !n[0].getType().isBoolean() ){
416       Trace("sort-inference-debug") << "For equality " << n << ", set equal types from : " << n[0].getType() << " " << n[1].getType() << std::endl;
417       //if original types are mixed (e.g. Int/Real), don't commit type equality in either direction
418       if( n[0].getType()!=n[1].getType() ){
419         //for now, assume the original types
420         for( unsigned i=0; i<2; i++ ){
421           int ct = getIdForType( n[i].getType() );
422           setEqual( child_types[i], ct );
423         }
424       }else{
425         //we only require that the left and right hand side must be equal
426         setEqual( child_types[0], child_types[1] );
427       }
428       d_equality_types[n] = child_types[0];
429       retType = getIdForType( n.getType() );
430     }else if( n.getKind()==kind::APPLY_UF ){
431       Node op = n.getOperator();
432       TypeNode tn_op = op.getType();
433       if( d_op_return_types.find( op )==d_op_return_types.end() ){
434         if( n.getType().isBoolean() ){
435           //use booleans
436           d_op_return_types[op] = getIdForType( n.getType() );
437         }else{
438           //assign arbitrary sort for return type
439           d_op_return_types[op] = d_sortCount;
440           d_sortCount++;
441         }
442         // d_type_eq_class[d_sortCount].push_back( op );
443         // assign arbitrary sort for argument types
444         for( size_t i=0; i<n.getNumChildren(); i++ ){
445           d_op_arg_types[op].push_back(d_sortCount);
446           d_sortCount++;
447         }
448       }
449       for( size_t i=0; i<n.getNumChildren(); i++ ){
450         //the argument of the operator must match the return type of the subterm
451         if( n[i].getType()!=tn_op[i] ){
452           //if type mismatch, assume original types
453           Trace("sort-inference-debug") << "Argument " << i << " of " << op << " " << n[i] << " has type " << n[i].getType();
454           Trace("sort-inference-debug") << ", while operator arg has type " << tn_op[i] << std::endl;
455           int ct1 = getIdForType( n[i].getType() );
456           setEqual( child_types[i], ct1 );
457           int ct2 = getIdForType( tn_op[i] );
458           setEqual( d_op_arg_types[op][i], ct2 );
459         }else{
460           setEqual( child_types[i], d_op_arg_types[op][i] );
461         }
462       }
463       //return type is the return type
464       retType = d_op_return_types[op];
465     }else{
466       std::map< Node, Node >::iterator it = var_bound.find( n );
467       if( it!=var_bound.end() ){
468         Trace("sort-inference-debug") << n << " is a bound variable." << std::endl;
469         //the return type was specified while binding
470         retType = d_var_types[it->second][n];
471       }else if( n.getKind() == kind::VARIABLE || n.getKind()==kind::SKOLEM ){
472         Trace("sort-inference-debug") << n << " is a variable." << std::endl;
473         if( d_op_return_types.find( n )==d_op_return_types.end() ){
474           //assign arbitrary sort
475           d_op_return_types[n] = d_sortCount;
476           d_sortCount++;
477           // d_type_eq_class[d_sortCount].push_back( n );
478         }
479         retType = d_op_return_types[n];
480       }else if( n.isConst() ){
481         Trace("sort-inference-debug") << n << " is a constant." << std::endl;
482         //can be any type we want
483         retType = d_sortCount;
484         d_sortCount++;
485       }else{
486         Trace("sort-inference-debug") << n << " is a interpreted symbol." << std::endl;
487         //it is an interpreted term
488         for( size_t i=0; i<children.size(); i++ ){
489           Trace("sort-inference-debug") << children[i] << " forced to have " << children[i].getType() << std::endl;
490           //must enforce the actual type of the operator on the children
491           int ct = getIdForType( children[i].getType() );
492           setEqual( child_types[i], ct );
493         }
494         //return type must be the actual return type
495         retType = getIdForType( n.getType() );
496       }
497     }
498     Trace("sort-inference-debug") << "...Type( " << n << " ) = ";
499     printSort("sort-inference-debug", retType );
500     Trace("sort-inference-debug") << std::endl;
501     visited[n] = retType;
502     return retType;
503   }
504 }
505 
processMonotonic(Node n,bool pol,bool hasPol,std::map<Node,Node> & var_bound,std::map<Node,std::map<int,bool>> & visited,bool typeMode)506 void SortInference::processMonotonic( Node n, bool pol, bool hasPol, std::map< Node, Node >& var_bound, std::map< Node, std::map< int, bool > >& visited, bool typeMode ) {
507   int pindex = hasPol ? ( pol ? 1 : -1 ) : 0;
508   if( visited[n].find( pindex )==visited[n].end() ){
509     visited[n][pindex] = true;
510     Trace("sort-inference-debug") << "...Process monotonic " << pol << " " << hasPol << " " << n << std::endl;
511     if( n.getKind()==kind::FORALL ){
512       //only consider variables universally if it is possible this quantified formula is asserted positively
513       if( !hasPol || pol ){
514         for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
515           var_bound[n[0][i]] = n;
516         }
517       }
518       processMonotonic( n[1], pol, hasPol, var_bound, visited, typeMode );
519       if( !hasPol || pol ){
520         for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
521           var_bound.erase( n[0][i] );
522         }
523       }
524       return;
525     }else if( n.getKind()==kind::EQUAL ){
526       if( !hasPol || pol ){
527         for( unsigned i=0; i<2; i++ ){
528           if( var_bound.find( n[i] )!=var_bound.end() ){
529             if( !typeMode ){
530               int sid = getSortId( var_bound[n[i]], n[i] );
531               d_non_monotonic_sorts[sid] = true;
532             }else{
533               d_non_monotonic_sorts_orig[n[i].getType()] = true;
534             }
535             break;
536           }
537         }
538       }
539     }
540     for( unsigned i=0; i<n.getNumChildren(); i++ ){
541       bool npol;
542       bool nhasPol;
543       theory::QuantPhaseReq::getPolarity( n, i, hasPol, pol, nhasPol, npol );
544       processMonotonic( n[i], npol, nhasPol, var_bound, visited, typeMode );
545     }
546   }
547 }
548 
549 
getOrCreateTypeForId(int t,TypeNode pref)550 TypeNode SortInference::getOrCreateTypeForId( int t, TypeNode pref ){
551   int rt = d_type_union_find.getRepresentative( t );
552   if( d_type_types.find( rt )!=d_type_types.end() ){
553     return d_type_types[rt];
554   }else{
555     TypeNode retType;
556     // See if we can assign pref. This is an optimization for reusing an
557     // uninterpreted sort as the first subsort, so that fewer symbols needed
558     // to be rewritten in the sort-inferred signature. Notice we only assign
559     // pref here if it is an uninterpreted sort.
560     if (!pref.isNull() && d_id_for_types.find(pref) == d_id_for_types.end()
561         && pref.isSort())
562     {
563       retType = pref;
564     }else{
565       //must create new type
566       std::stringstream ss;
567       ss << "it_" << t << "_" << pref;
568       retType = NodeManager::currentNM()->mkSort( ss.str() );
569     }
570     Trace("sort-inference") << "-> Make type " << retType << " to correspond to ";
571     printSort("sort-inference", t );
572     Trace("sort-inference") << std::endl;
573     d_id_for_types[ retType ] = rt;
574     d_type_types[ rt ] = retType;
575     return retType;
576   }
577 }
578 
getTypeForId(int t)579 TypeNode SortInference::getTypeForId( int t ){
580   int rt = d_type_union_find.getRepresentative( t );
581   if( d_type_types.find( rt )!=d_type_types.end() ){
582     return d_type_types[rt];
583   }else{
584     return TypeNode::null();
585   }
586 }
587 
getNewSymbol(Node old,TypeNode tn)588 Node SortInference::getNewSymbol( Node old, TypeNode tn ){
589   // if no sort was inferred for this node, return original
590   if( tn.isNull() || tn.isComparableTo( old.getType() ) ){
591     return old;
592   }else if( old.isConst() ){
593     //must make constant of type tn
594     if( d_const_map[tn].find( old )==d_const_map[tn].end() ){
595       std::stringstream ss;
596       ss << "ic_" << tn << "_" << old;
597       d_const_map[tn][ old ] = NodeManager::currentNM()->mkSkolem( ss.str(), tn, "constant created during sort inference" );  //use mkConst???
598     }
599     return d_const_map[tn][ old ];
600   }else if( old.getKind()==kind::BOUND_VARIABLE ){
601     std::stringstream ss;
602     ss << "b_" << old;
603     return NodeManager::currentNM()->mkBoundVar( ss.str(), tn );
604   }else{
605     std::stringstream ss;
606     ss << "i_" << old;
607     return NodeManager::currentNM()->mkSkolem( ss.str(), tn, "created during sort inference" );
608   }
609 }
610 
simplifyNode(Node n,std::map<Node,Node> & var_bound,TypeNode tnn,std::map<Node,Node> & model_replace_f,std::map<Node,std::map<TypeNode,Node>> & visited)611 Node SortInference::simplifyNode(
612     Node n,
613     std::map<Node, Node>& var_bound,
614     TypeNode tnn,
615     std::map<Node, Node>& model_replace_f,
616     std::map<Node, std::map<TypeNode, Node> >& visited)
617 {
618   std::map< TypeNode, Node >::iterator itv = visited[n].find( tnn );
619   if( itv!=visited[n].end() ){
620     return itv->second;
621   }else{
622     Trace("sort-inference-debug2") << "Simplify " << n << ", type context=" << tnn << std::endl;
623     std::vector< Node > children;
624     std::map< Node, std::map< TypeNode, Node > > new_visited;
625     bool use_new_visited = false;
626     if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
627       //recreate based on types of variables
628       std::vector< Node > new_children;
629       for( size_t i=0; i<n[0].getNumChildren(); i++ ){
630         TypeNode tn = getOrCreateTypeForId( d_var_types[n][ n[0][i] ], n[0][i].getType() );
631         Node v = getNewSymbol( n[0][i], tn );
632         Trace("sort-inference-debug2") << "Map variable " << n[0][i] << " to " << v << std::endl;
633         new_children.push_back( v );
634         var_bound[ n[0][i] ] = v;
635       }
636       children.push_back( NodeManager::currentNM()->mkNode( n[0].getKind(), new_children ) );
637       use_new_visited = true;
638     }
639 
640     //process children
641     if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){
642       children.push_back( n.getOperator() );
643     }
644     Node op;
645     if( n.hasOperator() ){
646       op = n.getOperator();
647     }
648     bool childChanged = false;
649     TypeNode tnnc;
650     for( size_t i=0; i<n.getNumChildren(); i++ ){
651       bool processChild = true;
652       if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
653         processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
654       }
655       if( processChild ){
656         if( n.getKind()==kind::APPLY_UF ){
657           Assert( d_op_arg_types.find( op )!=d_op_arg_types.end() );
658           tnnc = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
659           Assert( !tnnc.isNull() );
660         }else if( n.getKind()==kind::EQUAL && i==0 ){
661           Assert( d_equality_types.find( n )!=d_equality_types.end() );
662           tnnc = getOrCreateTypeForId( d_equality_types[n], n[0].getType() );
663           Assert( !tnnc.isNull() );
664         }
665         Node nc = simplifyNode(n[i],
666                                var_bound,
667                                tnnc,
668                                model_replace_f,
669                                use_new_visited ? new_visited : visited);
670         Trace("sort-inference-debug2") << "Simplify " << i << " " << n[i] << " returned " << nc << std::endl;
671         children.push_back( nc );
672         childChanged = childChanged || nc!=n[i];
673       }
674     }
675 
676     //remove from variable bindings
677     Node ret;
678     if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
679       //erase from variable bound
680       for( size_t i=0; i<n[0].getNumChildren(); i++ ){
681         Trace("sort-inference-debug2") << "Remove bound for " << n[0][i] << std::endl;
682         var_bound.erase( n[0][i] );
683       }
684       ret = NodeManager::currentNM()->mkNode( n.getKind(), children );
685     }else if( n.getKind()==kind::EQUAL ){
686       TypeNode tn1 = children[0].getType();
687       TypeNode tn2 = children[1].getType();
688       if( !tn1.isSubtypeOf( tn2 ) && !tn2.isSubtypeOf( tn1 ) ){
689         Trace("sort-inference-warn") << "Sort inference created bad equality: " << children[0] << " = " << children[1] << std::endl;
690         Trace("sort-inference-warn") << "  Types : " << children[0].getType() << " " << children[1].getType() << std::endl;
691         Assert( false );
692       }
693       ret = NodeManager::currentNM()->mkNode( kind::EQUAL, children );
694     }else if( n.getKind()==kind::APPLY_UF ){
695       if( d_symbol_map.find( op )==d_symbol_map.end() ){
696         //make the new operator if necessary
697         bool opChanged = false;
698         std::vector< TypeNode > argTypes;
699         for( size_t i=0; i<n.getNumChildren(); i++ ){
700           TypeNode tn = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
701           argTypes.push_back( tn );
702           if( tn!=n[i].getType() ){
703             opChanged = true;
704           }
705         }
706         TypeNode retType = getOrCreateTypeForId( d_op_return_types[op], n.getType() );
707         if( retType!=n.getType() ){
708           opChanged = true;
709         }
710         if( opChanged ){
711           std::stringstream ss;
712           ss << "io_" << op;
713           TypeNode typ = NodeManager::currentNM()->mkFunctionType( argTypes, retType );
714           d_symbol_map[op] = NodeManager::currentNM()->mkSkolem( ss.str(), typ, "op created during sort inference" );
715           Trace("setp-model") << "Function " << op << " is replaced with " << d_symbol_map[op] << std::endl;
716           model_replace_f[op] = d_symbol_map[op];
717         }else{
718           d_symbol_map[op] = op;
719         }
720       }
721       children[0] = d_symbol_map[op];
722       // make sure all children have been given proper types
723       for (size_t i = 0, size = n.getNumChildren(); i < size; i++)
724       {
725         TypeNode tn = children[i+1].getType();
726         TypeNode tna = getTypeForId( d_op_arg_types[op][i] );
727         if (!tn.isSubtypeOf(tna))
728         {
729           Trace("sort-inference-warn") << "Sort inference created bad child: " << n << " " << n[i] << " " << tn << " " << tna << std::endl;
730           Assert( false );
731         }
732       }
733       ret = NodeManager::currentNM()->mkNode( kind::APPLY_UF, children );
734     }else{
735       std::map< Node, Node >::iterator it = var_bound.find( n );
736       if( it!=var_bound.end() ){
737         ret = it->second;
738       }else if( n.getKind() == kind::VARIABLE || n.getKind() == kind::SKOLEM ){
739         if( d_symbol_map.find( n )==d_symbol_map.end() ){
740           TypeNode tn = getOrCreateTypeForId( d_op_return_types[n], n.getType() );
741           d_symbol_map[n] = getNewSymbol( n, tn );
742         }
743         ret = d_symbol_map[n];
744       }else if( n.isConst() ){
745         //type is determined by context
746         ret = getNewSymbol( n, tnn );
747       }else if( childChanged ){
748         ret = NodeManager::currentNM()->mkNode( n.getKind(), children );
749       }else{
750         ret = n;
751       }
752     }
753     visited[n][tnn] = ret;
754     return ret;
755   }
756 }
757 
mkInjection(TypeNode tn1,TypeNode tn2)758 Node SortInference::mkInjection( TypeNode tn1, TypeNode tn2 ) {
759   std::vector< TypeNode > tns;
760   tns.push_back( tn1 );
761   TypeNode typ = NodeManager::currentNM()->mkFunctionType( tns, tn2 );
762   Node f = NodeManager::currentNM()->mkSkolem( "inj", typ, "injection for monotonicity constraint" );
763   Trace("sort-inference") << "-> Make injection " << f << " from " << tn1 << " to " << tn2 << std::endl;
764   Node v1 = NodeManager::currentNM()->mkBoundVar( "?x", tn1 );
765   Node v2 = NodeManager::currentNM()->mkBoundVar( "?y", tn1 );
766   Node ret = NodeManager::currentNM()->mkNode( kind::FORALL,
767                NodeManager::currentNM()->mkNode( kind::BOUND_VAR_LIST, v1, v2 ),
768                NodeManager::currentNM()->mkNode( kind::OR,
769                  NodeManager::currentNM()->mkNode( kind::APPLY_UF, f, v1 ).eqNode( NodeManager::currentNM()->mkNode( kind::APPLY_UF, f, v2 ) ).negate(),
770                  v1.eqNode( v2 ) ) );
771   ret = theory::Rewriter::rewrite( ret );
772   return ret;
773 }
774 
getSortId(Node n)775 int SortInference::getSortId( Node n ) {
776   Node op = n.getKind()==kind::APPLY_UF ? n.getOperator() : n;
777   if( d_op_return_types.find( op )!=d_op_return_types.end() ){
778     return d_type_union_find.getRepresentative( d_op_return_types[op] );
779   }else{
780     return 0;
781   }
782 }
783 
getSortId(Node f,Node v)784 int SortInference::getSortId( Node f, Node v ) {
785   if( d_var_types.find( f )!=d_var_types.end() ){
786     return d_type_union_find.getRepresentative( d_var_types[f][v] );
787   }else{
788     return 0;
789   }
790 }
791 
setSkolemVar(Node f,Node v,Node sk)792 void SortInference::setSkolemVar( Node f, Node v, Node sk ){
793   Trace("sort-inference-temp") << "Set skolem var for " << f << ", variable " << v << std::endl;
794   if( isWellSortedFormula( f ) && d_var_types.find( f )==d_var_types.end() ){
795     //calculate the sort for variables if not done so already
796     std::map< Node, Node > var_bound;
797     std::map< Node, int > visited;
798     process( f, var_bound, visited );
799   }
800   d_op_return_types[sk] = getSortId( f, v );
801   Trace("sort-inference-temp") << "Set skolem sort id for " << sk << " to " << d_op_return_types[sk] << std::endl;
802 }
803 
isWellSortedFormula(Node n)804 bool SortInference::isWellSortedFormula( Node n ) {
805   if( n.getType().isBoolean() && n.getKind()!=kind::APPLY_UF ){
806     for( unsigned i=0; i<n.getNumChildren(); i++ ){
807       if( !isWellSortedFormula( n[i] ) ){
808         return false;
809       }
810     }
811     return true;
812   }else{
813     return isWellSorted( n );
814   }
815 }
816 
isWellSorted(Node n)817 bool SortInference::isWellSorted( Node n ) {
818   if( getSortId( n )==0 ){
819     return false;
820   }else{
821     if( n.getKind()==kind::APPLY_UF ){
822       for( unsigned i=0; i<n.getNumChildren(); i++ ){
823         int s1 = getSortId( n[i] );
824         int s2 = d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] );
825         if( s1!=s2 ){
826           return false;
827         }
828         if( !isWellSorted( n[i] ) ){
829           return false;
830         }
831       }
832     }
833     return true;
834   }
835 }
836 
getSortConstraints(Node n,UnionFind & uf)837 void SortInference::getSortConstraints( Node n, UnionFind& uf ) {
838   if( n.getKind()==kind::APPLY_UF ){
839     for( unsigned i=0; i<n.getNumChildren(); i++ ){
840       getSortConstraints( n[i], uf );
841       uf.setEqual( getSortId( n[i] ), d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] ) );
842     }
843   }
844 }
845 
isMonotonic(TypeNode tn)846 bool SortInference::isMonotonic( TypeNode tn ) {
847   Assert( tn.isSort() );
848   return d_non_monotonic_sorts_orig.find( tn )==d_non_monotonic_sorts_orig.end();
849 }
850 
851 }/* CVC4 namespace */
852