1 /********************* */
2 /*! \file ext_theory.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Tim King, Morgan Deters
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Extended theory interface.
13 **
14 ** This implements a generic module, used by theory solvers, for performing
15 ** "context-dependent simplification", as described in Reynolds et al
16 ** "Designing Theory Solvers with Extensions", FroCoS 2017.
17 **/
18
19 #include "theory/ext_theory.h"
20
21 #include "base/cvc4_assert.h"
22 #include "smt/smt_statistics_registry.h"
23 #include "theory/quantifiers_engine.h"
24 #include "theory/substitutions.h"
25
26 using namespace std;
27
28 namespace CVC4 {
29 namespace theory {
30
ExtTheory(Theory * p,bool cacheEnabled)31 ExtTheory::ExtTheory(Theory* p, bool cacheEnabled)
32 : d_parent(p),
33 d_ext_func_terms(p->getSatContext()),
34 d_ci_inactive(p->getUserContext()),
35 d_has_extf(p->getSatContext()),
36 d_lemmas(p->getUserContext()),
37 d_pp_lemmas(p->getUserContext()),
38 d_cacheEnabled(cacheEnabled)
39 {
40 d_true = NodeManager::currentNM()->mkConst(true);
41 }
42
43 // Gets all leaf terms in n.
collectVars(Node n)44 std::vector<Node> ExtTheory::collectVars(Node n)
45 {
46 std::vector<Node> vars;
47 std::set<Node> visited;
48 std::vector<Node> worklist;
49 worklist.push_back(n);
50 while (!worklist.empty())
51 {
52 Node current = worklist.back();
53 worklist.pop_back();
54 if (current.isConst() || visited.count(current) > 0)
55 {
56 continue;
57 }
58 visited.insert(current);
59 // Treat terms not belonging to this theory as leaf
60 // note : chould include terms not belonging to this theory
61 // (commented below)
62 if (current.getNumChildren() > 0)
63 {
64 //&& Theory::theoryOf(n)==d_parent->getId() ){
65 worklist.insert(worklist.end(), current.begin(), current.end());
66 }
67 else
68 {
69 vars.push_back(current);
70 }
71 }
72 return vars;
73 }
74
getSubstitutedTerm(int effort,Node term,std::vector<Node> & exp,bool useCache)75 Node ExtTheory::getSubstitutedTerm(int effort,
76 Node term,
77 std::vector<Node>& exp,
78 bool useCache)
79 {
80 if (useCache)
81 {
82 Assert(d_gst_cache[effort].find(term) != d_gst_cache[effort].end());
83 exp.insert(exp.end(),
84 d_gst_cache[effort][term].d_exp.begin(),
85 d_gst_cache[effort][term].d_exp.end());
86 return d_gst_cache[effort][term].d_sterm;
87 }
88
89 std::vector<Node> terms;
90 terms.push_back(term);
91 std::vector<Node> sterms;
92 std::vector<std::vector<Node> > exps;
93 getSubstitutedTerms(effort, terms, sterms, exps, useCache);
94 Assert(sterms.size() == 1);
95 Assert(exps.size() == 1);
96 exp.insert(exp.end(), exps[0].begin(), exps[0].end());
97 return sterms[0];
98 }
99
100 // do inferences
getSubstitutedTerms(int effort,const std::vector<Node> & terms,std::vector<Node> & sterms,std::vector<std::vector<Node>> & exp,bool useCache)101 void ExtTheory::getSubstitutedTerms(int effort,
102 const std::vector<Node>& terms,
103 std::vector<Node>& sterms,
104 std::vector<std::vector<Node> >& exp,
105 bool useCache)
106 {
107 if (useCache)
108 {
109 for (const Node& n : terms)
110 {
111 Assert(d_gst_cache[effort].find(n) != d_gst_cache[effort].end());
112 sterms.push_back(d_gst_cache[effort][n].d_sterm);
113 exp.push_back(std::vector<Node>());
114 exp[0].insert(exp[0].end(),
115 d_gst_cache[effort][n].d_exp.begin(),
116 d_gst_cache[effort][n].d_exp.end());
117 }
118 }
119 else
120 {
121 Trace("extt-debug") << "getSubstitutedTerms for " << terms.size() << " / "
122 << d_ext_func_terms.size() << " extended functions."
123 << std::endl;
124 if (!terms.empty())
125 {
126 // all variables we need to find a substitution for
127 std::vector<Node> vars;
128 std::vector<Node> sub;
129 std::map<Node, std::vector<Node> > expc;
130 for (const Node& n : terms)
131 {
132 // do substitution, rewrite
133 std::map<Node, ExtfInfo>::iterator iti = d_extf_info.find(n);
134 Assert(iti != d_extf_info.end());
135 for (const Node& v : iti->second.d_vars)
136 {
137 if (std::find(vars.begin(), vars.end(), v) == vars.end())
138 {
139 vars.push_back(v);
140 }
141 }
142 }
143 bool useSubs = d_parent->getCurrentSubstitution(effort, vars, sub, expc);
144 // get the current substitution for all variables
145 Assert(!useSubs || vars.size() == sub.size());
146 for (const Node& n : terms)
147 {
148 Node ns = n;
149 std::vector<Node> expn;
150 if (useSubs)
151 {
152 // do substitution
153 ns = n.substitute(vars.begin(), vars.end(), sub.begin(), sub.end());
154 if (ns != n)
155 {
156 // build explanation: explanation vars = sub for each vars in FV(n)
157 std::map<Node, ExtfInfo>::iterator iti = d_extf_info.find(n);
158 Assert(iti != d_extf_info.end());
159 for (const Node& v : iti->second.d_vars)
160 {
161 std::map<Node, std::vector<Node> >::iterator itx = expc.find(v);
162 if (itx != expc.end())
163 {
164 for (const Node& e : itx->second)
165 {
166 if (std::find(expn.begin(), expn.end(), e) == expn.end())
167 {
168 expn.push_back(e);
169 }
170 }
171 }
172 }
173 }
174 Trace("extt-debug")
175 << " have " << n << " == " << ns << ", exp size=" << expn.size()
176 << "." << std::endl;
177 }
178 // add to vector
179 sterms.push_back(ns);
180 exp.push_back(expn);
181 // add to cache
182 if (d_cacheEnabled)
183 {
184 d_gst_cache[effort][n].d_sterm = ns;
185 d_gst_cache[effort][n].d_exp.clear();
186 d_gst_cache[effort][n].d_exp.insert(
187 d_gst_cache[effort][n].d_exp.end(), expn.begin(), expn.end());
188 }
189 }
190 }
191 }
192 }
193
doInferencesInternal(int effort,const std::vector<Node> & terms,std::vector<Node> & nred,bool batch,bool isRed)194 bool ExtTheory::doInferencesInternal(int effort,
195 const std::vector<Node>& terms,
196 std::vector<Node>& nred,
197 bool batch,
198 bool isRed)
199 {
200 if (batch)
201 {
202 bool addedLemma = false;
203 if (isRed)
204 {
205 for (const Node& n : terms)
206 {
207 Node nr;
208 // note: could do reduction with substitution here
209 int ret = d_parent->getReduction(effort, n, nr);
210 if (ret == 0)
211 {
212 nred.push_back(n);
213 }
214 else
215 {
216 if (!nr.isNull() && n != nr)
217 {
218 Node lem = NodeManager::currentNM()->mkNode(kind::EQUAL, n, nr);
219 if (sendLemma(lem, true))
220 {
221 Trace("extt-lemma")
222 << "ExtTheory : reduction lemma : " << lem << std::endl;
223 addedLemma = true;
224 }
225 }
226 markReduced(n, ret < 0);
227 }
228 }
229 }
230 else
231 {
232 std::vector<Node> sterms;
233 std::vector<std::vector<Node> > exp;
234 getSubstitutedTerms(effort, terms, sterms, exp);
235 std::map<Node, unsigned> sterm_index;
236 NodeManager* nm = NodeManager::currentNM();
237 for (unsigned i = 0, size = terms.size(); i < size; i++)
238 {
239 bool processed = false;
240 if (sterms[i] != terms[i])
241 {
242 Node sr = Rewriter::rewrite(sterms[i]);
243 // ask the theory if this term is reduced, e.g. is it constant or it
244 // is a non-extf term.
245 if (d_parent->isExtfReduced(effort, sr, terms[i], exp[i]))
246 {
247 processed = true;
248 markReduced(terms[i]);
249 // We have exp[i] => terms[i] = sr, convert this to a clause.
250 // This ensures the proof infrastructure can process this as a
251 // normal theory lemma.
252 Node eq = terms[i].eqNode(sr);
253 Node lem = eq;
254 if (!exp[i].empty())
255 {
256 std::vector<Node> eei;
257 for (const Node& e : exp[i])
258 {
259 eei.push_back(e.negate());
260 }
261 eei.push_back(eq);
262 lem = nm->mkNode(kind::OR, eei);
263 }
264
265 Trace("extt-debug") << "ExtTheory::doInferences : infer : " << eq
266 << " by " << exp[i] << std::endl;
267 Trace("extt-debug") << "...send lemma " << lem << std::endl;
268 if (sendLemma(lem))
269 {
270 Trace("extt-lemma")
271 << "ExtTheory : substitution + rewrite lemma : " << lem
272 << std::endl;
273 addedLemma = true;
274 }
275 }
276 else
277 {
278 // check if we have already reduced this
279 std::map<Node, unsigned>::iterator itsi = sterm_index.find(sr);
280 if (itsi == sterm_index.end())
281 {
282 sterm_index[sr] = i;
283 }
284 else
285 {
286 // unsigned j = itsi->second;
287 // note : can add (non-reducing) lemma :
288 // exp[j] ^ exp[i] => sterms[i] = sterms[j]
289 }
290
291 Trace("extt-nred") << "Non-reduced term : " << sr << std::endl;
292 }
293 }
294 else
295 {
296 Trace("extt-nred") << "Non-reduced term : " << sterms[i] << std::endl;
297 }
298 if (!processed)
299 {
300 nred.push_back(terms[i]);
301 }
302 }
303 }
304 return addedLemma;
305 }
306 // non-batch
307 std::vector<Node> nnred;
308 if (terms.empty())
309 {
310 for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
311 it != d_ext_func_terms.end();
312 ++it)
313 {
314 if ((*it).second && !isContextIndependentInactive((*it).first))
315 {
316 std::vector<Node> nterms;
317 nterms.push_back((*it).first);
318 if (doInferencesInternal(effort, nterms, nnred, true, isRed))
319 {
320 return true;
321 }
322 }
323 }
324 }
325 else
326 {
327 for (const Node& n : terms)
328 {
329 std::vector<Node> nterms;
330 nterms.push_back(n);
331 if (doInferencesInternal(effort, nterms, nnred, true, isRed))
332 {
333 return true;
334 }
335 }
336 }
337 return false;
338 }
339
sendLemma(Node lem,bool preprocess)340 bool ExtTheory::sendLemma(Node lem, bool preprocess)
341 {
342 if (preprocess)
343 {
344 if (d_pp_lemmas.find(lem) == d_pp_lemmas.end())
345 {
346 d_pp_lemmas.insert(lem);
347 d_parent->getOutputChannel().lemma(lem, false, true);
348 return true;
349 }
350 }
351 else
352 {
353 if (d_lemmas.find(lem) == d_lemmas.end())
354 {
355 d_lemmas.insert(lem);
356 d_parent->getOutputChannel().lemma(lem);
357 return true;
358 }
359 }
360 return false;
361 }
362
doInferences(int effort,const std::vector<Node> & terms,std::vector<Node> & nred,bool batch)363 bool ExtTheory::doInferences(int effort,
364 const std::vector<Node>& terms,
365 std::vector<Node>& nred,
366 bool batch)
367 {
368 if (!terms.empty())
369 {
370 return doInferencesInternal(effort, terms, nred, batch, false);
371 }
372 return false;
373 }
374
doInferences(int effort,std::vector<Node> & nred,bool batch)375 bool ExtTheory::doInferences(int effort, std::vector<Node>& nred, bool batch)
376 {
377 std::vector<Node> terms = getActive();
378 return doInferencesInternal(effort, terms, nred, batch, false);
379 }
380
doReductions(int effort,const std::vector<Node> & terms,std::vector<Node> & nred,bool batch)381 bool ExtTheory::doReductions(int effort,
382 const std::vector<Node>& terms,
383 std::vector<Node>& nred,
384 bool batch)
385 {
386 if (!terms.empty())
387 {
388 return doInferencesInternal(effort, terms, nred, batch, true);
389 }
390 return false;
391 }
392
doReductions(int effort,std::vector<Node> & nred,bool batch)393 bool ExtTheory::doReductions(int effort, std::vector<Node>& nred, bool batch)
394 {
395 const std::vector<Node> terms = getActive();
396 return doInferencesInternal(effort, terms, nred, batch, true);
397 }
398
399 // Register term.
registerTerm(Node n)400 void ExtTheory::registerTerm(Node n)
401 {
402 if (d_extf_kind.find(n.getKind()) != d_extf_kind.end())
403 {
404 if (d_ext_func_terms.find(n) == d_ext_func_terms.end())
405 {
406 Trace("extt-debug") << "Found extended function : " << n << " in "
407 << d_parent->getId() << std::endl;
408 d_ext_func_terms[n] = true;
409 d_has_extf = n;
410 d_extf_info[n].d_vars = collectVars(n);
411 }
412 }
413 }
414
registerTermRec(Node n)415 void ExtTheory::registerTermRec(Node n)
416 {
417 std::unordered_set<TNode, TNodeHashFunction> visited;
418 std::vector<TNode> visit;
419 TNode cur;
420 visit.push_back(n);
421 do
422 {
423 cur = visit.back();
424 visit.pop_back();
425 if (visited.find(cur) == visited.end())
426 {
427 visited.insert(cur);
428 registerTerm(cur);
429 for (const Node& cc : cur)
430 {
431 visit.push_back(cc);
432 }
433 }
434 } while (!visit.empty());
435 }
436
437 // mark reduced
markReduced(Node n,bool contextDepend)438 void ExtTheory::markReduced(Node n, bool contextDepend)
439 {
440 Trace("extt-debug") << "Mark reduced " << n << std::endl;
441 registerTerm(n);
442 Assert(d_ext_func_terms.find(n) != d_ext_func_terms.end());
443 d_ext_func_terms[n] = false;
444 if (!contextDepend)
445 {
446 d_ci_inactive.insert(n);
447 }
448
449 // update has_extf
450 if (d_has_extf.get() == n)
451 {
452 for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
453 it != d_ext_func_terms.end();
454 ++it)
455 {
456 // if not already reduced
457 if ((*it).second && !isContextIndependentInactive((*it).first))
458 {
459 d_has_extf = (*it).first;
460 }
461 }
462 }
463 }
464
465 // mark congruent
markCongruent(Node a,Node b)466 void ExtTheory::markCongruent(Node a, Node b)
467 {
468 Trace("extt-debug") << "Mark congruent : " << a << " " << b << std::endl;
469 registerTerm(a);
470 registerTerm(b);
471 NodeBoolMap::const_iterator it = d_ext_func_terms.find(b);
472 if (it != d_ext_func_terms.end())
473 {
474 if (d_ext_func_terms.find(a) != d_ext_func_terms.end())
475 {
476 d_ext_func_terms[a] = d_ext_func_terms[a] && (*it).second;
477 }
478 else
479 {
480 Assert(false);
481 }
482 d_ext_func_terms[b] = false;
483 }
484 else
485 {
486 Assert(false);
487 }
488 }
489
isContextIndependentInactive(Node n) const490 bool ExtTheory::isContextIndependentInactive(Node n) const
491 {
492 return d_ci_inactive.find(n) != d_ci_inactive.end();
493 }
494
getTerms(std::vector<Node> & terms)495 void ExtTheory::getTerms(std::vector<Node>& terms)
496 {
497 for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
498 it != d_ext_func_terms.end();
499 ++it)
500 {
501 terms.push_back((*it).first);
502 }
503 }
504
hasActiveTerm() const505 bool ExtTheory::hasActiveTerm() const { return !d_has_extf.get().isNull(); }
506
507 // is active
isActive(Node n) const508 bool ExtTheory::isActive(Node n) const
509 {
510 NodeBoolMap::const_iterator it = d_ext_func_terms.find(n);
511 if (it != d_ext_func_terms.end())
512 {
513 return (*it).second && !isContextIndependentInactive(n);
514 }
515 return false;
516 }
517
518 // get active
getActive() const519 std::vector<Node> ExtTheory::getActive() const
520 {
521 std::vector<Node> active;
522 for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
523 it != d_ext_func_terms.end();
524 ++it)
525 {
526 // if not already reduced
527 if ((*it).second && !isContextIndependentInactive((*it).first))
528 {
529 active.push_back((*it).first);
530 }
531 }
532 return active;
533 }
534
getActive(Kind k) const535 std::vector<Node> ExtTheory::getActive(Kind k) const
536 {
537 std::vector<Node> active;
538 for (NodeBoolMap::iterator it = d_ext_func_terms.begin();
539 it != d_ext_func_terms.end();
540 ++it)
541 {
542 // if not already reduced
543 if ((*it).first.getKind() == k && (*it).second
544 && !isContextIndependentInactive((*it).first))
545 {
546 active.push_back((*it).first);
547 }
548 }
549 return active;
550 }
551
clearCache()552 void ExtTheory::clearCache() { d_gst_cache.clear(); }
553
554 } /* CVC4::theory namespace */
555 } /* CVC4 namespace */
556