1 /*********************                                                        */
2 /*! \file sygus_unif_strat.cpp
3  ** \verbatim
4  ** Top contributors (to current version):
5  **   Andrew Reynolds, Haniel Barbosa
6  ** This file is part of the CVC4 project.
7  ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8  ** in the top-level source directory) and their institutional affiliations.
9  ** All rights reserved.  See the file COPYING in the top-level source
10  ** directory for licensing information.\endverbatim
11  **
12  ** \brief Implementation of sygus_unif_strat
13  **/
14 
15 #include "theory/quantifiers/sygus/sygus_unif_strat.h"
16 
17 #include "theory/datatypes/datatypes_rewriter.h"
18 #include "theory/quantifiers/sygus/sygus_unif.h"
19 #include "theory/quantifiers/sygus/term_database_sygus.h"
20 #include "theory/quantifiers/term_util.h"
21 
22 using namespace std;
23 using namespace CVC4::kind;
24 
25 namespace CVC4 {
26 namespace theory {
27 namespace quantifiers {
28 
operator <<(std::ostream & os,EnumRole r)29 std::ostream& operator<<(std::ostream& os, EnumRole r)
30 {
31   switch (r)
32   {
33     case enum_invalid: os << "INVALID"; break;
34     case enum_io: os << "IO"; break;
35     case enum_ite_condition: os << "CONDITION"; break;
36     case enum_concat_term: os << "CTERM"; break;
37     default: os << "enum_" << static_cast<unsigned>(r); break;
38   }
39   return os;
40 }
41 
operator <<(std::ostream & os,NodeRole r)42 std::ostream& operator<<(std::ostream& os, NodeRole r)
43 {
44   switch (r)
45   {
46     case role_equal: os << "equal"; break;
47     case role_string_prefix: os << "string_prefix"; break;
48     case role_string_suffix: os << "string_suffix"; break;
49     case role_ite_condition: os << "ite_condition"; break;
50     default: os << "role_" << static_cast<unsigned>(r); break;
51   }
52   return os;
53 }
54 
getEnumeratorRoleForNodeRole(NodeRole r)55 EnumRole getEnumeratorRoleForNodeRole(NodeRole r)
56 {
57   switch (r)
58   {
59     case role_equal: return enum_io; break;
60     case role_string_prefix: return enum_concat_term; break;
61     case role_string_suffix: return enum_concat_term; break;
62     case role_ite_condition: return enum_ite_condition; break;
63     default: break;
64   }
65   return enum_invalid;
66 }
67 
operator <<(std::ostream & os,StrategyType st)68 std::ostream& operator<<(std::ostream& os, StrategyType st)
69 {
70   switch (st)
71   {
72     case strat_ITE: os << "ITE"; break;
73     case strat_CONCAT_PREFIX: os << "CONCAT_PREFIX"; break;
74     case strat_CONCAT_SUFFIX: os << "CONCAT_SUFFIX"; break;
75     case strat_ID: os << "ID"; break;
76     default: os << "strat_" << static_cast<unsigned>(st); break;
77   }
78   return os;
79 }
80 
initialize(QuantifiersEngine * qe,Node f,std::vector<Node> & enums)81 void SygusUnifStrategy::initialize(QuantifiersEngine* qe,
82                                    Node f,
83                                    std::vector<Node>& enums)
84 {
85   Assert(d_candidate.isNull());
86   d_candidate = f;
87   d_root = f.getType();
88   d_qe = qe;
89 
90   // collect the enumerator types and form the strategy
91   buildStrategyGraph(d_root, role_equal);
92   // add the enumerators
93   enums.insert(enums.end(), d_esym_list.begin(), d_esym_list.end());
94   // finish the initialization of the strategy
95   // this computes if each node is conditional
96   std::map<Node, std::map<NodeRole, bool> > visited;
97   finishInit(getRootEnumerator(), role_equal, visited, false);
98 }
99 
initializeType(TypeNode tn)100 void SygusUnifStrategy::initializeType(TypeNode tn)
101 {
102   d_tinfo[tn].d_this_type = tn;
103 }
104 
getRootEnumerator() const105 Node SygusUnifStrategy::getRootEnumerator() const
106 {
107   std::map<TypeNode, EnumTypeInfo>::const_iterator itt = d_tinfo.find(d_root);
108   Assert(itt != d_tinfo.end());
109   std::map<EnumRole, Node>::const_iterator it =
110       itt->second.d_enum.find(enum_io);
111   Assert(it != itt->second.d_enum.end());
112   return it->second;
113 }
114 
getEnumInfo(Node e)115 EnumInfo& SygusUnifStrategy::getEnumInfo(Node e)
116 {
117   std::map<Node, EnumInfo>::iterator it = d_einfo.find(e);
118   Assert(it != d_einfo.end());
119   return it->second;
120 }
121 
getEnumTypeInfo(TypeNode tn)122 EnumTypeInfo& SygusUnifStrategy::getEnumTypeInfo(TypeNode tn)
123 {
124   std::map<TypeNode, EnumTypeInfo>::iterator it = d_tinfo.find(tn);
125   Assert(it != d_tinfo.end());
126   return it->second;
127 }
128 // ----------------------------- establishing enumeration types
129 
registerStrategyPoint(Node et,TypeNode tn,EnumRole enum_role,bool inSearch)130 void SygusUnifStrategy::registerStrategyPoint(Node et,
131                                            TypeNode tn,
132                                            EnumRole enum_role,
133                                            bool inSearch)
134 {
135   if (d_einfo.find(et) == d_einfo.end())
136   {
137     Trace("sygus-unif-debug")
138         << "...register " << et << " for "
139         << static_cast<DatatypeType>(tn.toType()).getDatatype().getName();
140     Trace("sygus-unif-debug") << ", role = " << enum_role
141                               << ", in search = " << inSearch << std::endl;
142     d_einfo[et].initialize(enum_role);
143     // if we are actually enumerating this (could be a compound node in the
144     // strategy)
145     if (inSearch)
146     {
147       std::map<TypeNode, Node>::iterator itn = d_master_enum.find(tn);
148       if (itn == d_master_enum.end())
149       {
150         // use this for the search
151         d_master_enum[tn] = et;
152         d_esym_list.push_back(et);
153         d_einfo[et].d_enum_slave.push_back(et);
154       }
155       else
156       {
157         Trace("sygus-unif-debug") << "Make " << et << " a slave of "
158                                   << itn->second << std::endl;
159         d_einfo[itn->second].d_enum_slave.push_back(et);
160       }
161     }
162   }
163 }
164 
buildStrategyGraph(TypeNode tn,NodeRole nrole)165 void SygusUnifStrategy::buildStrategyGraph(TypeNode tn, NodeRole nrole)
166 {
167   NodeManager* nm = NodeManager::currentNM();
168   if (d_tinfo.find(tn) == d_tinfo.end())
169   {
170     // register type
171     Trace("sygus-unif") << "Register enumerating type : " << tn << std::endl;
172     initializeType(tn);
173   }
174   EnumTypeInfo& eti = d_tinfo[tn];
175   std::map<NodeRole, StrategyNode>::iterator itsn = eti.d_snodes.find(nrole);
176   if (itsn != eti.d_snodes.end())
177   {
178     // already initialized
179     return;
180   }
181   StrategyNode& snode = eti.d_snodes[nrole];
182 
183   // get the enumerator for this
184   EnumRole erole = getEnumeratorRoleForNodeRole(nrole);
185 
186   Node ee;
187   std::map<EnumRole, Node>::iterator iten = eti.d_enum.find(erole);
188   if (iten == eti.d_enum.end())
189   {
190     ee = nm->mkSkolem("ee", tn);
191     eti.d_enum[erole] = ee;
192     Trace("sygus-unif-debug")
193         << "...enumerator " << ee << " for "
194         << static_cast<DatatypeType>(tn.toType()).getDatatype().getName()
195         << ", role = " << erole << std::endl;
196   }
197   else
198   {
199     ee = iten->second;
200   }
201 
202   // roles that we do not recurse on
203   if (nrole == role_ite_condition)
204   {
205     Trace("sygus-unif-debug") << "...this register (non-io)" << std::endl;
206     registerStrategyPoint(ee, tn, erole, true);
207     return;
208   }
209 
210   // look at information on how we will construct solutions for this type
211   // we know this is a sygus datatype since it is either the top-level type
212   // in the strategy graph, or was recursed by a strategy we inferred.
213   Assert(tn.isDatatype());
214   const Datatype& dt = static_cast<DatatypeType>(tn.toType()).getDatatype();
215   Assert(dt.isSygus());
216 
217   std::map<Node, std::vector<StrategyType> > cop_to_strat;
218   std::map<Node, unsigned> cop_to_cindex;
219   std::map<Node, std::map<unsigned, Node> > cop_to_child_templ;
220   std::map<Node, std::map<unsigned, Node> > cop_to_child_templ_arg;
221   std::map<Node, std::vector<unsigned> > cop_to_carg_list;
222   std::map<Node, std::vector<TypeNode> > cop_to_child_types;
223   std::map<Node, std::vector<Node> > cop_to_sks;
224 
225   // whether we will enumerate the current type
226   bool search_this = false;
227   for (unsigned j = 0, ncons = dt.getNumConstructors(); j < ncons; j++)
228   {
229     Node cop = Node::fromExpr(dt[j].getConstructor());
230     Node op = Node::fromExpr(dt[j].getSygusOp());
231     Trace("sygus-unif-debug") << "--- Infer strategy from " << cop
232                               << " with sygus op " << op << "..." << std::endl;
233 
234     // expand the evaluation to see if this constuctor induces a strategy
235     std::vector<Node> utchildren;
236     utchildren.push_back(cop);
237     std::vector<Node> sks;
238     std::vector<TypeNode> sktns;
239     for (unsigned k = 0, nargs = dt[j].getNumArgs(); k < nargs; k++)
240     {
241       Type t = dt[j][k].getRangeType();
242       TypeNode ttn = TypeNode::fromType(t);
243       Node kv = nm->mkSkolem("ut", ttn);
244       sks.push_back(kv);
245       cop_to_sks[cop].push_back(kv);
246       sktns.push_back(ttn);
247       utchildren.push_back(kv);
248     }
249     Node ut = nm->mkNode(APPLY_CONSTRUCTOR, utchildren);
250     std::vector<Node> echildren;
251     echildren.push_back(ut);
252     Node sbvl = Node::fromExpr(dt.getSygusVarList());
253     for (const Node& sbv : sbvl)
254     {
255       echildren.push_back(sbv);
256     }
257     Node eut = nm->mkNode(DT_SYGUS_EVAL, echildren);
258     Trace("sygus-unif-debug2") << "  Test evaluation of " << eut << "..."
259                                << std::endl;
260     eut = d_qe->getTermDatabaseSygus()->unfold(eut);
261     Trace("sygus-unif-debug2") << "  ...got " << eut;
262     Trace("sygus-unif-debug2") << ", type : " << eut.getType() << std::endl;
263 
264     // candidate strategy
265     if (eut.getKind() == ITE)
266     {
267       cop_to_strat[cop].push_back(strat_ITE);
268     }
269     else if (eut.getKind() == STRING_CONCAT)
270     {
271       if (nrole != role_string_suffix)
272       {
273         cop_to_strat[cop].push_back(strat_CONCAT_PREFIX);
274       }
275       if (nrole != role_string_prefix)
276       {
277         cop_to_strat[cop].push_back(strat_CONCAT_SUFFIX);
278       }
279     }
280     else if (dt[j].isSygusIdFunc())
281     {
282       cop_to_strat[cop].push_back(strat_ID);
283     }
284 
285     // the kinds for which there is a strategy
286     if (cop_to_strat.find(cop) != cop_to_strat.end())
287     {
288       // infer an injection from the arguments of the datatype
289       std::map<unsigned, unsigned> templ_injection;
290       std::vector<Node> vs;
291       std::vector<Node> ss;
292       std::map<Node, unsigned> templ_var_index;
293       for (unsigned k = 0, sksize = sks.size(); k < sksize; k++)
294       {
295         Assert(sks[k].getType().isDatatype());
296         echildren[0] = sks[k];
297         Trace("sygus-unif-debug2") << "...set eval dt to " << sks[k]
298                                    << std::endl;
299         Node esk = nm->mkNode(DT_SYGUS_EVAL, echildren);
300         vs.push_back(esk);
301         Node tvar = nm->mkSkolem("templ", esk.getType());
302         templ_var_index[tvar] = k;
303         Trace("sygus-unif-debug2") << "* template inference : looking for "
304                                    << tvar << " for arg " << k << std::endl;
305         ss.push_back(tvar);
306         Trace("sygus-unif-debug2") << "* substitute : " << esk << " -> " << tvar
307                                    << std::endl;
308       }
309       eut = eut.substitute(vs.begin(), vs.end(), ss.begin(), ss.end());
310       Trace("sygus-unif-debug2") << "Constructor " << j << ", base term is "
311                                  << eut << std::endl;
312       std::map<unsigned, Node> test_args;
313       if (dt[j].isSygusIdFunc())
314       {
315         test_args[0] = eut;
316       }
317       else
318       {
319         for (unsigned k = 0, size = eut.getNumChildren(); k < size; k++)
320         {
321           test_args[k] = eut[k];
322         }
323       }
324 
325       // TODO : prefix grouping prefix/suffix
326       bool isAssoc = TermUtil::isAssoc(eut.getKind());
327       Trace("sygus-unif-debug2") << eut.getKind() << " isAssoc = " << isAssoc
328                                  << std::endl;
329       std::map<unsigned, std::vector<unsigned> > assoc_combine;
330       std::vector<unsigned> assoc_waiting;
331       int assoc_last_valid_index = -1;
332       for (std::pair<const unsigned, Node>& ta : test_args)
333       {
334         unsigned k = ta.first;
335         Node eut_c = ta.second;
336         // success if we can find a injection from args to sygus args
337         if (!inferTemplate(k, eut_c, templ_var_index, templ_injection))
338         {
339           Trace("sygus-unif-debug")
340               << "...fail: could not find injection (range)." << std::endl;
341           cop_to_strat.erase(cop);
342           break;
343         }
344         std::map<unsigned, unsigned>::iterator itti = templ_injection.find(k);
345         if (itti != templ_injection.end())
346         {
347           // if associative, combine arguments if it is the same variable
348           if (isAssoc && assoc_last_valid_index >= 0
349               && itti->second == templ_injection[assoc_last_valid_index])
350           {
351             templ_injection.erase(k);
352             assoc_combine[assoc_last_valid_index].push_back(k);
353           }
354           else
355           {
356             assoc_last_valid_index = (int)k;
357             if (!assoc_waiting.empty())
358             {
359               assoc_combine[k].insert(assoc_combine[k].end(),
360                                       assoc_waiting.begin(),
361                                       assoc_waiting.end());
362               assoc_waiting.clear();
363             }
364             assoc_combine[k].push_back(k);
365           }
366         }
367         else
368         {
369           // a ground argument
370           if (!isAssoc)
371           {
372             Trace("sygus-unif-debug")
373                 << "...fail: could not find injection (functional)."
374                 << std::endl;
375             cop_to_strat.erase(cop);
376             break;
377           }
378           else
379           {
380             if (assoc_last_valid_index >= 0)
381             {
382               assoc_combine[assoc_last_valid_index].push_back(k);
383             }
384             else
385             {
386               assoc_waiting.push_back(k);
387             }
388           }
389         }
390       }
391       if (cop_to_strat.find(cop) != cop_to_strat.end())
392       {
393         // construct the templates
394         if (!assoc_waiting.empty())
395         {
396           // could not find a way to fit some arguments into injection
397           cop_to_strat.erase(cop);
398         }
399         else
400         {
401           for (std::pair<const unsigned, Node>& ta : test_args)
402           {
403             unsigned k = ta.first;
404             Trace("sygus-unif-debug2") << "- processing argument " << k << "..."
405                                        << std::endl;
406             if (templ_injection.find(k) != templ_injection.end())
407             {
408               unsigned sk_index = templ_injection[k];
409               if (std::find(cop_to_carg_list[cop].begin(),
410                             cop_to_carg_list[cop].end(),
411                             sk_index)
412                   == cop_to_carg_list[cop].end())
413               {
414                 cop_to_carg_list[cop].push_back(sk_index);
415               }
416               else
417               {
418                 Trace("sygus-unif-debug") << "...fail: duplicate argument used"
419                                           << std::endl;
420                 cop_to_strat.erase(cop);
421                 break;
422               }
423               // also store the template information, if necessary
424               Node teut;
425               if (isAssoc)
426               {
427                 std::vector<unsigned>& ac = assoc_combine[k];
428                 Assert(!ac.empty());
429                 std::vector<Node> children;
430                 for (unsigned ack = 0, size_ac = ac.size(); ack < size_ac;
431                      ack++)
432                 {
433                   children.push_back(eut[ac[ack]]);
434                 }
435                 teut = children.size() == 1
436                            ? children[0]
437                            : nm->mkNode(eut.getKind(), children);
438                 teut = Rewriter::rewrite(teut);
439               }
440               else
441               {
442                 teut = ta.second;
443               }
444 
445               if (!teut.isVar())
446               {
447                 cop_to_child_templ[cop][k] = teut;
448                 cop_to_child_templ_arg[cop][k] = ss[sk_index];
449                 Trace("sygus-unif-debug")
450                     << "  Arg " << k << " (template : " << teut << " arg "
451                     << ss[sk_index] << "), index " << sk_index << std::endl;
452               }
453               else
454               {
455                 Trace("sygus-unif-debug") << "  Arg " << k << ", index "
456                                           << sk_index << std::endl;
457                 Assert(teut == ss[sk_index]);
458               }
459             }
460             else
461             {
462               Assert(isAssoc);
463             }
464           }
465         }
466       }
467     }
468 
469     std::map<Node, std::vector<StrategyType> >::iterator itcs = cop_to_strat.find(cop);
470     if (itcs != cop_to_strat.end())
471     {
472       Trace("sygus-unif") << "-> constructor " << cop
473                           << " matches strategy for " << eut.getKind() << "..."
474                           << std::endl;
475       // collect children types
476       for (unsigned k = 0, size = cop_to_carg_list[cop].size(); k < size; k++)
477       {
478         TypeNode ctn = sktns[cop_to_carg_list[cop][k]];
479         Trace("sygus-unif-debug")
480             << "   Child type " << k << " : "
481             << static_cast<DatatypeType>(ctn.toType()).getDatatype().getName()
482             << std::endl;
483         cop_to_child_types[cop].push_back(ctn);
484       }
485       // if there are checks on the consistency of child types wrt strategies,
486       // these should be enforced here. We currently have none.
487     }
488     if (cop_to_strat.find(cop) == cop_to_strat.end())
489     {
490       Trace("sygus-unif") << "...constructor " << cop
491                           << " does not correspond to a strategy." << std::endl;
492       search_this = true;
493     }
494   }
495 
496   // check whether we should also enumerate the current type
497   Trace("sygus-unif-debug2") << "  register this strategy ..." << std::endl;
498   registerStrategyPoint(ee, tn, erole, search_this);
499 
500   if (cop_to_strat.empty())
501   {
502     Trace("sygus-unif") << "...consider " << dt.getName() << " a basic type"
503                         << std::endl;
504   }
505   else
506   {
507     for (std::pair<const Node, std::vector<StrategyType> >& cstr : cop_to_strat)
508     {
509       Node cop = cstr.first;
510       Trace("sygus-unif-debug") << "Constructor " << cop << " has "
511                                 << cstr.second.size() << " strategies..."
512                                 << std::endl;
513       for (unsigned s = 0, ssize = cstr.second.size(); s < ssize; s++)
514       {
515         EnumTypeInfoStrat* cons_strat = new EnumTypeInfoStrat;
516         StrategyType strat = cstr.second[s];
517 
518         cons_strat->d_this = strat;
519         cons_strat->d_cons = cop;
520         Trace("sygus-unif-debug") << "Process strategy #" << s
521                                   << " for operator : " << cop << " : " << strat
522                                   << std::endl;
523         Assert(cop_to_child_types.find(cop) != cop_to_child_types.end());
524         std::vector<TypeNode>& childTypes = cop_to_child_types[cop];
525         Assert(cop_to_carg_list.find(cop) != cop_to_carg_list.end());
526         std::vector<unsigned>& cargList = cop_to_carg_list[cop];
527 
528         std::vector<Node> sol_templ_children;
529         sol_templ_children.resize(cop_to_sks[cop].size());
530 
531         for (unsigned j = 0, csize = childTypes.size(); j < csize; j++)
532         {
533           // calculate if we should allocate a new enumerator : should be true
534           // if we have a new role
535           NodeRole nrole_c = nrole;
536           if (strat == strat_ITE)
537           {
538             if (j == 0)
539             {
540               nrole_c = role_ite_condition;
541             }
542           }
543           else if (strat == strat_CONCAT_PREFIX)
544           {
545             if ((j + 1) < childTypes.size())
546             {
547               nrole_c = role_string_prefix;
548             }
549           }
550           else if (strat == strat_CONCAT_SUFFIX)
551           {
552             if (j > 0)
553             {
554               nrole_c = role_string_suffix;
555             }
556           }
557           // in all other cases, role is same as parent
558 
559           // register the child type
560           TypeNode ct = childTypes[j];
561           Node csk = cop_to_sks[cop][cargList[j]];
562           cons_strat->d_sol_templ_args.push_back(csk);
563           sol_templ_children[cargList[j]] = csk;
564 
565           EnumRole erole_c = getEnumeratorRoleForNodeRole(nrole_c);
566           // make the enumerator
567           Node et;
568           if (cop_to_child_templ[cop].find(j) != cop_to_child_templ[cop].end())
569           {
570             // it is templated, allocate a fresh variable
571             et = nm->mkSkolem("et", ct);
572             Trace("sygus-unif-debug")
573                 << "...enumerate " << et << " of type "
574                 << ((DatatypeType)ct.toType()).getDatatype().getName();
575             Trace("sygus-unif-debug") << " for arg " << j << " of "
576                                       << static_cast<DatatypeType>(tn.toType())
577                                              .getDatatype()
578                                              .getName()
579                                       << std::endl;
580             registerStrategyPoint(et, ct, erole_c, true);
581             d_einfo[et].d_template = cop_to_child_templ[cop][j];
582             d_einfo[et].d_template_arg = cop_to_child_templ_arg[cop][j];
583             Assert(!d_einfo[et].d_template.isNull());
584             Assert(!d_einfo[et].d_template_arg.isNull());
585           }
586           else
587           {
588             Trace("sygus-unif-debug")
589                 << "...child type enumerate "
590                 << ((DatatypeType)ct.toType()).getDatatype().getName()
591                 << ", node role = " << nrole_c << std::endl;
592             buildStrategyGraph(ct, nrole_c);
593             // otherwise use the previous
594             Assert(d_tinfo[ct].d_enum.find(erole_c)
595                    != d_tinfo[ct].d_enum.end());
596             et = d_tinfo[ct].d_enum[erole_c];
597           }
598           Trace("sygus-unif-debug") << "Register child enumerator " << et
599                                     << ", arg " << j << " of " << cop
600                                     << ", role = " << erole_c << std::endl;
601           Assert(!et.isNull());
602           cons_strat->d_cenum.push_back(std::pair<Node, NodeRole>(et, nrole_c));
603         }
604         // children that are unused in the strategy can be arbitrary
605         for (unsigned j = 0, stsize = sol_templ_children.size(); j < stsize;
606              j++)
607         {
608           if (sol_templ_children[j].isNull())
609           {
610             sol_templ_children[j] = cop_to_sks[cop][j].getType().mkGroundTerm();
611           }
612         }
613         sol_templ_children.insert(sol_templ_children.begin(), cop);
614         cons_strat->d_sol_templ =
615             nm->mkNode(APPLY_CONSTRUCTOR, sol_templ_children);
616         if (strat == strat_CONCAT_SUFFIX)
617         {
618           std::reverse(cons_strat->d_cenum.begin(), cons_strat->d_cenum.end());
619           std::reverse(cons_strat->d_sol_templ_args.begin(),
620                        cons_strat->d_sol_templ_args.end());
621         }
622         if (Trace.isOn("sygus-unif"))
623         {
624           Trace("sygus-unif") << "Initialized strategy " << strat;
625           Trace("sygus-unif")
626               << " for "
627               << static_cast<DatatypeType>(tn.toType()).getDatatype().getName()
628               << ", operator " << cop;
629           Trace("sygus-unif") << ", #children = " << cons_strat->d_cenum.size()
630                               << ", solution template = (lambda ( ";
631           for (const Node& targ : cons_strat->d_sol_templ_args)
632           {
633             Trace("sygus-unif") << targ << " ";
634           }
635           Trace("sygus-unif") << ") " << cons_strat->d_sol_templ << ")";
636           Trace("sygus-unif") << std::endl;
637         }
638         // make the strategy
639         snode.d_strats.push_back(cons_strat);
640       }
641     }
642   }
643 }
644 
inferTemplate(unsigned k,Node n,std::map<Node,unsigned> & templ_var_index,std::map<unsigned,unsigned> & templ_injection)645 bool SygusUnifStrategy::inferTemplate(
646     unsigned k,
647     Node n,
648     std::map<Node, unsigned>& templ_var_index,
649     std::map<unsigned, unsigned>& templ_injection)
650 {
651   if (n.getNumChildren() == 0)
652   {
653     std::map<Node, unsigned>::iterator itt = templ_var_index.find(n);
654     if (itt != templ_var_index.end())
655     {
656       unsigned kk = itt->second;
657       std::map<unsigned, unsigned>::iterator itti = templ_injection.find(k);
658       if (itti == templ_injection.end())
659       {
660         Trace("sygus-unif-debug") << "...set template injection " << k << " -> "
661                                   << kk << std::endl;
662         templ_injection[k] = kk;
663       }
664       else if (itti->second != kk)
665       {
666         // two distinct variables in this term, we fail
667         return false;
668       }
669     }
670     return true;
671   }
672   else
673   {
674     for (unsigned i = 0; i < n.getNumChildren(); i++)
675     {
676       if (!inferTemplate(k, n[i], templ_var_index, templ_injection))
677       {
678         return false;
679       }
680     }
681   }
682   return true;
683 }
684 
staticLearnRedundantOps(std::map<Node,std::vector<Node>> & strategy_lemmas)685 void SygusUnifStrategy::staticLearnRedundantOps(
686     std::map<Node, std::vector<Node>>& strategy_lemmas)
687 {
688   StrategyRestrictions restrictions;
689   staticLearnRedundantOps(strategy_lemmas, restrictions);
690 }
691 
staticLearnRedundantOps(std::map<Node,std::vector<Node>> & strategy_lemmas,StrategyRestrictions & restrictions)692 void SygusUnifStrategy::staticLearnRedundantOps(
693     std::map<Node, std::vector<Node>>& strategy_lemmas,
694     StrategyRestrictions& restrictions)
695 {
696   for (unsigned i = 0; i < d_esym_list.size(); i++)
697   {
698     Node e = d_esym_list[i];
699     std::map<Node, EnumInfo>::iterator itn = d_einfo.find(e);
700     Assert(itn != d_einfo.end());
701     // see if there is anything we can eliminate
702     Trace("sygus-unif")
703         << "* Search enumerator #" << i << " : type "
704         << ((DatatypeType)e.getType().toType()).getDatatype().getName()
705         << " : ";
706     Trace("sygus-unif") << e << " has " << itn->second.d_enum_slave.size()
707                         << " slaves:" << std::endl;
708     for (unsigned j = 0; j < itn->second.d_enum_slave.size(); j++)
709     {
710       Node es = itn->second.d_enum_slave[j];
711       std::map<Node, EnumInfo>::iterator itns = d_einfo.find(es);
712       Assert(itns != d_einfo.end());
713       Trace("sygus-unif") << "  " << es << ", role = " << itns->second.getRole()
714                           << std::endl;
715     }
716   }
717   Trace("sygus-unif") << std::endl;
718   Trace("sygus-unif") << "Strategy for candidate " << d_candidate
719                       << " is : " << std::endl;
720   debugPrint("sygus-unif");
721   std::map<Node, std::map<NodeRole, bool> > visited;
722   std::map<Node, std::map<unsigned, bool> > needs_cons;
723   staticLearnRedundantOps(
724       getRootEnumerator(), role_equal, visited, needs_cons, restrictions);
725   // now, check the needs_cons map
726   for (std::pair<const Node, std::map<unsigned, bool> >& nce : needs_cons)
727   {
728     Node em = nce.first;
729     const Datatype& dt =
730         static_cast<DatatypeType>(em.getType().toType()).getDatatype();
731     std::vector<Node> lemmas;
732     for (std::pair<const unsigned, bool>& nc : nce.second)
733     {
734       Assert(nc.first < dt.getNumConstructors());
735       if (!nc.second)
736       {
737         Node tst =
738             datatypes::DatatypesRewriter::mkTester(em, nc.first, dt).negate();
739 
740         if (std::find(lemmas.begin(), lemmas.end(), tst) == lemmas.end())
741         {
742           Trace("sygus-unif") << "...can exclude based on  : " << tst
743                               << std::endl;
744           lemmas.push_back(tst);
745         }
746       }
747     }
748     if (!lemmas.empty())
749     {
750       strategy_lemmas[em] = lemmas;
751     }
752   }
753 }
754 
debugPrint(const char * c)755 void SygusUnifStrategy::debugPrint(const char* c)
756 {
757   if (Trace.isOn(c))
758   {
759     std::map<Node, std::map<NodeRole, bool> > visited;
760     debugPrint(c, getRootEnumerator(), role_equal, visited, 0);
761   }
762 }
763 
staticLearnRedundantOps(Node e,NodeRole nrole,std::map<Node,std::map<NodeRole,bool>> & visited,std::map<Node,std::map<unsigned,bool>> & needs_cons,StrategyRestrictions & restrictions)764 void SygusUnifStrategy::staticLearnRedundantOps(
765     Node e,
766     NodeRole nrole,
767     std::map<Node, std::map<NodeRole, bool>>& visited,
768     std::map<Node, std::map<unsigned, bool>>& needs_cons,
769     StrategyRestrictions& restrictions)
770 {
771   if (visited[e].find(nrole) != visited[e].end())
772   {
773     return;
774   }
775   Trace("sygus-strat-slearn") << "Learn redundant operators " << e << " "
776                               << nrole << "..." << std::endl;
777   visited[e][nrole] = true;
778   EnumInfo& ei = getEnumInfo(e);
779   if (ei.isTemplated())
780   {
781     return;
782   }
783   TypeNode etn = e.getType();
784   EnumTypeInfo& tinfo = getEnumTypeInfo(etn);
785   StrategyNode& snode = tinfo.getStrategyNode(nrole);
786   // the constructors of the current strategy point we need
787   std::map<unsigned, bool> needs_cons_curr;
788   // get the unused strategies
789   std::map<Node, std::unordered_set<unsigned>>::iterator itus =
790       restrictions.d_unused_strategies.find(e);
791   std::unordered_set<unsigned> unused_strats;
792   if (itus != restrictions.d_unused_strategies.end())
793   {
794     unused_strats.insert(itus->second.begin(), itus->second.end());
795   }
796   for (unsigned j = 0, size = snode.d_strats.size(); j < size; j++)
797   {
798     // if we are not using this strategy, there is nothing to do
799     if (unused_strats.find(j) != unused_strats.end())
800     {
801       continue;
802     }
803     EnumTypeInfoStrat* etis = snode.d_strats[j];
804     unsigned cindex = datatypes::DatatypesRewriter::indexOf(etis->d_cons);
805     // constructors that correspond to strategies are not needed
806     // the intuition is that the strategy itself is responsible for constructing
807     // all terms that use the given constructor
808     Trace("sygus-strat-slearn") << "...by strategy, can exclude operator "
809                                 << etis->d_cons << std::endl;
810     needs_cons_curr[cindex] = false;
811     // try to eliminate from etn's datatype all operators except TRUE/FALSE if
812     // arguments of ITE are the same BOOL type
813     if (restrictions.d_iteReturnBoolConst)
814     {
815       const Datatype& dt =
816           static_cast<DatatypeType>(etn.toType()).getDatatype();
817       Node op = Node::fromExpr(dt[cindex].getSygusOp());
818       TypeNode sygus_tn = TypeNode::fromType(dt.getSygusType());
819       if (op.getKind() == kind::BUILTIN
820           && NodeManager::operatorToKind(op) == ITE
821           && sygus_tn.isBoolean()
822           && (TypeNode::fromType(dt[cindex].getArgType(1))
823               == TypeNode::fromType(dt[cindex].getArgType(2))))
824       {
825         unsigned ncons = dt.getNumConstructors(), indexT = ncons,
826                  indexF = ncons;
827         for (unsigned k = 0; k < ncons; ++k)
828         {
829           Node op_arg = Node::fromExpr(dt[k].getSygusOp());
830           if (dt[k].getNumArgs() > 0 || !op_arg.isConst())
831           {
832             continue;
833           }
834           if (op_arg.getConst<bool>())
835           {
836             indexT = k;
837           }
838           else
839           {
840             indexF = k;
841           }
842         }
843         if (indexT < ncons && indexF < ncons)
844         {
845           Trace("sygus-strat-slearn")
846               << "...for ite boolean arg, can exclude all operators but T/F\n";
847           for (unsigned k = 0; k < ncons; ++k)
848           {
849             needs_cons_curr[k] = false;
850           }
851           needs_cons_curr[indexT] = true;
852           needs_cons_curr[indexF] = true;
853         }
854       }
855     }
856     for (std::pair<Node, NodeRole>& cec : etis->d_cenum)
857     {
858       staticLearnRedundantOps(
859           cec.first, cec.second, visited, needs_cons, restrictions);
860     }
861   }
862   // get the current datatype
863   const Datatype& dt = static_cast<DatatypeType>(etn.toType()).getDatatype();
864   // do not use recursive Boolean connectives for conditions of ITEs
865   if (nrole == role_ite_condition && restrictions.d_iteCondOnlyAtoms)
866   {
867     TypeNode sygus_tn = TypeNode::fromType(dt.getSygusType());
868     for (unsigned j = 0, size = dt.getNumConstructors(); j < size; j++)
869     {
870       Node op = Node::fromExpr(dt[j].getSygusOp());
871       Trace("sygus-strat-slearn")
872           << "...for ite condition, look at operator : " << op << std::endl;
873       if (op.isConst() && dt[j].getNumArgs() == 0)
874       {
875         Trace("sygus-strat-slearn")
876             << "...for ite condition, can exclude Boolean constant " << op
877             << std::endl;
878         needs_cons_curr[j] = false;
879         continue;
880       }
881       if (op.getKind() == kind::BUILTIN)
882       {
883         Kind k = NodeManager::operatorToKind(op);
884         if (k == NOT || k == OR || k == AND || k == ITE)
885         {
886           // can eliminate if their argument types are simple loops to this type
887           bool type_ok = true;
888           for (unsigned k = 0, nargs = dt[j].getNumArgs(); k < nargs; k++)
889           {
890             TypeNode tn = TypeNode::fromType(dt[j].getArgType(k));
891             if (tn != etn)
892             {
893               type_ok = false;
894               break;
895             }
896           }
897           if (type_ok)
898           {
899             Trace("sygus-strat-slearn")
900                 << "...for ite condition, can exclude Boolean connective : "
901                 << op << std::endl;
902             needs_cons_curr[j] = false;
903           }
904         }
905       }
906     }
907   }
908   // all other constructors are needed
909   for (unsigned j = 0, size = dt.getNumConstructors(); j < size; j++)
910   {
911     if (needs_cons_curr.find(j) == needs_cons_curr.end())
912     {
913       needs_cons_curr[j] = true;
914     }
915   }
916   // update the constructors that the master enumerator needs
917   if (needs_cons.find(e) == needs_cons.end())
918   {
919     needs_cons[e] = needs_cons_curr;
920   }
921   else
922   {
923     for (unsigned j = 0, size = dt.getNumConstructors(); j < size; j++)
924     {
925       needs_cons[e][j] = needs_cons[e][j] || needs_cons_curr[j];
926     }
927   }
928 }
929 
finishInit(Node e,NodeRole nrole,std::map<Node,std::map<NodeRole,bool>> & visited,bool isCond)930 void SygusUnifStrategy::finishInit(
931     Node e,
932     NodeRole nrole,
933     std::map<Node, std::map<NodeRole, bool> >& visited,
934     bool isCond)
935 {
936   EnumInfo& ei = getEnumInfo(e);
937   if (visited[e].find(nrole) != visited[e].end()
938       && (!isCond || ei.isConditional()))
939   {
940     return;
941   }
942   visited[e][nrole] = true;
943   // set conditional
944   if (isCond)
945   {
946     ei.setConditional();
947   }
948   if (ei.isTemplated())
949   {
950     return;
951   }
952   TypeNode etn = e.getType();
953   EnumTypeInfo& tinfo = getEnumTypeInfo(etn);
954   StrategyNode& snode = tinfo.getStrategyNode(nrole);
955   for (unsigned j = 0, size = snode.d_strats.size(); j < size; j++)
956   {
957     EnumTypeInfoStrat* etis = snode.d_strats[j];
958     StrategyType strat = etis->d_this;
959     bool newIsCond = isCond || strat == strat_ITE;
960     for (std::pair<Node, NodeRole>& cec : etis->d_cenum)
961     {
962       finishInit(cec.first, cec.second, visited, newIsCond);
963     }
964   }
965 }
966 
debugPrint(const char * c,Node e,NodeRole nrole,std::map<Node,std::map<NodeRole,bool>> & visited,int ind)967 void SygusUnifStrategy::debugPrint(
968     const char* c,
969     Node e,
970     NodeRole nrole,
971     std::map<Node, std::map<NodeRole, bool> >& visited,
972     int ind)
973 {
974   if (visited[e].find(nrole) != visited[e].end())
975   {
976     indent(c, ind);
977     Trace(c) << e << " :: node role : " << nrole << std::endl;
978     return;
979   }
980   visited[e][nrole] = true;
981   EnumInfo& ei = getEnumInfo(e);
982 
983   TypeNode etn = e.getType();
984 
985   indent(c, ind);
986   Trace(c) << e << " :: node role : " << nrole;
987   Trace(c) << ", type : "
988            << static_cast<DatatypeType>(etn.toType()).getDatatype().getName();
989   if (ei.isConditional())
990   {
991     Trace(c) << ", conditional";
992   }
993   Trace(c) << ", enum role : " << ei.getRole();
994 
995   if (ei.isTemplated())
996   {
997     Trace(c) << ", templated : (lambda " << ei.d_template_arg << " "
998              << ei.d_template << ")" << std::endl;
999     return;
1000   }
1001   Trace(c) << std::endl;
1002 
1003   EnumTypeInfo& tinfo = getEnumTypeInfo(etn);
1004   StrategyNode& snode = tinfo.getStrategyNode(nrole);
1005   for (unsigned j = 0, size = snode.d_strats.size(); j < size; j++)
1006   {
1007     EnumTypeInfoStrat* etis = snode.d_strats[j];
1008     StrategyType strat = etis->d_this;
1009     indent(c, ind + 1);
1010     Trace(c) << "Strategy : " << strat << ", from cons : " << etis->d_cons
1011              << std::endl;
1012     for (std::pair<Node, NodeRole>& cec : etis->d_cenum)
1013     {
1014       // recurse
1015       debugPrint(c, cec.first, cec.second, visited, ind + 2);
1016     }
1017   }
1018 }
1019 
initialize(EnumRole role)1020 void EnumInfo::initialize(EnumRole role) { d_role = role; }
1021 
getStrategyNode(NodeRole nrole)1022 StrategyNode& EnumTypeInfo::getStrategyNode(NodeRole nrole)
1023 {
1024   std::map<NodeRole, StrategyNode>::iterator it = d_snodes.find(nrole);
1025   Assert(it != d_snodes.end());
1026   return it->second;
1027 }
1028 
isValid(UnifContext & x)1029 bool EnumTypeInfoStrat::isValid(UnifContext& x)
1030 {
1031   if ((x.getCurrentRole() == role_string_prefix
1032        && d_this == strat_CONCAT_SUFFIX)
1033       || (x.getCurrentRole() == role_string_suffix
1034           && d_this == strat_CONCAT_PREFIX))
1035   {
1036     return false;
1037   }
1038   return true;
1039 }
1040 
~StrategyNode()1041 StrategyNode::~StrategyNode()
1042 {
1043   for (unsigned j = 0, size = d_strats.size(); j < size; j++)
1044   {
1045     delete d_strats[j];
1046   }
1047   d_strats.clear();
1048 }
1049 
indent(const char * c,int ind)1050 void SygusUnifStrategy::indent(const char* c, int ind)
1051 {
1052   if (Trace.isOn(c))
1053   {
1054     for (int i = 0; i < ind; i++)
1055     {
1056       Trace(c) << "  ";
1057     }
1058   }
1059 }
1060 
1061 } /* CVC4::theory::quantifiers namespace */
1062 } /* CVC4::theory namespace */
1063 } /* CVC4 namespace */
1064