1 /*********************                                                        */
2 /*! \file sygus_unif_rl.cpp
3  ** \verbatim
4  ** Top contributors (to current version):
5  **   Haniel Barbosa, Andrew Reynolds
6  ** This file is part of the CVC4 project.
7  ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8  ** in the top-level source directory) and their institutional affiliations.
9  ** All rights reserved.  See the file COPYING in the top-level source
10  ** directory for licensing information.\endverbatim
11  **
12  ** \brief Implementation of sygus_unif_rl
13  **/
14 
15 #include "theory/quantifiers/sygus/sygus_unif_rl.h"
16 
17 #include "options/base_options.h"
18 #include "options/quantifiers_options.h"
19 #include "printer/printer.h"
20 #include "theory/datatypes/datatypes_rewriter.h"
21 #include "theory/quantifiers/sygus/synth_conjecture.h"
22 #include "theory/quantifiers/sygus/term_database_sygus.h"
23 #include "util/random.h"
24 
25 #include <math.h>
26 
27 using namespace CVC4::kind;
28 
29 namespace CVC4 {
30 namespace theory {
31 namespace quantifiers {
32 
SygusUnifRl(SynthConjecture * p)33 SygusUnifRl::SygusUnifRl(SynthConjecture* p) : d_parent(p) {}
~SygusUnifRl()34 SygusUnifRl::~SygusUnifRl() {}
initializeCandidate(QuantifiersEngine * qe,Node f,std::vector<Node> & enums,std::map<Node,std::vector<Node>> & strategy_lemmas)35 void SygusUnifRl::initializeCandidate(
36     QuantifiersEngine* qe,
37     Node f,
38     std::vector<Node>& enums,
39     std::map<Node, std::vector<Node>>& strategy_lemmas)
40 {
41   // initialize
42   std::vector<Node> all_enums;
43   SygusUnif::initializeCandidate(qe, f, all_enums, strategy_lemmas);
44   // based on the strategy inferred for each function, determine if we are
45   // using a unification strategy that is compatible our approach.
46   StrategyRestrictions restrictions;
47   if (options::sygusBoolIteReturnConst())
48   {
49     restrictions.d_iteReturnBoolConst = true;
50   }
51   // register the strategy
52   registerStrategy(f, enums, restrictions.d_unused_strategies);
53   d_strategy[f].staticLearnRedundantOps(strategy_lemmas, restrictions);
54   // Copy candidates and check whether CegisUnif for any of them
55   if (d_unif_candidates.find(f) != d_unif_candidates.end())
56   {
57     d_hd_to_pt[f].clear();
58     d_cand_to_eval_hds[f].clear();
59     d_cand_to_hd_count[f] = 0;
60   }
61 }
62 
notifyEnumeration(Node e,Node v,std::vector<Node> & lemmas)63 void SygusUnifRl::notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas)
64 {
65   // we do not use notify enumeration
66   Assert(false);
67 }
68 
purifyLemma(Node n,bool ensureConst,std::vector<Node> & model_guards,BoolNodePairMap & cache)69 Node SygusUnifRl::purifyLemma(Node n,
70                               bool ensureConst,
71                               std::vector<Node>& model_guards,
72                               BoolNodePairMap& cache)
73 {
74   Trace("sygus-unif-rl-purify") << "PurifyLemma : " << n << "\n";
75   BoolNodePairMap::const_iterator it = cache.find(BoolNodePair(ensureConst, n));
76   if (it != cache.end())
77   {
78     Trace("sygus-unif-rl-purify-debug") << "... already visited " << n << "\n";
79     return it->second;
80   }
81   // Recurse
82   unsigned size = n.getNumChildren();
83   Kind k = n.getKind();
84   // We retrive model value now because purified node may not have a value
85   Node nv = n;
86   // Whether application of a function-to-synthesize
87   bool fapp = (n.getKind() == DT_SYGUS_EVAL);
88   bool u_fapp = false;
89   bool nu_fapp = false;
90   if (fapp)
91   {
92     Assert(std::find(d_candidates.begin(), d_candidates.end(), n[0])
93            != d_candidates.end());
94     // Whether application of a (non-)unification function-to-synthesize
95     u_fapp = usingUnif(n[0]);
96     nu_fapp = !usingUnif(n[0]);
97     // get model value of non-top level applications of functions-to-synthesize
98     // occurring under a unification function-to-synthesize
99     if (ensureConst)
100     {
101       std::map<Node, Node>::iterator it = d_cand_to_sol.find(n[0]);
102       // if function-to-synthesize, retrieve its built solution to replace in
103       // the application before computing the model value
104       AlwaysAssert(!u_fapp || it != d_cand_to_sol.end());
105       if (it != d_cand_to_sol.end())
106       {
107         TNode cand = n[0];
108         Node tmp = n.substitute(cand, it->second);
109         nv = d_tds->evaluateWithUnfolding(tmp);
110         Trace("sygus-unif-rl-purify")
111             << "PurifyLemma : model value for " << tmp << " is " << nv << "\n";
112       }
113       else
114       {
115         nv = d_parent->getModelValue(n);
116         Trace("sygus-unif-rl-purify")
117             << "PurifyLemma : model value for " << n << " is " << nv << "\n";
118       }
119       Assert(n != nv);
120     }
121   }
122   // Travese to purify
123   bool childChanged = false;
124   std::vector<Node> children;
125   NodeManager* nm = NodeManager::currentNM();
126   for (unsigned i = 0; i < size; ++i)
127   {
128     if (i == 0 && fapp)
129     {
130       children.push_back(n[i]);
131       continue;
132     }
133     // Arguments of non-unif functions do not need to be constant
134     Node child = purifyLemma(
135         n[i], !nu_fapp && (ensureConst || u_fapp), model_guards, cache);
136     children.push_back(child);
137     childChanged = childChanged || child != n[i];
138   }
139   Node nb;
140   if (childChanged)
141   {
142     if (n.getMetaKind() == metakind::PARAMETERIZED)
143     {
144       Trace("sygus-unif-rl-purify-debug")
145           << "Node " << n << " is parameterized\n";
146       children.insert(children.begin(), n.getOperator());
147     }
148     if (Trace.isOn("sygus-unif-rl-purify-debug"))
149     {
150       Trace("sygus-unif-rl-purify-debug")
151           << "...rebuilding " << n << " with kind " << k << " and children:\n";
152       for (const Node& child : children)
153       {
154         Trace("sygus-unif-rl-purify-debug") << "...... " << child << "\n";
155       }
156     }
157     nb = NodeManager::currentNM()->mkNode(k, children);
158     Trace("sygus-unif-rl-purify")
159         << "PurifyLemma : transformed " << n << " into " << nb << "\n";
160   }
161   else
162   {
163     nb = n;
164   }
165   // Map to point enumerator every unification function-to-synthesize
166   if (u_fapp)
167   {
168     Node np;
169     std::map<Node, Node>::const_iterator it = d_app_to_purified.find(nb);
170     if (it == d_app_to_purified.end())
171     {
172       // Build purified head with fresh skolem and recreate node
173       std::stringstream ss;
174       ss << nb[0] << "_" << d_cand_to_hd_count[nb[0]]++;
175       Node new_f = nm->mkSkolem(ss.str(),
176                                 nb[0].getType(),
177                                 "head of unif evaluation point",
178                                 NodeManager::SKOLEM_EXACT_NAME);
179       // Adds new enumerator to map from candidate
180       Trace("sygus-unif-rl-purify")
181           << "...new enum " << new_f << " for candidate " << nb[0] << "\n";
182       d_cand_to_eval_hds[nb[0]].push_back(new_f);
183       // Maps new enumerator to its respective tuple of arguments
184       d_hd_to_pt[new_f] =
185           std::vector<Node>(children.begin() + 1, children.end());
186       if (Trace.isOn("sygus-unif-rl-purify-debug"))
187       {
188         Trace("sygus-unif-rl-purify-debug") << "...[" << new_f << "] --> ( ";
189         for (const Node& pt_i : d_hd_to_pt[new_f])
190         {
191           Trace("sygus-unif-rl-purify-debug") << pt_i << " ";
192         }
193         Trace("sygus-unif-rl-purify-debug") << ")\n";
194       }
195       // replace first child and rebulid node
196       Assert(children.size() > 0);
197       children[0] = new_f;
198       Trace("sygus-unif-rl-purify-debug")
199           << "Make sygus eval app " << children << std::endl;
200       np = nm->mkNode(DT_SYGUS_EVAL, children);
201       d_app_to_purified[nb] = np;
202     }
203     else
204     {
205       np = it->second;
206     }
207     Trace("sygus-unif-rl-purify")
208         << "PurifyLemma : purified head and transformed " << nb << " into "
209         << np << "\n";
210     nb = np;
211   }
212   // Add equality between purified fapp and model value
213   if (ensureConst && fapp)
214   {
215     model_guards.push_back(
216         NodeManager::currentNM()->mkNode(EQUAL, nv, nb).negate());
217     nb = nv;
218     Trace("sygus-unif-rl-purify")
219         << "PurifyLemma : adding model eq " << model_guards.back() << "\n";
220   }
221   nb = Rewriter::rewrite(nb);
222   // every non-top level application of function-to-synthesize must be reduced
223   // to a concrete constant
224   Assert(!ensureConst || nb.isConst());
225   Trace("sygus-unif-rl-purify-debug")
226       << "... caching [" << n << "] = " << nb << "\n";
227   cache[BoolNodePair(ensureConst, n)] = nb;
228   return nb;
229 }
230 
addRefLemma(Node lemma,std::map<Node,std::vector<Node>> & eval_hds)231 Node SygusUnifRl::addRefLemma(Node lemma,
232                               std::map<Node, std::vector<Node>>& eval_hds)
233 {
234   Trace("sygus-unif-rl-purify")
235       << "Registering lemma at SygusUnif : " << lemma << "\n";
236   std::vector<Node> model_guards;
237   BoolNodePairMap cache;
238   // cache previous sizes
239   std::map<Node, unsigned> prev_n_eval_hds;
240   for (const std::pair<const Node, std::vector<Node>>& cp : d_cand_to_eval_hds)
241   {
242     prev_n_eval_hds[cp.first] = cp.second.size();
243   }
244 
245   // Make the purified lemma which will guide the unification utility.
246   Node plem = purifyLemma(lemma, false, model_guards, cache);
247   if (!model_guards.empty())
248   {
249     model_guards.push_back(plem);
250     plem = NodeManager::currentNM()->mkNode(OR, model_guards);
251   }
252   plem = Rewriter::rewrite(plem);
253   Trace("sygus-unif-rl-purify") << "Purified lemma : " << plem << "\n";
254 
255   Trace("sygus-unif-rl-purify") << "Collect new evaluation points...\n";
256   for (const std::pair<const Node, std::vector<Node>>& cp : d_cand_to_eval_hds)
257   {
258     Node c = cp.first;
259     unsigned prevn = 0;
260     std::map<Node, unsigned>::iterator itp = prev_n_eval_hds.find(c);
261     if (itp != prev_n_eval_hds.end())
262     {
263       prevn = itp->second;
264     }
265     for (unsigned j = prevn, size = cp.second.size(); j < size; j++)
266     {
267       eval_hds[c].push_back(cp.second[j]);
268       // Add new point to respective decision trees
269       Assert(d_cand_cenums.find(c) != d_cand_cenums.end());
270       for (const Node& cenum : d_cand_cenums[c])
271       {
272         Assert(d_cenum_to_stratpt.find(cenum) != d_cenum_to_stratpt.end());
273         for (const Node& stratpt : d_cenum_to_stratpt[cenum])
274         {
275           Assert(d_stratpt_to_dt.find(stratpt) != d_stratpt_to_dt.end());
276           Trace("sygus-unif-rl-dt")
277               << "Register point with head " << cp.second[j]
278               << " to strategy point " << stratpt << "\n";
279           // Register new point from new head
280           d_stratpt_to_dt[stratpt].d_hds.push_back(cp.second[j]);
281         }
282       }
283     }
284   }
285 
286   return plem;
287 }
288 
initializeConstructSol()289 void SygusUnifRl::initializeConstructSol() {}
initializeConstructSolFor(Node f)290 void SygusUnifRl::initializeConstructSolFor(Node f) {}
constructSolution(std::vector<Node> & sols,std::vector<Node> & lemmas)291 bool SygusUnifRl::constructSolution(std::vector<Node>& sols,
292                                     std::vector<Node>& lemmas)
293 {
294   initializeConstructSol();
295   bool successful = true;
296   for (const Node& c : d_candidates)
297   {
298     if (!usingUnif(c))
299     {
300       Node v = d_parent->getModelValue(c);
301       sols.push_back(v);
302       continue;
303     }
304     initializeConstructSolFor(c);
305     Node v = constructSol(
306         c, d_strategy[c].getRootEnumerator(), role_equal, 0, lemmas);
307     if (v.isNull())
308     {
309       // we continue trying to build solutions to accumulate potentitial
310       // separation conditions from other decision trees
311       successful = false;
312       continue;
313     }
314     sols.push_back(v);
315     d_cand_to_sol[c] = v;
316   }
317   return successful;
318 }
319 
constructSol(Node f,Node e,NodeRole nrole,int ind,std::vector<Node> & lemmas)320 Node SygusUnifRl::constructSol(
321     Node f, Node e, NodeRole nrole, int ind, std::vector<Node>& lemmas)
322 {
323   indent("sygus-unif-sol", ind);
324   Trace("sygus-unif-sol") << "ConstructSol: SygusRL : " << e << std::endl;
325   // retrieve strategy information
326   TypeNode etn = e.getType();
327   EnumTypeInfo& tinfo = d_strategy[f].getEnumTypeInfo(etn);
328   StrategyNode& snode = tinfo.getStrategyNode(nrole);
329   if (nrole != role_equal)
330   {
331     return Node::null();
332   }
333   // is there a decision tree strategy?
334   std::map<Node, DecisionTreeInfo>::iterator itd = d_stratpt_to_dt.find(e);
335   // for now only considering simple case of sole "ITE(cond, e, e)" strategy
336   if (itd == d_stratpt_to_dt.end())
337   {
338     return Node::null();
339   }
340   indent("sygus-unif-sol", ind);
341   Trace("sygus-unif-sol") << "...it has a decision tree strategy.\n";
342   // whether empty set of points
343   if (d_cand_to_eval_hds[f].empty())
344   {
345     Trace("sygus-unif-sol") << "...... no points, return root enum value "
346                             << d_parent->getModelValue(e) << "\n";
347     return d_parent->getModelValue(e);
348   }
349   EnumTypeInfoStrat* etis = snode.d_strats[itd->second.getStrategyIndex()];
350   Node sol = itd->second.buildSol(etis->d_cons, lemmas);
351   Assert(options::sygusUnifCondIndependent() || !sol.isNull()
352          || !lemmas.empty());
353   return sol;
354 }
355 
usingUnif(Node f) const356 bool SygusUnifRl::usingUnif(Node f) const
357 {
358   return d_unif_candidates.find(f) != d_unif_candidates.end();
359 }
360 
getConditionForEvaluationPoint(Node e) const361 Node SygusUnifRl::getConditionForEvaluationPoint(Node e) const
362 {
363   std::map<Node, DecisionTreeInfo>::const_iterator it = d_stratpt_to_dt.find(e);
364   Assert(it != d_stratpt_to_dt.end());
365   return it->second.getConditionEnumerator();
366 }
367 
setConditions(Node e,Node guard,const std::vector<Node> & enums,const std::vector<Node> & conds)368 void SygusUnifRl::setConditions(Node e,
369                                 Node guard,
370                                 const std::vector<Node>& enums,
371                                 const std::vector<Node>& conds)
372 {
373   std::map<Node, DecisionTreeInfo>::iterator it = d_stratpt_to_dt.find(e);
374   Assert(it != d_stratpt_to_dt.end());
375   // set the conditions for the appropriate tree
376   it->second.setConditions(guard, enums, conds);
377 }
378 
getEvalPointHeads(Node c)379 std::vector<Node> SygusUnifRl::getEvalPointHeads(Node c)
380 {
381   std::map<Node, std::vector<Node>>::iterator it = d_cand_to_eval_hds.find(c);
382   if (it == d_cand_to_eval_hds.end())
383   {
384     return std::vector<Node>();
385   }
386   return it->second;
387 }
388 
registerStrategy(Node f,std::vector<Node> & enums,std::map<Node,std::unordered_set<unsigned>> & unused_strats)389 void SygusUnifRl::registerStrategy(
390     Node f,
391     std::vector<Node>& enums,
392     std::map<Node, std::unordered_set<unsigned>>& unused_strats)
393 {
394   if (Trace.isOn("sygus-unif-rl-strat"))
395   {
396     Trace("sygus-unif-rl-strat")
397         << "Strategy for " << f << " is : " << std::endl;
398     d_strategy[f].debugPrint("sygus-unif-rl-strat");
399   }
400   Trace("sygus-unif-rl-strat") << "Register..." << std::endl;
401   Node e = d_strategy[f].getRootEnumerator();
402   std::map<Node, std::map<NodeRole, bool>> visited;
403   registerStrategyNode(f, e, role_equal, visited, enums, unused_strats);
404 }
405 
registerStrategyNode(Node f,Node e,NodeRole nrole,std::map<Node,std::map<NodeRole,bool>> & visited,std::vector<Node> & enums,std::map<Node,std::unordered_set<unsigned>> & unused_strats)406 void SygusUnifRl::registerStrategyNode(
407     Node f,
408     Node e,
409     NodeRole nrole,
410     std::map<Node, std::map<NodeRole, bool>>& visited,
411     std::vector<Node>& enums,
412     std::map<Node, std::unordered_set<unsigned>>& unused_strats)
413 {
414   Trace("sygus-unif-rl-strat") << "  register node " << e << std::endl;
415   if (visited[e].find(nrole) != visited[e].end())
416   {
417     return;
418   }
419   visited[e][nrole] = true;
420   TypeNode etn = e.getType();
421   EnumTypeInfo& tinfo = d_strategy[f].getEnumTypeInfo(etn);
422   StrategyNode& snode = tinfo.getStrategyNode(nrole);
423   for (unsigned j = 0, size = snode.d_strats.size(); j < size; j++)
424   {
425     EnumTypeInfoStrat* etis = snode.d_strats[j];
426     StrategyType strat = etis->d_this;
427     // is this a simple recursive ITE strategy?
428     bool success = false;
429     if (strat == strat_ITE && nrole == role_equal)
430     {
431       success = true;
432       for (unsigned c = 1; c <= 2; c++)
433       {
434         std::pair<Node, NodeRole> child = etis->d_cenum[c];
435         if (child.first != e || child.second != nrole)
436         {
437           success = false;
438           break;
439         }
440       }
441       if (success)
442       {
443         Node cond = etis->d_cenum[0].first;
444         Assert(etis->d_cenum[0].second == role_ite_condition);
445         Trace("sygus-unif-rl-strat")
446             << "  ...detected recursive ITE strategy, condition enumerator : "
447             << cond << std::endl;
448         // indicate that we will be enumerating values for cond
449         registerConditionalEnumerator(f, e, cond, j);
450         // we will be using a strategy for e
451         enums.push_back(e);
452       }
453     }
454     if (!success)
455     {
456       unused_strats[e].insert(j);
457     }
458     // TODO: recurse? for (std::pair<Node, NodeRole>& cec : etis->d_cenum)
459   }
460 }
461 
registerConditionalEnumerator(Node f,Node e,Node cond,unsigned strategy_index)462 void SygusUnifRl::registerConditionalEnumerator(Node f,
463                                                 Node e,
464                                                 Node cond,
465                                                 unsigned strategy_index)
466 {
467   // only allow one decision tree per strategy point
468   if (d_stratpt_to_dt.find(e) != d_stratpt_to_dt.end())
469   {
470     return;
471   }
472   // we will do unification for this candidate
473   d_unif_candidates.insert(f);
474   // add to the list of all conditional enumerators
475   if (std::find(d_cond_enums.begin(), d_cond_enums.end(), cond)
476       == d_cond_enums.end())
477   {
478     d_cond_enums.push_back(cond);
479     d_cand_cenums[f].push_back(cond);
480     d_cenum_to_stratpt[cond].clear();
481   }
482   // register that this strategy node has a decision tree construction
483   d_stratpt_to_dt[e].initialize(cond, this, &d_strategy[f], strategy_index);
484   // associate conditional enumerator with strategy node
485   d_cenum_to_stratpt[cond].push_back(e);
486 }
487 
initialize(Node cond_enum,SygusUnifRl * unif,SygusUnifStrategy * strategy,unsigned strategy_index)488 void SygusUnifRl::DecisionTreeInfo::initialize(Node cond_enum,
489                                                SygusUnifRl* unif,
490                                                SygusUnifStrategy* strategy,
491                                                unsigned strategy_index)
492 {
493   d_cond_enum = cond_enum;
494   d_unif = unif;
495   d_strategy = strategy;
496   d_strategy_index = strategy_index;
497   d_true = NodeManager::currentNM()->mkConst(true);
498   d_false = NodeManager::currentNM()->mkConst(false);
499   // Retrieve template
500   EnumInfo& eiv = d_strategy->getEnumInfo(d_cond_enum);
501   d_template = NodePair(eiv.d_template, eiv.d_template_arg);
502   // Initialize classifier
503   d_pt_sep.initialize(this);
504 }
505 
setConditions(Node guard,const std::vector<Node> & enums,const std::vector<Node> & conds)506 void SygusUnifRl::DecisionTreeInfo::setConditions(
507     Node guard, const std::vector<Node>& enums, const std::vector<Node>& conds)
508 {
509   Assert(enums.size() == conds.size());
510   // set the guard
511   d_guard = guard;
512   // clear old condition values
513   d_enums.clear();
514   d_conds.clear();
515   // set new condition values
516   d_enums.insert(d_enums.end(), enums.begin(), enums.end());
517   d_conds.insert(d_conds.end(), conds.begin(), conds.end());
518   // add to condition pool
519   if (options::sygusUnifCondIndependent())
520   {
521     d_cond_mvs.insert(conds.begin(), conds.end());
522     if (Trace.isOn("sygus-unif-cond-pool"))
523     {
524       for (const Node& condv : conds)
525       {
526         if (d_cond_mvs.find(condv) == d_cond_mvs.end())
527         {
528           Trace("sygus-unif-cond-pool")
529               << "  ...adding to condition pool : "
530               << d_unif->d_tds->sygusToBuiltin(condv, condv.getType()) << "\n";
531         }
532       }
533     }
534   }
535 }
536 
getStrategyIndex() const537 unsigned SygusUnifRl::DecisionTreeInfo::getStrategyIndex() const
538 {
539   return d_strategy_index;
540 }
541 
buildSol(Node cons,std::vector<Node> & lemmas)542 Node SygusUnifRl::DecisionTreeInfo::buildSol(Node cons,
543                                              std::vector<Node>& lemmas)
544 {
545   if (!d_template.first.isNull())
546   {
547     Trace("sygus-unif-sol") << "...templated conditions unsupported\n";
548     return Node::null();
549   }
550   Trace("sygus-unif-sol") << "Decision::buildSol with " << d_hds.size()
551                           << " evaluation heads and " << d_conds.size()
552                           << " conditions..." << std::endl;
553   // reset the trie
554   d_pt_sep.d_trie.clear();
555   return options::sygusUnifCondIndependent() ? buildSolAllCond(cons, lemmas)
556                                              : buildSolMinCond(cons, lemmas);
557 }
558 
buildSolAllCond(Node cons,std::vector<Node> & lemmas)559 Node SygusUnifRl::DecisionTreeInfo::buildSolAllCond(Node cons,
560                                                     std::vector<Node>& lemmas)
561 {
562   // model values for evaluation heads
563   std::map<Node, Node> hd_mv;
564   // add conditions
565   d_conds.clear();
566   d_conds.insert(d_conds.end(), d_cond_mvs.begin(), d_cond_mvs.end());
567   // shuffle conditions before bulding DT
568   //
569   // this does not impact whether it's possible to build a solution, but it does
570   // impact the potential size of the resulting solution (can make it smaller,
571   // bigger, or have no impact) and which conditions will be present in the DT,
572   // which influences the "quality" of the solution for cases not covered in the
573   // current data points
574   if (options::sygusUnifShuffleCond())
575   {
576     std::shuffle(d_conds.begin(), d_conds.end(), Random::getRandom());
577   }
578   unsigned num_conds = d_conds.size();
579   for (unsigned i = 0; i < num_conds; ++i)
580   {
581     d_pt_sep.d_trie.addClassifier(&d_pt_sep, i);
582   }
583   // add heads
584   for (const Node& e : d_hds)
585   {
586     Node v = d_unif->d_parent->getModelValue(e);
587     hd_mv[e] = v;
588     Node er = d_pt_sep.d_trie.add(e, &d_pt_sep, num_conds);
589     // are we in conflict?
590     if (er == e)
591     {
592       // new separation class, no conflict
593       continue;
594     }
595     Assert(hd_mv.find(er) != hd_mv.end());
596     // merged into separation class with same model value, no conflict
597     if (hd_mv[e] == hd_mv[er])
598     {
599       continue;
600     }
601     // conflict. Explanation?
602     Trace("sygus-unif-sol")
603         << "  ...can't separate " << e << " from " << er << std::endl;
604     return Node::null();
605   }
606   Trace("sygus-unif-sol") << "...ready to build solution from DT\n";
607   Node sol = extractSol(cons, hd_mv);
608   // repeated solution
609   if (options::sygusUnifCondIndNoRepeatSol()
610       && d_sols.find(sol) != d_sols.end())
611   {
612     return Node::null();
613   }
614   d_sols.insert(sol);
615   return sol;
616 }
617 
buildSolMinCond(Node cons,std::vector<Node> & lemmas)618 Node SygusUnifRl::DecisionTreeInfo::buildSolMinCond(Node cons,
619                                                     std::vector<Node>& lemmas)
620 {
621   NodeManager* nm = NodeManager::currentNM();
622   // model values for evaluation heads
623   std::map<Node, Node> hd_mv;
624   // the current explanation of why there has not yet been a separation conflict
625   std::vector<Node> exp;
626   // is the above explanation ready to be sent out as a lemma?
627   bool exp_conflict = false;
628   // the index of the head we are considering
629   unsigned hd_counter = 0;
630   // the index of the condition we are considering
631   unsigned c_counter = 0;
632   // do we need to resolve a separation conflict?
633   bool needs_sep_resolve = false;
634   // This loop simultaneously builds the solution in terms of a lazy trie
635   // (LazyTrieMulti), and checks whether a separation conflict exists. We
636   // enforce that the separation conflicts we encounter while building
637   // this solution are resolved, in order, by the condition enumerators.
638   // If not, then we add a (conflict) lemma stating that the current model
639   // value of the condition enumerator must be different. We also call this
640   // a "separation lemma".
641   //
642   // As a simple example, say we have:
643   //   evalution heads: (eval e1 0 0), (eval e2 1 2)
644   //   conditions: c1
645   // where M(e1) = x, M(e2) = y, and M(c1) = x>1. After adding e1 and e2, we are
646   // in conflict since { e1, e2 } form a separation class, M(e1)!=M(e2), and
647   // M(c1) does not separate e1 and e2 since:
648   //   (x>1){x->0,y->0} = (x>1){x->1,y->2} = false
649   // Hence, we would fail to build a solution in this case, and instead send a
650   // separation lemma of the form:
651   //   ~( e1 != e2 ^ c1 = [x<1] )
652   //
653   // Say we have:
654   //   evalution heads: (eval e1 0 0), (eval e2 1 2), (eval e3 1 3)
655   //   conditions: c1 c2
656   // where M(e1) = x, M(e2) = y, M(e3) = x+1, M(c1) = x>0 and M(c2) = x<0.
657   // After adding e1 and e2, { e1, e2 } form a separation class, M(e1)!=M(e2),
658   // but M(c1) separates e1 and e2 since
659   //   (x>0){x->0,y->0} = false, and
660   //   (x>1){x->1,y->2} = true
661   // Hence, we get new separation classes { e1 } and { e2 }, and afterwards
662   // add e3. We then get { e2, e3 } as a separation class, which is also a
663   // conflict since M(e2)!=M(e3). We check if M(c2) resolves this conflict.
664   // It does not, since (x<1){x->0,y->0} = (x<1){x->1,y->2} = false. Hence,
665   // we get a separation lemma:
666   //  ~( c1 = [x>1] ^ e2 != e3 ^ c2 = [x<1] )
667   //
668   // Say we have:
669   //   evalution heads: (eval e1 0 0), (eval e2 1 2), (eval e3 1 3)
670   //   conditions: c1
671   // where M(e1) = x, M(e2) = x, M(e3) = y, M(c1) = x>0.
672   // After adding e1 and e2, we have separation class { e1, e2 }. This is not a
673   // conflict since M(e1)=M(e2). We then add e3, obtaining separation class
674   // { e1, e2, e3 }, which is in conflict since M(e3)!=M(e1), and the condition
675   // c1 does not separate e3 and the representative of this class, e1. Hence we
676   // get a separation lemma of the form:
677   //  ~( e1 = e2 ^ e1 != e3 ^ c1 = [x>0] )
678   //
679   // It also may be the case that we exhaust the pool of condition enumerators.
680   // Say we have:
681   //   evalution heads: (eval e1 0 0), (eval e2 1 2), (eval e3 1 3)
682   //   conditions: c1
683   // where M(e1) = x, M(e2) = x, M(e3) = y, M(c1) = y>0. After adding e1, e2,
684   // and e3, we have a separation class { e1, e2, e3 } that is in conflict
685   // since M(e3)!=M(e1). We add the condition c1, which separates into new
686   // equivalence classes { e1 }, { e2, e3 }. We are still in separation conflict
687   // since M(e3)!=M(e2). However, we do not have any further conditions to use
688   // to resolve this conflict. Thus, we add the separation lemma:
689   //  ~( e1 = e2 ^ e1 != e3 ^ e2 != e3 ^ c1 = [y>0] ^ G_1 )
690   // where G_1 is a guard stating that we use at most 1 condition.
691   Node e;
692   Node er;
693   while (hd_counter < d_hds.size() || needs_sep_resolve)
694   {
695     if (!needs_sep_resolve)
696     {
697       // add the head to the trie
698       e = d_hds[hd_counter];
699       hd_mv[e] = d_unif->d_parent->getModelValue(e);
700       if (Trace.isOn("sygus-unif-sol"))
701       {
702         std::stringstream ss;
703         Printer::getPrinter(options::outputLanguage())
704             ->toStreamSygus(ss, hd_mv[e]);
705         Trace("sygus-unif-sol")
706             << "  add evaluation head (" << hd_counter << "/" << d_hds.size()
707             << "): " << e << " -> " << ss.str() << std::endl;
708       }
709       hd_counter++;
710       // get the representative of the trie
711       er = d_pt_sep.d_trie.add(e, &d_pt_sep, c_counter);
712       Trace("sygus-unif-sol") << "  ...separation class " << er << std::endl;
713       // are we in conflict?
714       if (er == e)
715       {
716         // new separation class, no conflict
717         continue;
718       }
719       Assert(hd_mv.find(er) != hd_mv.end());
720       if (hd_mv[er] == hd_mv[e])
721       {
722         // merged into separation class with same model value, no conflict
723         // add to explanation
724         // this states that it mattered that (er = e) at the time that e was
725         // added to the trie. Notice that er and e may become separated later,
726         // but to ensure the overall invariant, this equality must persist in
727         // the explanation.
728         exp.push_back(er.eqNode(e));
729         Trace("sygus-unif-sol") << "  ...equal model values " << std::endl;
730         Trace("sygus-unif-sol")
731             << "  ...add to explanation " << er.eqNode(e) << std::endl;
732         continue;
733       }
734     }
735     // must include in the explanation that we hit a conflict at this point in
736     // the construction
737     exp.push_back(e.eqNode(er).negate());
738     // we are in separation conflict, does the next condition resolve this?
739     // check whether we have have exhausted our condition pool. If so, we
740     // are in conflict and this conflict depends on the guard.
741     if (c_counter >= d_conds.size())
742     {
743       // truncated separation lemma
744       Assert(!d_guard.isNull());
745       exp.push_back(d_guard);
746       exp_conflict = true;
747       break;
748     }
749     Assert(c_counter < d_conds.size());
750     Node ce = d_enums[c_counter];
751     Node cv = d_conds[c_counter];
752     Assert(ce.getType() == cv.getType());
753     if (Trace.isOn("sygus-unif-sol"))
754     {
755       std::stringstream ss;
756       Printer::getPrinter(options::outputLanguage())->toStreamSygus(ss, cv);
757       Trace("sygus-unif-sol")
758           << "  add condition (" << c_counter << "/" << d_conds.size()
759           << "): " << ce << " -> " << ss.str() << std::endl;
760     }
761     // cache the separation class
762     std::vector<Node> prev_sep_c = d_pt_sep.d_trie.d_rep_to_class[er];
763     // add new classifier
764     d_pt_sep.d_trie.addClassifier(&d_pt_sep, c_counter);
765     c_counter++;
766     // add to explanation
767     // c_exp is a conjunction of testers applied to shared selector chains
768     Node c_exp = d_unif->d_tds->getExplain()->getExplanationForEquality(ce, cv);
769     exp.push_back(c_exp);
770     std::map<Node, std::vector<Node>>::iterator itr =
771         d_pt_sep.d_trie.d_rep_to_class.find(e);
772     // since e is last in its separation class, if it becomes a representative,
773     // then it is separated from all values in prev_sep_c
774     if (itr != d_pt_sep.d_trie.d_rep_to_class.end())
775     {
776       Trace("sygus-unif-sol")
777           << "  ...resolves separation conflict with all" << std::endl;
778       needs_sep_resolve = false;
779       continue;
780     }
781     itr = d_pt_sep.d_trie.d_rep_to_class.find(er);
782     // since er is first in its separation class, it remains a representative
783     Assert(itr != d_pt_sep.d_trie.d_rep_to_class.end());
784     // is e still in the separation class of er?
785     if (std::find(itr->second.begin(), itr->second.end(), e)
786         != itr->second.end())
787     {
788       Trace("sygus-unif-sol")
789           << "  ...does not resolve separation conflict with current"
790           << std::endl;
791       // the condition does not separate e and er
792       // this violates the invariant that the i^th conditional enumerator
793       // resolves the i^th separation conflict
794       exp_conflict = true;
795       break;
796     }
797     Trace("sygus-unif-sol")
798         << "  ...resolves separation conflict with current, but not all"
799         << std::endl;
800     // find the new term to resolve a separation
801     Node new_er = Node::null();
802     // scan the previous list and find the representative of the class that e is
803     // now in
804     for (unsigned i = 0, size = prev_sep_c.size(); i < size; i++)
805     {
806       Node check_er = prev_sep_c[i];
807       if (check_er != er && check_er != e)
808       {
809         itr = d_pt_sep.d_trie.d_rep_to_class.find(check_er);
810         if (itr != d_pt_sep.d_trie.d_rep_to_class.end())
811         {
812           if (std::find(itr->second.begin(), itr->second.end(), e)
813               != itr->second.end())
814           {
815             new_er = check_er;
816             break;
817           }
818         }
819       }
820     }
821     // should find exactly one
822     Assert(!new_er.isNull());
823     er = new_er;
824     needs_sep_resolve = true;
825   }
826   if (exp_conflict)
827   {
828     Node lemma = exp.size() == 1 ? exp[0] : nm->mkNode(AND, exp);
829     lemma = lemma.negate();
830     Trace("sygus-unif-sol") << "  ......conflict is " << lemma << std::endl;
831     lemmas.push_back(lemma);
832     return Node::null();
833   }
834 
835   Trace("sygus-unif-sol") << "...ready to build solution from DT\n";
836   return extractSol(cons, hd_mv);
837 }
838 
extractSol(Node cons,std::map<Node,Node> & hd_mv)839 Node SygusUnifRl::DecisionTreeInfo::extractSol(Node cons,
840                                                std::map<Node, Node>& hd_mv)
841 {
842   // rebuild decision tree using heuristic learning
843   if (options::sygusUnifBooleanHeuristicDt())
844   {
845     recomputeSolHeuristically(hd_mv);
846   }
847   return d_pt_sep.extractSol(cons, hd_mv);
848 }
849 
extractSol(Node cons,std::map<Node,Node> & hd_mv)850 Node SygusUnifRl::DecisionTreeInfo::PointSeparator::extractSol(
851     Node cons, std::map<Node, Node>& hd_mv)
852 {
853   // Traverse trie and build ITE with cons
854   NodeManager* nm = NodeManager::currentNM();
855   std::map<IndTriePair, Node> cache;
856   std::map<IndTriePair, Node>::iterator it;
857   std::vector<IndTriePair> visit;
858   unsigned index = 0;
859   LazyTrie* trie;
860   IndTriePair root = IndTriePair(0, &d_trie.d_trie);
861   visit.push_back(root);
862   while (!visit.empty())
863   {
864     index = visit.back().first;
865     trie = visit.back().second;
866     visit.pop_back();
867     IndTriePair cur = IndTriePair(index, trie);
868     it = cache.find(cur);
869     // traverse children so results are saved to build node for parent
870     if (it == cache.end())
871     {
872       // leaf
873       if (trie->d_children.empty())
874       {
875         Assert(hd_mv.find(trie->d_lazy_child) != hd_mv.end());
876         cache[cur] = hd_mv[trie->d_lazy_child];
877         Trace("sygus-unif-sol-debug") << "......leaf, build "
878                                       << d_dt->d_unif->d_tds->sygusToBuiltin(
879                                              cache[cur], cache[cur].getType())
880                                       << "\n";
881         continue;
882       }
883       cache[cur] = Node::null();
884       visit.push_back(cur);
885       for (std::pair<const Node, LazyTrie>& p_nt : trie->d_children)
886       {
887         visit.push_back(IndTriePair(index + 1, &p_nt.second));
888       }
889       continue;
890     }
891     // retrieve terms of children and build result
892     Assert(it->second.isNull());
893     Assert(trie->d_children.size() == 1 || trie->d_children.size() == 2);
894     std::vector<Node> children(4);
895     children[0] = cons;
896     children[1] = d_dt->d_conds[index];
897     unsigned i = 0;
898     for (std::pair<const Node, LazyTrie>& p_nt : trie->d_children)
899     {
900       i = p_nt.first.getConst<bool>() ? 2 : 3;
901       Assert(cache.find(IndTriePair(index + 1, &p_nt.second)) != cache.end());
902       children[i] = cache[IndTriePair(index + 1, &p_nt.second)];
903       Assert(!children[i].isNull());
904     }
905     // condition is useless or result children are equal, no no need for ITE
906     if (trie->d_children.size() == 1 || children[2] == children[3])
907     {
908       cache[cur] = children[i];
909       Trace("sygus-unif-sol-debug")
910           << "......no need for cond "
911           << d_dt->d_unif->d_tds->sygusToBuiltin(d_dt->d_conds[index],
912                                                  d_dt->d_conds[index].getType())
913           << ", build "
914           << d_dt->d_unif->d_tds->sygusToBuiltin(cache[cur],
915                                                  cache[cur].getType())
916           << "\n";
917       continue;
918     }
919     Assert(trie->d_children.size() == 2);
920     cache[cur] = nm->mkNode(APPLY_CONSTRUCTOR, children);
921     Trace("sygus-unif-sol-debug")
922         << "......build node "
923         << d_dt->d_unif->d_tds->sygusToBuiltin(cache[cur], cache[cur].getType())
924         << "\n";
925   }
926   Assert(cache.find(root) != cache.end());
927   Assert(!cache.find(root)->second.isNull());
928   return cache[root];
929 }
930 
recomputeSolHeuristically(std::map<Node,Node> & hd_mv)931 void SygusUnifRl::DecisionTreeInfo::recomputeSolHeuristically(
932     std::map<Node, Node>& hd_mv)
933 {
934   // reset the trie
935   d_pt_sep.d_trie.clear();
936   // TODO workaround and not really sure this is the last condition, since I put
937   // a set here. Maybe make d_cond_mvs into a vector
938   Node backup_last_cond = d_conds.back();
939   d_conds.clear();
940   for (const Node& e : d_hds)
941   {
942     d_pt_sep.d_trie.add(e, &d_pt_sep, 0);
943   }
944   // init vector of conds
945   std::vector<Node> conds;
946   conds.insert(conds.end(), d_cond_mvs.begin(), d_cond_mvs.end());
947 
948   // recursively build trie by picking best condition for respective points
949   buildDtInfoGain(d_hds, conds, hd_mv, 1);
950   // if no condition was added (i.e. points are already classified at root
951   // level), use last condition as candidate
952   if (d_conds.empty())
953   {
954     Trace("sygus-unif-dt") << "......using last condition "
955                            << d_unif->d_tds->sygusToBuiltin(
956                                   backup_last_cond, backup_last_cond.getType())
957                            << " as candidate\n";
958     d_conds.push_back(backup_last_cond);
959     d_pt_sep.d_trie.addClassifier(&d_pt_sep, d_conds.size() - 1);
960   }
961 }
962 
buildDtInfoGain(std::vector<Node> & hds,std::vector<Node> conds,std::map<Node,Node> & hd_mv,int ind)963 void SygusUnifRl::DecisionTreeInfo::buildDtInfoGain(std::vector<Node>& hds,
964                                                     std::vector<Node> conds,
965                                                     std::map<Node, Node>& hd_mv,
966                                                     int ind)
967 {
968   // test if fully classified
969   if (hds.size() < 2)
970   {
971     indent("sygus-unif-dt", ind);
972     Trace("sygus-unif-dt") << "..set fully classified: "
973                            << (hds.empty() ? "empty" : "unary") << "\n";
974     return;
975   }
976   Node v1 = hd_mv[hds[0]];
977   unsigned i = 1, size = hds.size();
978   for (; i < size; ++i)
979   {
980     if (hd_mv[hds[i]] != v1)
981     {
982       break;
983     }
984   }
985   if (i == size)
986   {
987     indent("sygus-unif-dt", ind);
988     Trace("sygus-unif-dt") << "..set fully classified: " << hds.size() << " "
989                            << (d_unif->d_tds->sygusToBuiltin(v1, v1.getType())
990                                        == d_true
991                                    ? "good"
992                                    : "bad")
993                            << " points\n";
994     return;
995   }
996   // pick condition to further classify
997   double maxgain = -1;
998   unsigned picked_cond = 0;
999   std::vector<std::pair<std::vector<Node>, std::vector<Node>>> splits;
1000   double current_set_entropy = getEntropy(hds, hd_mv, ind);
1001   for (unsigned i = 0, size = conds.size(); i < size; ++i)
1002   {
1003     std::pair<std::vector<Node>, std::vector<Node>> split =
1004         evaluateCond(hds, conds[i]);
1005     splits.push_back(split);
1006     Assert(hds.size() == split.first.size() + split.second.size());
1007     double gain =
1008         current_set_entropy
1009         - (split.first.size() * getEntropy(split.first, hd_mv, ind)
1010            + split.second.size() * getEntropy(split.second, hd_mv, ind))
1011               / hds.size();
1012     indent("sygus-unif-dt-debug", ind);
1013     Trace("sygus-unif-dt-debug")
1014         << "..gain of "
1015         << d_unif->d_tds->sygusToBuiltin(conds[i], conds[i].getType()) << " is "
1016         << gain << "\n";
1017     if (gain > maxgain)
1018     {
1019       maxgain = gain;
1020       picked_cond = i;
1021     }
1022   }
1023   // add picked condition
1024   indent("sygus-unif-dt", ind);
1025   Trace("sygus-unif-dt") << "..picked condition "
1026                          << d_unif->d_tds->sygusToBuiltin(
1027                                 conds[picked_cond],
1028                                 conds[picked_cond].getType())
1029                          << "\n";
1030   d_conds.push_back(conds[picked_cond]);
1031   conds.erase(conds.begin() + picked_cond);
1032   d_pt_sep.d_trie.addClassifier(&d_pt_sep, d_conds.size() - 1);
1033   // recurse
1034   buildDtInfoGain(splits[picked_cond].first, conds, hd_mv, ind + 1);
1035   buildDtInfoGain(splits[picked_cond].second, conds, hd_mv, ind + 1);
1036 }
1037 
1038 std::pair<std::vector<Node>, std::vector<Node>>
evaluateCond(std::vector<Node> & pts,Node cond)1039 SygusUnifRl::DecisionTreeInfo::evaluateCond(std::vector<Node>& pts, Node cond)
1040 {
1041   std::vector<Node> good, bad;
1042   for (const Node& pt : pts)
1043   {
1044     if (d_pt_sep.computeCond(cond, pt) == d_true)
1045     {
1046       good.push_back(pt);
1047       continue;
1048     }
1049     Assert(d_pt_sep.computeCond(cond, pt) == d_false);
1050     bad.push_back(pt);
1051   }
1052   return std::pair<std::vector<Node>, std::vector<Node>>(good, bad);
1053 }
1054 
getEntropy(const std::vector<Node> & hds,std::map<Node,Node> & hd_mv,int ind)1055 double SygusUnifRl::DecisionTreeInfo::getEntropy(const std::vector<Node>& hds,
1056                                                  std::map<Node, Node>& hd_mv,
1057                                                  int ind)
1058 {
1059   double p = 0, n = 0;
1060   TermDbSygus* tds = d_unif->d_tds;
1061   // get number of points evaluated positively and negatively with feature
1062   for (const Node& e : hds)
1063   {
1064     if (tds->sygusToBuiltin(hd_mv[e]) == d_true)
1065     {
1066       p++;
1067       continue;
1068     }
1069     Assert(tds->sygusToBuiltin(hd_mv[e]) == d_false);
1070     n++;
1071   }
1072   // compute entropy
1073   return p == 0 || n == 0 ? 0
1074                           : ((-p / (p + n)) * log2(p / (p + n)))
1075                                 - ((n / (p + n)) * log2(n / (p + n)));
1076 }
1077 
initialize(DecisionTreeInfo * dt)1078 void SygusUnifRl::DecisionTreeInfo::PointSeparator::initialize(
1079     DecisionTreeInfo* dt)
1080 {
1081   d_dt = dt;
1082 }
1083 
evaluate(Node n,unsigned index)1084 Node SygusUnifRl::DecisionTreeInfo::PointSeparator::evaluate(Node n,
1085                                                              unsigned index)
1086 {
1087   Assert(index < d_dt->d_conds.size());
1088   // Retrieve respective built_in condition
1089   Node cond = d_dt->d_conds[index];
1090   return computeCond(cond, n);
1091 }
1092 
computeCond(Node cond,Node hd)1093 Node SygusUnifRl::DecisionTreeInfo::PointSeparator::computeCond(Node cond,
1094                                                                 Node hd)
1095 {
1096   std::pair<Node, Node> cond_hd = std::pair<Node, Node>(cond, hd);
1097   std::map<std::pair<Node, Node>, Node>::iterator it =
1098       d_eval_cond_hd.find(cond_hd);
1099   if (it != d_eval_cond_hd.end())
1100   {
1101     return it->second;
1102   }
1103   TypeNode tn = cond.getType();
1104   Node builtin_cond = d_dt->d_unif->d_tds->sygusToBuiltin(cond, tn);
1105   // Retrieve evaluation point
1106   Assert(d_dt->d_unif->d_hd_to_pt.find(hd) != d_dt->d_unif->d_hd_to_pt.end());
1107   std::vector<Node> pt = d_dt->d_unif->d_hd_to_pt[hd];
1108   // compute the result
1109   if (Trace.isOn("sygus-unif-rl-sep"))
1110   {
1111     Trace("sygus-unif-rl-sep")
1112         << "Evaluate cond " << builtin_cond << " on pt " << hd << " ( ";
1113     for (const Node& pti : pt)
1114     {
1115       Trace("sygus-unif-rl-sep") << pti << " ";
1116     }
1117     Trace("sygus-unif-rl-sep") << ")\n";
1118   }
1119   Node res = d_dt->d_unif->d_tds->evaluateBuiltin(tn, builtin_cond, pt);
1120   Trace("sygus-unif-rl-sep") << "...got res = " << res << "\n";
1121   // If condition is templated, recompute result accordingly
1122   Node templ = d_dt->d_template.first;
1123   TNode templ_var = d_dt->d_template.second;
1124   if (!templ.isNull())
1125   {
1126     res = templ.substitute(templ_var, res);
1127     res = Rewriter::rewrite(res);
1128     Trace("sygus-unif-rl-sep")
1129         << "...after template res = " << res << std::endl;
1130   }
1131   Assert(res.isConst());
1132   d_eval_cond_hd[cond_hd] = res;
1133   return res;
1134 }
1135 
1136 }  // namespace quantifiers
1137 }  // namespace theory
1138 }  // namespace CVC4
1139