1 /*********************                                                        */
2 /*! \file cegis.cpp
3  ** \verbatim
4  ** Top contributors (to current version):
5  **   Andrew Reynolds, Haniel Barbosa
6  ** This file is part of the CVC4 project.
7  ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8  ** in the top-level source directory) and their institutional affiliations.
9  ** All rights reserved.  See the file COPYING in the top-level source
10  ** directory for licensing information.\endverbatim
11  **
12  ** \brief Implementation of cegis
13  **/
14 
15 #include "theory/quantifiers/sygus/cegis.h"
16 #include "expr/node_algorithm.h"
17 #include "options/base_options.h"
18 #include "options/quantifiers_options.h"
19 #include "printer/printer.h"
20 #include "theory/quantifiers/sygus/synth_conjecture.h"
21 #include "theory/quantifiers/sygus/term_database_sygus.h"
22 #include "theory/theory_engine.h"
23 
24 using namespace std;
25 using namespace CVC4::kind;
26 using namespace CVC4::context;
27 
28 namespace CVC4 {
29 namespace theory {
30 namespace quantifiers {
31 
Cegis(QuantifiersEngine * qe,SynthConjecture * p)32 Cegis::Cegis(QuantifiersEngine* qe, SynthConjecture* p)
33     : SygusModule(qe, p), d_eval_unfold(nullptr), d_using_gr_repair(false)
34 {
35   if (options::sygusEvalUnfold())
36   {
37     d_eval_unfold = qe->getTermDatabaseSygus()->getEvalUnfold();
38   }
39 }
40 
initialize(Node n,const std::vector<Node> & candidates,std::vector<Node> & lemmas)41 bool Cegis::initialize(Node n,
42                        const std::vector<Node>& candidates,
43                        std::vector<Node>& lemmas)
44 {
45   d_base_body = n;
46   if (d_base_body.getKind() == NOT && d_base_body[0].getKind() == FORALL)
47   {
48     for (const Node& v : d_base_body[0][0])
49     {
50       d_base_vars.push_back(v);
51     }
52     d_base_body = d_base_body[0][1];
53   }
54 
55   // assign the cegis sampler if applicable
56   if (options::cegisSample() != CEGIS_SAMPLE_NONE)
57   {
58     Trace("cegis-sample") << "Initialize sampler for " << d_base_body << "..."
59                           << std::endl;
60     TypeNode bt = d_base_body.getType();
61     d_cegis_sampler.initialize(bt, d_base_vars, options::sygusSamples());
62   }
63   return processInitialize(n, candidates, lemmas);
64 }
65 
processInitialize(Node n,const std::vector<Node> & candidates,std::vector<Node> & lemmas)66 bool Cegis::processInitialize(Node n,
67                               const std::vector<Node>& candidates,
68                               std::vector<Node>& lemmas)
69 {
70   Trace("cegis") << "Initialize cegis..." << std::endl;
71   unsigned csize = candidates.size();
72   // The role of enumerators is to be either the single solution or part of
73   // a solution involving multiple enumerators.
74   EnumeratorRole erole =
75       csize == 1 ? ROLE_ENUM_SINGLE_SOLUTION : ROLE_ENUM_MULTI_SOLUTION;
76   // initialize an enumerator for each candidate
77   for (unsigned i = 0; i < csize; i++)
78   {
79     Trace("cegis") << "...register enumerator " << candidates[i];
80     bool do_repair_const = false;
81     if (options::sygusRepairConst())
82     {
83       TypeNode ctn = candidates[i].getType();
84       d_tds->registerSygusType(ctn);
85       if (d_tds->hasSubtermSymbolicCons(ctn))
86       {
87         do_repair_const = true;
88         // remember that we are doing grammar-based repair
89         d_using_gr_repair = true;
90         Trace("cegis") << " (using repair)";
91       }
92     }
93     Trace("cegis") << std::endl;
94     d_tds->registerEnumerator(
95         candidates[i], candidates[i], d_parent, erole, do_repair_const);
96   }
97   return true;
98 }
99 
getTermList(const std::vector<Node> & candidates,std::vector<Node> & enums)100 void Cegis::getTermList(const std::vector<Node>& candidates,
101                         std::vector<Node>& enums)
102 {
103   enums.insert(enums.end(), candidates.begin(), candidates.end());
104 }
105 
addEvalLemmas(const std::vector<Node> & candidates,const std::vector<Node> & candidate_values,std::vector<Node> & lems)106 bool Cegis::addEvalLemmas(const std::vector<Node>& candidates,
107                           const std::vector<Node>& candidate_values,
108                           std::vector<Node>& lems)
109 {
110   // First, decide if this call will apply "conjecture-specific refinement".
111   // In other words, in some settings, the following method will identify and
112   // block a class of solutions {candidates -> S} that generalizes the current
113   // one (given by {candidates -> candidate_values}), such that for each
114   // candidate_values' in S, we have that {candidates -> candidate_values'} is
115   // also not a solution for the given conjecture. We may not
116   // apply this form of refinement if any (relevant) enumerator in candidates is
117   // "actively generated" (see TermDbSygs::isPassiveEnumerator), since its
118   // model values are themselves interpreted as classes of solutions.
119   bool doGen = true;
120   for (const Node& v : candidates)
121   {
122     // if it is relevant to refinement
123     if (d_refinement_lemma_vars.find(v) != d_refinement_lemma_vars.end())
124     {
125       if (!d_tds->isPassiveEnumerator(v))
126       {
127         doGen = false;
128         break;
129       }
130     }
131   }
132   NodeManager* nm = NodeManager::currentNM();
133   bool addedEvalLemmas = false;
134   if (options::sygusRefEval())
135   {
136     Trace("cegqi-engine") << "  *** Do refinement lemma evaluation"
137                           << (doGen ? " with conjecture-specific refinement"
138                                     : "")
139                           << "..." << std::endl;
140     // see if any refinement lemma is refuted by evaluation
141     std::vector<Node> cre_lems;
142     bool ret =
143         getRefinementEvalLemmas(candidates, candidate_values, cre_lems, doGen);
144     if (ret && !doGen)
145     {
146       Trace("cegqi-engine") << "...(actively enumerated) candidate failed "
147                                "refinement lemma evaluation."
148                             << std::endl;
149       return true;
150     }
151     if (!cre_lems.empty())
152     {
153       lems.insert(lems.end(), cre_lems.begin(), cre_lems.end());
154       addedEvalLemmas = true;
155       if (Trace.isOn("cegqi-lemma"))
156       {
157         for (const Node& lem : cre_lems)
158         {
159           Trace("cegqi-lemma")
160               << "Cegqi::Lemma : ref evaluation : " << lem << std::endl;
161         }
162       }
163       /* we could, but do not return here. experimentally, it is better to
164          add the lemmas below as well, in parallel. */
165     }
166   }
167   // we only do evaluation unfolding for passive enumerators
168   if (doGen && d_eval_unfold != nullptr)
169   {
170     Trace("cegqi-engine") << "  *** Do evaluation unfolding..." << std::endl;
171     std::vector<Node> eager_terms, eager_vals, eager_exps;
172     for (unsigned i = 0, size = candidates.size(); i < size; ++i)
173     {
174       Trace("cegqi-debug") << "  register " << candidates[i] << " -> "
175                            << candidate_values[i] << std::endl;
176       d_eval_unfold->registerModelValue(candidates[i],
177                                         candidate_values[i],
178                                         eager_terms,
179                                         eager_vals,
180                                         eager_exps);
181     }
182     Trace("cegqi-debug") << "...produced " << eager_terms.size()
183                          << " evaluation unfold lemmas.\n";
184     for (unsigned i = 0, size = eager_terms.size(); i < size; ++i)
185     {
186       Node lem = nm->mkNode(
187           OR, eager_exps[i].negate(), eager_terms[i].eqNode(eager_vals[i]));
188       lems.push_back(lem);
189       addedEvalLemmas = true;
190       Trace("cegqi-lemma") << "Cegqi::Lemma : evaluation unfold : " << lem
191                            << std::endl;
192     }
193   }
194   return addedEvalLemmas;
195 }
196 
constructCandidates(const std::vector<Node> & enums,const std::vector<Node> & enum_values,const std::vector<Node> & candidates,std::vector<Node> & candidate_values,std::vector<Node> & lems)197 bool Cegis::constructCandidates(const std::vector<Node>& enums,
198                                 const std::vector<Node>& enum_values,
199                                 const std::vector<Node>& candidates,
200                                 std::vector<Node>& candidate_values,
201                                 std::vector<Node>& lems)
202 {
203   if (Trace.isOn("cegis"))
204   {
205     Trace("cegis") << "  Enumerators :\n";
206     for (unsigned i = 0, size = enums.size(); i < size; ++i)
207     {
208       Trace("cegis") << "    " << enums[i] << " -> ";
209       TermDbSygus::toStreamSygus("cegis", enum_values[i]);
210       Trace("cegis") << "\n";
211     }
212   }
213   // if we are using grammar-based repair
214   if (d_using_gr_repair)
215   {
216     SygusRepairConst* src = d_parent->getRepairConst();
217     Assert(src != nullptr);
218     // check if any enum_values have symbolic terms that must be repaired
219     bool mustRepair = false;
220     for (const Node& c : enum_values)
221     {
222       if (SygusRepairConst::mustRepair(c))
223       {
224         mustRepair = true;
225         break;
226       }
227     }
228     Trace("cegis") << "...must repair is: " << mustRepair << std::endl;
229     // if the solution contains a subterm that must be repaired
230     if (mustRepair)
231     {
232       std::vector<Node> fail_cvs = enum_values;
233       Assert(candidates.size() == fail_cvs.size());
234       if (src->repairSolution(candidates, fail_cvs, candidate_values))
235       {
236         return true;
237       }
238       // repair solution didn't work, exclude this solution
239       std::vector<Node> exp;
240       for (unsigned i = 0, size = enums.size(); i < size; i++)
241       {
242         d_tds->getExplain()->getExplanationForEquality(
243             enums[i], enum_values[i], exp);
244       }
245       Assert(!exp.empty());
246       Node expn =
247           exp.size() == 1 ? exp[0] : NodeManager::currentNM()->mkNode(AND, exp);
248       lems.push_back(expn.negate());
249       return false;
250     }
251   }
252 
253   // evaluate on refinement lemmas
254   bool addedEvalLemmas = addEvalLemmas(enums, enum_values, lems);
255 
256   // try to construct candidates
257   if (!processConstructCandidates(enums,
258                                   enum_values,
259                                   candidates,
260                                   candidate_values,
261                                   !addedEvalLemmas,
262                                   lems))
263   {
264     return false;
265   }
266 
267   if (options::cegisSample() != CEGIS_SAMPLE_NONE && lems.empty())
268   {
269     // if we didn't add a lemma, trying sampling to add a refinement lemma
270     // that immediately refutes the candidate we just constructed
271     if (sampleAddRefinementLemma(enums, enum_values, lems))
272     {
273       // restart (should be guaranteed to add evaluation lemmas on this call)
274       return constructCandidates(
275           enums, enum_values, candidates, candidate_values, lems);
276     }
277   }
278   return true;
279 }
280 
processConstructCandidates(const std::vector<Node> & enums,const std::vector<Node> & enum_values,const std::vector<Node> & candidates,std::vector<Node> & candidate_values,bool satisfiedRl,std::vector<Node> & lems)281 bool Cegis::processConstructCandidates(const std::vector<Node>& enums,
282                                        const std::vector<Node>& enum_values,
283                                        const std::vector<Node>& candidates,
284                                        std::vector<Node>& candidate_values,
285                                        bool satisfiedRl,
286                                        std::vector<Node>& lems)
287 {
288   if (satisfiedRl)
289   {
290     candidate_values.insert(
291         candidate_values.end(), enum_values.begin(), enum_values.end());
292     return true;
293   }
294   return false;
295 }
296 
addRefinementLemma(Node lem)297 void Cegis::addRefinementLemma(Node lem)
298 {
299   d_refinement_lemmas.push_back(lem);
300   // apply existing substitution
301   Node slem = lem;
302   if (!d_rl_eval_hds.empty())
303   {
304     slem = lem.substitute(d_rl_eval_hds.begin(),
305                           d_rl_eval_hds.end(),
306                           d_rl_vals.begin(),
307                           d_rl_vals.end());
308   }
309   // rewrite with extended rewriter
310   slem = d_tds->getExtRewriter()->extendedRewrite(slem);
311   // collect all variables in slem
312   expr::getSymbols(slem, d_refinement_lemma_vars);
313   std::vector<Node> waiting;
314   waiting.push_back(lem);
315   unsigned wcounter = 0;
316   // while we are not done adding lemmas
317   while (wcounter < waiting.size())
318   {
319     // add the conjunct, possibly propagating
320     addRefinementLemmaConjunct(wcounter, waiting);
321     wcounter++;
322   }
323 }
324 
addRefinementLemmaConjunct(unsigned wcounter,std::vector<Node> & waiting)325 void Cegis::addRefinementLemmaConjunct(unsigned wcounter,
326                                        std::vector<Node>& waiting)
327 {
328   Node lem = waiting[wcounter];
329   lem = Rewriter::rewrite(lem);
330   // apply substitution and rewrite if applicable
331   if (lem.isConst())
332   {
333     if (!lem.getConst<bool>())
334     {
335       // conjecture is infeasible
336     }
337     else
338     {
339       return;
340     }
341   }
342   // break into conjunctions
343   if (lem.getKind() == AND)
344   {
345     for (const Node& lc : lem)
346     {
347       waiting.push_back(lc);
348     }
349     return;
350   }
351   // does this correspond to a substitution?
352   NodeManager* nm = NodeManager::currentNM();
353   TNode term;
354   TNode val;
355   if (lem.getKind() == EQUAL)
356   {
357     for (unsigned i = 0; i < 2; i++)
358     {
359       if (lem[i].isConst() && d_tds->isEvaluationPoint(lem[1 - i]))
360       {
361         term = lem[1 - i];
362         val = lem[i];
363         break;
364       }
365     }
366   }
367   else
368   {
369     term = lem.getKind() == NOT ? lem[0] : lem;
370     // predicate case: the conjunct is a (negated) evaluation point
371     if (d_tds->isEvaluationPoint(term))
372     {
373       val = nm->mkConst(lem.getKind() != NOT);
374     }
375   }
376   if (!val.isNull())
377   {
378     if (d_refinement_lemma_unit.find(lem) != d_refinement_lemma_unit.end())
379     {
380       // already added
381       return;
382     }
383     Trace("cegis-rl") << "* cegis-rl: propagate: " << term << " -> " << val
384                       << std::endl;
385     d_rl_eval_hds.push_back(term);
386     d_rl_vals.push_back(val);
387     d_refinement_lemma_unit.insert(lem);
388     // apply to waiting lemmas beyond this one
389     for (unsigned i = wcounter + 1, size = waiting.size(); i < size; i++)
390     {
391       waiting[i] = waiting[i].substitute(term, val);
392     }
393     // apply to all existing refinement lemmas
394     std::vector<Node> to_rem;
395     for (const Node& rl : d_refinement_lemma_conj)
396     {
397       Node srl = rl.substitute(term, val);
398       if (srl != rl)
399       {
400         Trace("cegis-rl") << "* cegis-rl: replace: " << rl << " -> " << srl
401                           << std::endl;
402         waiting.push_back(srl);
403         to_rem.push_back(rl);
404       }
405     }
406     for (const Node& tr : to_rem)
407     {
408       d_refinement_lemma_conj.erase(tr);
409     }
410   }
411   else
412   {
413     if (Trace.isOn("cegis-rl"))
414     {
415       if (d_refinement_lemma_conj.find(lem) == d_refinement_lemma_conj.end())
416       {
417         Trace("cegis-rl") << "cegis-rl: add: " << lem << std::endl;
418       }
419     }
420     d_refinement_lemma_conj.insert(lem);
421   }
422 }
423 
registerRefinementLemma(const std::vector<Node> & vars,Node lem,std::vector<Node> & lems)424 void Cegis::registerRefinementLemma(const std::vector<Node>& vars,
425                                     Node lem,
426                                     std::vector<Node>& lems)
427 {
428   addRefinementLemma(lem);
429   // Make the refinement lemma and add it to lems.
430   // This lemma is guarded by the parent's guard, which has the semantics
431   // "this conjecture has a solution", hence this lemma states:
432   // if the parent conjecture has a solution, it satisfies the specification
433   // for the given concrete point.
434   Node rlem =
435       NodeManager::currentNM()->mkNode(OR, d_parent->getGuard().negate(), lem);
436   lems.push_back(rlem);
437 }
438 
usingRepairConst()439 bool Cegis::usingRepairConst() { return true; }
getRefinementEvalLemmas(const std::vector<Node> & vs,const std::vector<Node> & ms,std::vector<Node> & lems,bool doGen)440 bool Cegis::getRefinementEvalLemmas(const std::vector<Node>& vs,
441                                     const std::vector<Node>& ms,
442                                     std::vector<Node>& lems,
443                                     bool doGen)
444 {
445   Trace("sygus-cref-eval") << "Cref eval : conjecture has "
446                            << d_refinement_lemma_unit.size() << " unit and "
447                            << d_refinement_lemma_conj.size()
448                            << " non-unit refinement lemma conjunctions."
449                            << std::endl;
450   Assert(vs.size() == ms.size());
451 
452   NodeManager* nm = NodeManager::currentNM();
453 
454   Node nfalse = nm->mkConst(false);
455   Node neg_guard = d_parent->getGuard().negate();
456   bool ret = false;
457   for (unsigned r = 0; r < 2; r++)
458   {
459     std::unordered_set<Node, NodeHashFunction>& rlemmas =
460         r == 0 ? d_refinement_lemma_unit : d_refinement_lemma_conj;
461     for (const Node& lem : rlemmas)
462     {
463       Assert(!lem.isNull());
464       std::map<Node, Node> visited;
465       std::map<Node, std::vector<Node> > exp;
466       EvalSygusInvarianceTest vsit;
467       Trace("sygus-cref-eval") << "Check refinement lemma conjunct " << lem
468                                << " against current model." << std::endl;
469       Trace("sygus-cref-eval2") << "Check refinement lemma conjunct " << lem
470                                 << " against current model." << std::endl;
471       Node cre_lem;
472       Node lemcs = lem.substitute(vs.begin(), vs.end(), ms.begin(), ms.end());
473       Trace("sygus-cref-eval2")
474           << "...under substitution it is : " << lemcs << std::endl;
475       Node lemcsu = vsit.doEvaluateWithUnfolding(d_tds, lemcs);
476       Trace("sygus-cref-eval2")
477           << "...after unfolding is : " << lemcsu << std::endl;
478       if (lemcsu.isConst() && !lemcsu.getConst<bool>())
479       {
480         if (!doGen)
481         {
482           // we are not generating the lemmas, instead just return
483           return true;
484         }
485         ret = true;
486         std::vector<Node> msu;
487         std::vector<Node> mexp;
488         msu.insert(msu.end(), ms.begin(), ms.end());
489         std::map<TypeNode, int> var_count;
490         for (unsigned k = 0; k < vs.size(); k++)
491         {
492           vsit.setUpdatedTerm(msu[k]);
493           msu[k] = vs[k];
494           // substitute for everything except this
495           Node sconj =
496               lem.substitute(vs.begin(), vs.end(), msu.begin(), msu.end());
497           vsit.init(sconj, vs[k], nfalse);
498           // get minimal explanation for this
499           Node ut = vsit.getUpdatedTerm();
500           Trace("sygus-cref-eval2-debug")
501               << "  compute min explain of : " << vs[k] << " = " << ut
502               << std::endl;
503           d_tds->getExplain()->getExplanationFor(
504               vs[k], ut, mexp, vsit, var_count, false);
505           Trace("sygus-cref-eval2-debug") << "exp now: " << mexp << std::endl;
506           msu[k] = vsit.getUpdatedTerm();
507           Trace("sygus-cref-eval2-debug")
508               << "updated term : " << msu[k] << std::endl;
509         }
510         if (!mexp.empty())
511         {
512           Node en = mexp.size() == 1 ? mexp[0] : nm->mkNode(kind::AND, mexp);
513           cre_lem = nm->mkNode(kind::OR, en.negate(), neg_guard);
514         }
515         else
516         {
517           cre_lem = neg_guard;
518         }
519         if (std::find(lems.begin(), lems.end(), cre_lem) == lems.end())
520         {
521           Trace("sygus-cref-eval") << "...produced lemma : " << cre_lem
522                                    << std::endl;
523           lems.push_back(cre_lem);
524         }
525       }
526     }
527     if (!lems.empty())
528     {
529       break;
530     }
531   }
532   return ret;
533 }
534 
sampleAddRefinementLemma(const std::vector<Node> & candidates,const std::vector<Node> & vals,std::vector<Node> & lems)535 bool Cegis::sampleAddRefinementLemma(const std::vector<Node>& candidates,
536                                      const std::vector<Node>& vals,
537                                      std::vector<Node>& lems)
538 {
539   Trace("cegqi-engine") << "  *** Do sample add refinement..." << std::endl;
540   if (Trace.isOn("cegis-sample"))
541   {
542     Trace("cegis-sample") << "Check sampling for candidate solution"
543                           << std::endl;
544     for (unsigned i = 0, size = vals.size(); i < size; i++)
545     {
546       Trace("cegis-sample") << "  " << candidates[i] << " -> " << vals[i]
547                             << std::endl;
548     }
549   }
550   Assert(vals.size() == candidates.size());
551   Node sbody = d_base_body.substitute(
552       candidates.begin(), candidates.end(), vals.begin(), vals.end());
553   Trace("cegis-sample-debug2") << "Sample " << sbody << std::endl;
554   // do eager unfolding
555   std::map<Node, Node> visited_n;
556   sbody = d_qe->getTermDatabaseSygus()->getEagerUnfold(sbody, visited_n);
557   Trace("cegis-sample") << "Sample (after unfolding): " << sbody << std::endl;
558 
559   NodeManager* nm = NodeManager::currentNM();
560   for (unsigned i = 0, size = d_cegis_sampler.getNumSamplePoints(); i < size;
561        i++)
562   {
563     if (d_cegis_sample_refine.find(i) == d_cegis_sample_refine.end())
564     {
565       Node ev = d_cegis_sampler.evaluate(sbody, i);
566       Trace("cegis-sample-debug") << "...evaluate point #" << i << " to " << ev
567                                   << std::endl;
568       Assert(ev.isConst());
569       Assert(ev.getType().isBoolean());
570       if (!ev.getConst<bool>())
571       {
572         Trace("cegis-sample-debug") << "...false for point #" << i << std::endl;
573         // mark this as a CEGIS point (no longer sampled)
574         d_cegis_sample_refine.insert(i);
575         std::vector<Node> pt;
576         d_cegis_sampler.getSamplePoint(i, pt);
577         Assert(d_base_vars.size() == pt.size());
578         Node rlem = d_base_body.substitute(
579             d_base_vars.begin(), d_base_vars.end(), pt.begin(), pt.end());
580         rlem = Rewriter::rewrite(rlem);
581         if (std::find(
582                 d_refinement_lemmas.begin(), d_refinement_lemmas.end(), rlem)
583             == d_refinement_lemmas.end())
584         {
585           if (Trace.isOn("cegis-sample"))
586           {
587             Trace("cegis-sample") << "   false for point #" << i << " : ";
588             for (const Node& cn : pt)
589             {
590               Trace("cegis-sample") << cn << " ";
591             }
592             Trace("cegis-sample") << std::endl;
593           }
594           Trace("cegqi-engine") << "  *** Refine by sampling" << std::endl;
595           addRefinementLemma(rlem);
596           // if trust, we are not interested in sending out refinement lemmas
597           if (options::cegisSample() != CEGIS_SAMPLE_TRUST)
598           {
599             Node lem = nm->mkNode(OR, d_parent->getGuard().negate(), rlem);
600             lems.push_back(lem);
601           }
602           return true;
603         }
604         else
605         {
606           Trace("cegis-sample-debug") << "...duplicate." << std::endl;
607         }
608       }
609     }
610   }
611   return false;
612 }
613 
614 } /* CVC4::theory::quantifiers namespace */
615 } /* CVC4::theory namespace */
616 } /* CVC4 namespace */
617