1 /*********************                                                        */
2 /*! \file ext_theory.cpp
3  ** \verbatim
4  ** Top contributors (to current version):
5  **   Andrew Reynolds, Tim King, Morgan Deters
6  ** This file is part of the CVC4 project.
7  ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8  ** in the top-level source directory) and their institutional affiliations.
9  ** All rights reserved.  See the file COPYING in the top-level source
10  ** directory for licensing information.\endverbatim
11  **
12  ** \brief Extended theory interface.
13  **
14  ** This implements a generic module, used by theory solvers, for performing
15  ** "context-dependent simplification", as described in Reynolds et al
16  ** "Designing Theory Solvers with Extensions", FroCoS 2017.
17  **/
18 
19 #include "theory/ext_theory.h"
20 
21 #include "base/cvc4_assert.h"
22 #include "smt/smt_statistics_registry.h"
23 #include "theory/quantifiers_engine.h"
24 #include "theory/substitutions.h"
25 
26 using namespace std;
27 
28 namespace CVC4 {
29 namespace theory {
30 
ExtTheory(Theory * p,bool cacheEnabled)31 ExtTheory::ExtTheory(Theory* p, bool cacheEnabled)
32     : d_parent(p),
33       d_ext_func_terms(p->getSatContext()),
34       d_ci_inactive(p->getUserContext()),
35       d_has_extf(p->getSatContext()),
36       d_lemmas(p->getUserContext()),
37       d_pp_lemmas(p->getUserContext()),
38       d_cacheEnabled(cacheEnabled)
39 {
40   d_true = NodeManager::currentNM()->mkConst(true);
41 }
42 
43 // Gets all leaf terms in n.
collectVars(Node n)44 std::vector<Node> ExtTheory::collectVars(Node n)
45 {
46   std::vector<Node> vars;
47   std::set<Node> visited;
48   std::vector<Node> worklist;
49   worklist.push_back(n);
50   while (!worklist.empty())
51   {
52     Node current = worklist.back();
53     worklist.pop_back();
54     if (current.isConst() || visited.count(current) > 0)
55     {
56       continue;
57     }
58     visited.insert(current);
59     // Treat terms not belonging to this theory as leaf
60     // note : chould include terms not belonging to this theory
61     // (commented below)
62     if (current.getNumChildren() > 0)
63     {
64       //&& Theory::theoryOf(n)==d_parent->getId() ){
65       worklist.insert(worklist.end(), current.begin(), current.end());
66     }
67     else
68     {
69       vars.push_back(current);
70     }
71   }
72   return vars;
73 }
74 
getSubstitutedTerm(int effort,Node term,std::vector<Node> & exp,bool useCache)75 Node ExtTheory::getSubstitutedTerm(int effort,
76                                    Node term,
77                                    std::vector<Node>& exp,
78                                    bool useCache)
79 {
80   if (useCache)
81   {
82     Assert(d_gst_cache[effort].find(term) != d_gst_cache[effort].end());
83     exp.insert(exp.end(),
84                d_gst_cache[effort][term].d_exp.begin(),
85                d_gst_cache[effort][term].d_exp.end());
86     return d_gst_cache[effort][term].d_sterm;
87   }
88 
89   std::vector<Node> terms;
90   terms.push_back(term);
91   std::vector<Node> sterms;
92   std::vector<std::vector<Node> > exps;
93   getSubstitutedTerms(effort, terms, sterms, exps, useCache);
94   Assert(sterms.size() == 1);
95   Assert(exps.size() == 1);
96   exp.insert(exp.end(), exps[0].begin(), exps[0].end());
97   return sterms[0];
98 }
99 
100 // do inferences
getSubstitutedTerms(int effort,const std::vector<Node> & terms,std::vector<Node> & sterms,std::vector<std::vector<Node>> & exp,bool useCache)101 void ExtTheory::getSubstitutedTerms(int effort,
102                                     const std::vector<Node>& terms,
103                                     std::vector<Node>& sterms,
104                                     std::vector<std::vector<Node> >& exp,
105                                     bool useCache)
106 {
107   if (useCache)
108   {
109     for (const Node& n : terms)
110     {
111       Assert(d_gst_cache[effort].find(n) != d_gst_cache[effort].end());
112       sterms.push_back(d_gst_cache[effort][n].d_sterm);
113       exp.push_back(std::vector<Node>());
114       exp[0].insert(exp[0].end(),
115                     d_gst_cache[effort][n].d_exp.begin(),
116                     d_gst_cache[effort][n].d_exp.end());
117     }
118   }
119   else
120   {
121     Trace("extt-debug") << "getSubstitutedTerms for " << terms.size() << " / "
122                         << d_ext_func_terms.size() << " extended functions."
123                         << std::endl;
124     if (!terms.empty())
125     {
126       // all variables we need to find a substitution for
127       std::vector<Node> vars;
128       std::vector<Node> sub;
129       std::map<Node, std::vector<Node> > expc;
130       for (const Node& n : terms)
131       {
132         // do substitution, rewrite
133         std::map<Node, ExtfInfo>::iterator iti = d_extf_info.find(n);
134         Assert(iti != d_extf_info.end());
135         for (const Node& v : iti->second.d_vars)
136         {
137           if (std::find(vars.begin(), vars.end(), v) == vars.end())
138           {
139             vars.push_back(v);
140           }
141         }
142       }
143       bool useSubs = d_parent->getCurrentSubstitution(effort, vars, sub, expc);
144       // get the current substitution for all variables
145       Assert(!useSubs || vars.size() == sub.size());
146       for (const Node& n : terms)
147       {
148         Node ns = n;
149         std::vector<Node> expn;
150         if (useSubs)
151         {
152           // do substitution
153           ns = n.substitute(vars.begin(), vars.end(), sub.begin(), sub.end());
154           if (ns != n)
155           {
156             // build explanation: explanation vars = sub for each vars in FV(n)
157             std::map<Node, ExtfInfo>::iterator iti = d_extf_info.find(n);
158             Assert(iti != d_extf_info.end());
159             for (const Node& v : iti->second.d_vars)
160             {
161               std::map<Node, std::vector<Node> >::iterator itx = expc.find(v);
162               if (itx != expc.end())
163               {
164                 for (const Node& e : itx->second)
165                 {
166                   if (std::find(expn.begin(), expn.end(), e) == expn.end())
167                   {
168                     expn.push_back(e);
169                   }
170                 }
171               }
172             }
173           }
174           Trace("extt-debug")
175               << "  have " << n << " == " << ns << ", exp size=" << expn.size()
176               << "." << std::endl;
177         }
178         // add to vector
179         sterms.push_back(ns);
180         exp.push_back(expn);
181         // add to cache
182         if (d_cacheEnabled)
183         {
184           d_gst_cache[effort][n].d_sterm = ns;
185           d_gst_cache[effort][n].d_exp.clear();
186           d_gst_cache[effort][n].d_exp.insert(
187               d_gst_cache[effort][n].d_exp.end(), expn.begin(), expn.end());
188         }
189       }
190     }
191   }
192 }
193 
doInferencesInternal(int effort,const std::vector<Node> & terms,std::vector<Node> & nred,bool batch,bool isRed)194 bool ExtTheory::doInferencesInternal(int effort,
195                                      const std::vector<Node>& terms,
196                                      std::vector<Node>& nred,
197                                      bool batch,
198                                      bool isRed)
199 {
200   if (batch)
201   {
202     bool addedLemma = false;
203     if (isRed)
204     {
205       for (const Node& n : terms)
206       {
207         Node nr;
208         // note: could do reduction with substitution here
209         int ret = d_parent->getReduction(effort, n, nr);
210         if (ret == 0)
211         {
212           nred.push_back(n);
213         }
214         else
215         {
216           if (!nr.isNull() && n != nr)
217           {
218             Node lem = NodeManager::currentNM()->mkNode(kind::EQUAL, n, nr);
219             if (sendLemma(lem, true))
220             {
221               Trace("extt-lemma")
222                   << "ExtTheory : reduction lemma : " << lem << std::endl;
223               addedLemma = true;
224             }
225           }
226           markReduced(n, ret < 0);
227         }
228       }
229     }
230     else
231     {
232       std::vector<Node> sterms;
233       std::vector<std::vector<Node> > exp;
234       getSubstitutedTerms(effort, terms, sterms, exp);
235       std::map<Node, unsigned> sterm_index;
236       NodeManager* nm = NodeManager::currentNM();
237       for (unsigned i = 0, size = terms.size(); i < size; i++)
238       {
239         bool processed = false;
240         if (sterms[i] != terms[i])
241         {
242           Node sr = Rewriter::rewrite(sterms[i]);
243           // ask the theory if this term is reduced, e.g. is it constant or it
244           // is a non-extf term.
245           if (d_parent->isExtfReduced(effort, sr, terms[i], exp[i]))
246           {
247             processed = true;
248             markReduced(terms[i]);
249             // We have exp[i] => terms[i] = sr, convert this to a clause.
250             // This ensures the proof infrastructure can process this as a
251             // normal theory lemma.
252             Node eq = terms[i].eqNode(sr);
253             Node lem = eq;
254             if (!exp[i].empty())
255             {
256               std::vector<Node> eei;
257               for (const Node& e : exp[i])
258               {
259                 eei.push_back(e.negate());
260               }
261               eei.push_back(eq);
262               lem = nm->mkNode(kind::OR, eei);
263             }
264 
265             Trace("extt-debug") << "ExtTheory::doInferences : infer : " << eq
266                                 << " by " << exp[i] << std::endl;
267             Trace("extt-debug") << "...send lemma " << lem << std::endl;
268             if (sendLemma(lem))
269             {
270               Trace("extt-lemma")
271                   << "ExtTheory : substitution + rewrite lemma : " << lem
272                   << std::endl;
273               addedLemma = true;
274             }
275           }
276           else
277           {
278             // check if we have already reduced this
279             std::map<Node, unsigned>::iterator itsi = sterm_index.find(sr);
280             if (itsi == sterm_index.end())
281             {
282               sterm_index[sr] = i;
283             }
284             else
285             {
286               // unsigned j = itsi->second;
287               // note : can add (non-reducing) lemma :
288               //   exp[j] ^ exp[i] => sterms[i] = sterms[j]
289             }
290 
291             Trace("extt-nred") << "Non-reduced term : " << sr << std::endl;
292           }
293         }
294         else
295         {
296           Trace("extt-nred") << "Non-reduced term : " << sterms[i] << std::endl;
297         }
298         if (!processed)
299         {
300           nred.push_back(terms[i]);
301         }
302       }
303     }
304     return addedLemma;
305   }
306   // non-batch
307   std::vector<Node> nnred;
308   if (terms.empty())
309   {
310     for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
311          it != d_ext_func_terms.end();
312          ++it)
313     {
314       if ((*it).second && !isContextIndependentInactive((*it).first))
315       {
316         std::vector<Node> nterms;
317         nterms.push_back((*it).first);
318         if (doInferencesInternal(effort, nterms, nnred, true, isRed))
319         {
320           return true;
321         }
322       }
323     }
324   }
325   else
326   {
327     for (const Node& n : terms)
328     {
329       std::vector<Node> nterms;
330       nterms.push_back(n);
331       if (doInferencesInternal(effort, nterms, nnred, true, isRed))
332       {
333         return true;
334       }
335     }
336   }
337   return false;
338 }
339 
sendLemma(Node lem,bool preprocess)340 bool ExtTheory::sendLemma(Node lem, bool preprocess)
341 {
342   if (preprocess)
343   {
344     if (d_pp_lemmas.find(lem) == d_pp_lemmas.end())
345     {
346       d_pp_lemmas.insert(lem);
347       d_parent->getOutputChannel().lemma(lem, false, true);
348       return true;
349     }
350   }
351   else
352   {
353     if (d_lemmas.find(lem) == d_lemmas.end())
354     {
355       d_lemmas.insert(lem);
356       d_parent->getOutputChannel().lemma(lem);
357       return true;
358     }
359   }
360   return false;
361 }
362 
doInferences(int effort,const std::vector<Node> & terms,std::vector<Node> & nred,bool batch)363 bool ExtTheory::doInferences(int effort,
364                              const std::vector<Node>& terms,
365                              std::vector<Node>& nred,
366                              bool batch)
367 {
368   if (!terms.empty())
369   {
370     return doInferencesInternal(effort, terms, nred, batch, false);
371   }
372   return false;
373 }
374 
doInferences(int effort,std::vector<Node> & nred,bool batch)375 bool ExtTheory::doInferences(int effort, std::vector<Node>& nred, bool batch)
376 {
377   std::vector<Node> terms = getActive();
378   return doInferencesInternal(effort, terms, nred, batch, false);
379 }
380 
doReductions(int effort,const std::vector<Node> & terms,std::vector<Node> & nred,bool batch)381 bool ExtTheory::doReductions(int effort,
382                              const std::vector<Node>& terms,
383                              std::vector<Node>& nred,
384                              bool batch)
385 {
386   if (!terms.empty())
387   {
388     return doInferencesInternal(effort, terms, nred, batch, true);
389   }
390   return false;
391 }
392 
doReductions(int effort,std::vector<Node> & nred,bool batch)393 bool ExtTheory::doReductions(int effort, std::vector<Node>& nred, bool batch)
394 {
395   const std::vector<Node> terms = getActive();
396   return doInferencesInternal(effort, terms, nred, batch, true);
397 }
398 
399 // Register term.
registerTerm(Node n)400 void ExtTheory::registerTerm(Node n)
401 {
402   if (d_extf_kind.find(n.getKind()) != d_extf_kind.end())
403   {
404     if (d_ext_func_terms.find(n) == d_ext_func_terms.end())
405     {
406       Trace("extt-debug") << "Found extended function : " << n << " in "
407                           << d_parent->getId() << std::endl;
408       d_ext_func_terms[n] = true;
409       d_has_extf = n;
410       d_extf_info[n].d_vars = collectVars(n);
411     }
412   }
413 }
414 
registerTermRec(Node n)415 void ExtTheory::registerTermRec(Node n)
416 {
417   std::unordered_set<TNode, TNodeHashFunction> visited;
418   std::vector<TNode> visit;
419   TNode cur;
420   visit.push_back(n);
421   do
422   {
423     cur = visit.back();
424     visit.pop_back();
425     if (visited.find(cur) == visited.end())
426     {
427       visited.insert(cur);
428       registerTerm(cur);
429       for (const Node& cc : cur)
430       {
431         visit.push_back(cc);
432       }
433     }
434   } while (!visit.empty());
435 }
436 
437 // mark reduced
markReduced(Node n,bool contextDepend)438 void ExtTheory::markReduced(Node n, bool contextDepend)
439 {
440   Trace("extt-debug") << "Mark reduced " << n << std::endl;
441   registerTerm(n);
442   Assert(d_ext_func_terms.find(n) != d_ext_func_terms.end());
443   d_ext_func_terms[n] = false;
444   if (!contextDepend)
445   {
446     d_ci_inactive.insert(n);
447   }
448 
449   // update has_extf
450   if (d_has_extf.get() == n)
451   {
452     for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
453          it != d_ext_func_terms.end();
454          ++it)
455     {
456       // if not already reduced
457       if ((*it).second && !isContextIndependentInactive((*it).first))
458       {
459         d_has_extf = (*it).first;
460       }
461     }
462   }
463 }
464 
465 // mark congruent
markCongruent(Node a,Node b)466 void ExtTheory::markCongruent(Node a, Node b)
467 {
468   Trace("extt-debug") << "Mark congruent : " << a << " " << b << std::endl;
469   registerTerm(a);
470   registerTerm(b);
471   NodeBoolMap::const_iterator it = d_ext_func_terms.find(b);
472   if (it != d_ext_func_terms.end())
473   {
474     if (d_ext_func_terms.find(a) != d_ext_func_terms.end())
475     {
476       d_ext_func_terms[a] = d_ext_func_terms[a] && (*it).second;
477     }
478     else
479     {
480       Assert(false);
481     }
482     d_ext_func_terms[b] = false;
483   }
484   else
485   {
486     Assert(false);
487   }
488 }
489 
isContextIndependentInactive(Node n) const490 bool ExtTheory::isContextIndependentInactive(Node n) const
491 {
492   return d_ci_inactive.find(n) != d_ci_inactive.end();
493 }
494 
getTerms(std::vector<Node> & terms)495 void ExtTheory::getTerms(std::vector<Node>& terms)
496 {
497   for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
498        it != d_ext_func_terms.end();
499        ++it)
500   {
501     terms.push_back((*it).first);
502   }
503 }
504 
hasActiveTerm() const505 bool ExtTheory::hasActiveTerm() const { return !d_has_extf.get().isNull(); }
506 
507 // is active
isActive(Node n) const508 bool ExtTheory::isActive(Node n) const
509 {
510   NodeBoolMap::const_iterator it = d_ext_func_terms.find(n);
511   if (it != d_ext_func_terms.end())
512   {
513     return (*it).second && !isContextIndependentInactive(n);
514   }
515   return false;
516 }
517 
518 // get active
getActive() const519 std::vector<Node> ExtTheory::getActive() const
520 {
521   std::vector<Node> active;
522   for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
523        it != d_ext_func_terms.end();
524        ++it)
525   {
526     // if not already reduced
527     if ((*it).second && !isContextIndependentInactive((*it).first))
528     {
529       active.push_back((*it).first);
530     }
531   }
532   return active;
533 }
534 
getActive(Kind k) const535 std::vector<Node> ExtTheory::getActive(Kind k) const
536 {
537   std::vector<Node> active;
538   for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
539        it != d_ext_func_terms.end();
540        ++it)
541   {
542     // if not already reduced
543     if ((*it).first.getKind() == k && (*it).second
544         && !isContextIndependentInactive((*it).first))
545     {
546       active.push_back((*it).first);
547     }
548   }
549   return active;
550 }
551 
clearCache()552 void ExtTheory::clearCache() { d_gst_cache.clear(); }
553 
554 } /* CVC4::theory namespace */
555 } /* CVC4 namespace */
556