1 /*********************                                                        */
2 /*! \file inst_match_trie.cpp
3  ** \verbatim
4  ** Top contributors (to current version):
5  **   Andrew Reynolds, Morgan Deters, Tim King
6  ** This file is part of the CVC4 project.
7  ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8  ** in the top-level source directory) and their institutional affiliations.
9  ** All rights reserved.  See the file COPYING in the top-level source
10  ** directory for licensing information.\endverbatim
11  **
12  ** \brief Implementation of inst match class
13  **/
14 
15 #include "theory/quantifiers/inst_match_trie.h"
16 
17 #include "theory/quantifiers/instantiate.h"
18 #include "theory/quantifiers/quant_util.h"
19 #include "theory/quantifiers/term_database.h"
20 #include "theory/quantifiers_engine.h"
21 
22 using namespace CVC4::context;
23 
24 namespace CVC4 {
25 namespace theory {
26 namespace inst {
27 
addInstMatch(QuantifiersEngine * qe,Node f,std::vector<Node> & m,bool modEq,ImtIndexOrder * imtio,bool onlyExist,unsigned index)28 bool InstMatchTrie::addInstMatch(QuantifiersEngine* qe,
29                                  Node f,
30                                  std::vector<Node>& m,
31                                  bool modEq,
32                                  ImtIndexOrder* imtio,
33                                  bool onlyExist,
34                                  unsigned index)
35 {
36   if (index == f[0].getNumChildren()
37       || (imtio && index == imtio->d_order.size()))
38   {
39     return false;
40   }
41   unsigned i_index = imtio ? imtio->d_order[index] : index;
42   Node n = m[i_index];
43   std::map<Node, InstMatchTrie>::iterator it = d_data.find(n);
44   if (it != d_data.end())
45   {
46     bool ret =
47         it->second.addInstMatch(qe, f, m, modEq, imtio, onlyExist, index + 1);
48     if (!onlyExist || !ret)
49     {
50       return ret;
51     }
52   }
53   if (modEq)
54   {
55     // check modulo equality if any other instantiation match exists
56     if (!n.isNull() && qe->getEqualityQuery()->getEngine()->hasTerm(n))
57     {
58       eq::EqClassIterator eqc(
59           qe->getEqualityQuery()->getEngine()->getRepresentative(n),
60           qe->getEqualityQuery()->getEngine());
61       while (!eqc.isFinished())
62       {
63         Node en = (*eqc);
64         if (en != n)
65         {
66           std::map<Node, InstMatchTrie>::iterator itc = d_data.find(en);
67           if (itc != d_data.end())
68           {
69             if (itc->second.addInstMatch(
70                     qe, f, m, modEq, imtio, true, index + 1))
71             {
72               return false;
73             }
74           }
75         }
76         ++eqc;
77       }
78     }
79   }
80   if (!onlyExist)
81   {
82     d_data[n].addInstMatch(qe, f, m, modEq, imtio, false, index + 1);
83   }
84   return true;
85 }
86 
removeInstMatch(Node q,std::vector<Node> & m,ImtIndexOrder * imtio,unsigned index)87 bool InstMatchTrie::removeInstMatch(Node q,
88                                     std::vector<Node>& m,
89                                     ImtIndexOrder* imtio,
90                                     unsigned index)
91 {
92   Assert(index < q[0].getNumChildren());
93   Assert(!imtio || index < imtio->d_order.size());
94   unsigned i_index = imtio ? imtio->d_order[index] : index;
95   Node n = m[i_index];
96   std::map<Node, InstMatchTrie>::iterator it = d_data.find(n);
97   if (it != d_data.end())
98   {
99     if ((index + 1) == q[0].getNumChildren()
100         || (imtio && (index + 1) == imtio->d_order.size()))
101     {
102       d_data.erase(n);
103       return true;
104     }
105     return it->second.removeInstMatch(q, m, imtio, index + 1);
106   }
107   return false;
108 }
109 
recordInstLemma(Node q,std::vector<Node> & m,Node lem,ImtIndexOrder * imtio,unsigned index)110 bool InstMatchTrie::recordInstLemma(Node q,
111                                     std::vector<Node>& m,
112                                     Node lem,
113                                     ImtIndexOrder* imtio,
114                                     unsigned index)
115 {
116   if (index == q[0].getNumChildren()
117       || (imtio && index == imtio->d_order.size()))
118   {
119     setInstLemma(lem);
120     return true;
121   }
122   unsigned i_index = imtio ? imtio->d_order[index] : index;
123   std::map<Node, InstMatchTrie>::iterator it = d_data.find(m[i_index]);
124   if (it != d_data.end())
125   {
126     return it->second.recordInstLemma(q, m, lem, imtio, index + 1);
127   }
128   return false;
129 }
130 
print(std::ostream & out,Node q,std::vector<TNode> & terms,bool & firstTime,bool useActive,std::vector<Node> & active) const131 void InstMatchTrie::print(std::ostream& out,
132                           Node q,
133                           std::vector<TNode>& terms,
134                           bool& firstTime,
135                           bool useActive,
136                           std::vector<Node>& active) const
137 {
138   if (terms.size() == q[0].getNumChildren())
139   {
140     bool print;
141     if (useActive)
142     {
143       if (hasInstLemma())
144       {
145         Node lem = getInstLemma();
146         print = std::find(active.begin(), active.end(), lem) != active.end();
147       }
148       else
149       {
150         print = false;
151       }
152     }
153     else
154     {
155       print = true;
156     }
157     if (print)
158     {
159       if (firstTime)
160       {
161         out << "(instantiation " << q << std::endl;
162         firstTime = false;
163       }
164       out << "  ( ";
165       for (unsigned i = 0, size = terms.size(); i < size; i++)
166       {
167         if (i > 0)
168         {
169           out << ", ";
170         }
171         out << terms[i];
172       }
173       out << " )" << std::endl;
174     }
175   }
176   else
177   {
178     for (const std::pair<const Node, InstMatchTrie>& d : d_data)
179     {
180       terms.push_back(d.first);
181       d.second.print(out, q, terms, firstTime, useActive, active);
182       terms.pop_back();
183     }
184   }
185 }
186 
getInstantiations(std::vector<Node> & insts,Node q,std::vector<Node> & terms,QuantifiersEngine * qe,bool useActive,std::vector<Node> & active) const187 void InstMatchTrie::getInstantiations(std::vector<Node>& insts,
188                                       Node q,
189                                       std::vector<Node>& terms,
190                                       QuantifiersEngine* qe,
191                                       bool useActive,
192                                       std::vector<Node>& active) const
193 {
194   if (terms.size() == q[0].getNumChildren())
195   {
196     if (useActive)
197     {
198       if (hasInstLemma())
199       {
200         Node lem = getInstLemma();
201         if (std::find(active.begin(), active.end(), lem) != active.end())
202         {
203           insts.push_back(lem);
204         }
205       }
206     }
207     else
208     {
209       if (hasInstLemma())
210       {
211         insts.push_back(getInstLemma());
212       }
213       else
214       {
215         insts.push_back(qe->getInstantiate()->getInstantiation(q, terms, true));
216       }
217     }
218   }
219   else
220   {
221     for (const std::pair<const Node, InstMatchTrie>& d : d_data)
222     {
223       terms.push_back(d.first);
224       d.second.getInstantiations(insts, q, terms, qe, useActive, active);
225       terms.pop_back();
226     }
227   }
228 }
229 
getExplanationForInstLemmas(Node q,std::vector<Node> & terms,const std::vector<Node> & lems,std::map<Node,Node> & quant,std::map<Node,std::vector<Node>> & tvec) const230 void InstMatchTrie::getExplanationForInstLemmas(
231     Node q,
232     std::vector<Node>& terms,
233     const std::vector<Node>& lems,
234     std::map<Node, Node>& quant,
235     std::map<Node, std::vector<Node> >& tvec) const
236 {
237   if (terms.size() == q[0].getNumChildren())
238   {
239     if (hasInstLemma())
240     {
241       Node lem = getInstLemma();
242       if (std::find(lems.begin(), lems.end(), lem) != lems.end())
243       {
244         quant[lem] = q;
245         tvec[lem].clear();
246         tvec[lem].insert(tvec[lem].end(), terms.begin(), terms.end());
247       }
248     }
249   }
250   else
251   {
252     for (const std::pair<const Node, InstMatchTrie>& d : d_data)
253     {
254       terms.push_back(d.first);
255       d.second.getExplanationForInstLemmas(q, terms, lems, quant, tvec);
256       terms.pop_back();
257     }
258   }
259 }
260 
~CDInstMatchTrie()261 CDInstMatchTrie::~CDInstMatchTrie()
262 {
263   for (std::pair<const Node, CDInstMatchTrie*>& d : d_data)
264   {
265     CDInstMatchTrie* current = d.second;
266     delete current;
267   }
268   d_data.clear();
269 }
270 
addInstMatch(QuantifiersEngine * qe,Node f,std::vector<Node> & m,context::Context * c,bool modEq,unsigned index,bool onlyExist)271 bool CDInstMatchTrie::addInstMatch(QuantifiersEngine* qe,
272                                    Node f,
273                                    std::vector<Node>& m,
274                                    context::Context* c,
275                                    bool modEq,
276                                    unsigned index,
277                                    bool onlyExist)
278 {
279   bool reset = false;
280   if (!d_valid.get())
281   {
282     if (onlyExist)
283     {
284       return true;
285     }
286     else
287     {
288       d_valid.set(true);
289       reset = true;
290     }
291   }
292   if (index == f[0].getNumChildren())
293   {
294     return reset;
295   }
296   Node n = m[index];
297   std::map<Node, CDInstMatchTrie*>::iterator it = d_data.find(n);
298   if (it != d_data.end())
299   {
300     bool ret =
301         it->second->addInstMatch(qe, f, m, c, modEq, index + 1, onlyExist);
302     if (!onlyExist || !ret)
303     {
304       return reset || ret;
305     }
306   }
307   if (modEq)
308   {
309     // check modulo equality if any other instantiation match exists
310     if (!n.isNull() && qe->getEqualityQuery()->getEngine()->hasTerm(n))
311     {
312       eq::EqClassIterator eqc(
313           qe->getEqualityQuery()->getEngine()->getRepresentative(n),
314           qe->getEqualityQuery()->getEngine());
315       while (!eqc.isFinished())
316       {
317         Node en = (*eqc);
318         if (en != n)
319         {
320           std::map<Node, CDInstMatchTrie*>::iterator itc = d_data.find(en);
321           if (itc != d_data.end())
322           {
323             if (itc->second->addInstMatch(qe, f, m, c, modEq, index + 1, true))
324             {
325               return false;
326             }
327           }
328         }
329         ++eqc;
330       }
331     }
332   }
333 
334   if (!onlyExist)
335   {
336     // std::map< Node, CDInstMatchTrie* >::iterator it = d_data.find( n );
337     CDInstMatchTrie* imt = new CDInstMatchTrie(c);
338     Assert(d_data.find(n) == d_data.end());
339     d_data[n] = imt;
340     imt->addInstMatch(qe, f, m, c, modEq, index + 1, false);
341   }
342   return true;
343 }
344 
removeInstMatch(Node q,std::vector<Node> & m,unsigned index)345 bool CDInstMatchTrie::removeInstMatch(Node q,
346                                       std::vector<Node>& m,
347                                       unsigned index)
348 {
349   if (index == q[0].getNumChildren())
350   {
351     if (d_valid.get())
352     {
353       d_valid.set(false);
354       return true;
355     }
356     return false;
357   }
358   std::map<Node, CDInstMatchTrie*>::iterator it = d_data.find(m[index]);
359   if (it != d_data.end())
360   {
361     return it->second->removeInstMatch(q, m, index + 1);
362   }
363   return false;
364 }
365 
recordInstLemma(Node q,std::vector<Node> & m,Node lem,unsigned index)366 bool CDInstMatchTrie::recordInstLemma(Node q,
367                                       std::vector<Node>& m,
368                                       Node lem,
369                                       unsigned index)
370 {
371   if (index == q[0].getNumChildren())
372   {
373     if (d_valid.get())
374     {
375       setInstLemma(lem);
376       return true;
377     }
378     return false;
379   }
380   std::map<Node, CDInstMatchTrie*>::iterator it = d_data.find(m[index]);
381   if (it != d_data.end())
382   {
383     return it->second->recordInstLemma(q, m, lem, index + 1);
384   }
385   return false;
386 }
387 
print(std::ostream & out,Node q,std::vector<TNode> & terms,bool & firstTime,bool useActive,std::vector<Node> & active) const388 void CDInstMatchTrie::print(std::ostream& out,
389                             Node q,
390                             std::vector<TNode>& terms,
391                             bool& firstTime,
392                             bool useActive,
393                             std::vector<Node>& active) const
394 {
395   if (d_valid.get())
396   {
397     if (terms.size() == q[0].getNumChildren())
398     {
399       bool print;
400       if (useActive)
401       {
402         if (hasInstLemma())
403         {
404           Node lem = getInstLemma();
405           print = std::find(active.begin(), active.end(), lem) != active.end();
406         }
407         else
408         {
409           print = false;
410         }
411       }
412       else
413       {
414         print = true;
415       }
416       if (print)
417       {
418         if (firstTime)
419         {
420           out << "(instantiation " << q << std::endl;
421           firstTime = false;
422         }
423         out << "  ( ";
424         for (unsigned i = 0; i < terms.size(); i++)
425         {
426           if (i > 0) out << " ";
427           out << terms[i];
428         }
429         out << " )" << std::endl;
430       }
431     }
432     else
433     {
434       for (const std::pair<const Node, CDInstMatchTrie*>& d : d_data)
435       {
436         terms.push_back(d.first);
437         d.second->print(out, q, terms, firstTime, useActive, active);
438         terms.pop_back();
439       }
440     }
441   }
442 }
443 
getInstantiations(std::vector<Node> & insts,Node q,std::vector<Node> & terms,QuantifiersEngine * qe,bool useActive,std::vector<Node> & active) const444 void CDInstMatchTrie::getInstantiations(std::vector<Node>& insts,
445                                         Node q,
446                                         std::vector<Node>& terms,
447                                         QuantifiersEngine* qe,
448                                         bool useActive,
449                                         std::vector<Node>& active) const
450 {
451   if (d_valid.get())
452   {
453     if (terms.size() == q[0].getNumChildren())
454     {
455       if (useActive)
456       {
457         if (hasInstLemma())
458         {
459           Node lem = getInstLemma();
460           if (std::find(active.begin(), active.end(), lem) != active.end())
461           {
462             insts.push_back(lem);
463           }
464         }
465       }
466       else
467       {
468         if (hasInstLemma())
469         {
470           insts.push_back(getInstLemma());
471         }
472         else
473         {
474           insts.push_back(
475               qe->getInstantiate()->getInstantiation(q, terms, true));
476         }
477       }
478     }
479     else
480     {
481       for (const std::pair<const Node, CDInstMatchTrie*>& d : d_data)
482       {
483         terms.push_back(d.first);
484         d.second->getInstantiations(insts, q, terms, qe, useActive, active);
485         terms.pop_back();
486       }
487     }
488   }
489 }
490 
getExplanationForInstLemmas(Node q,std::vector<Node> & terms,const std::vector<Node> & lems,std::map<Node,Node> & quant,std::map<Node,std::vector<Node>> & tvec) const491 void CDInstMatchTrie::getExplanationForInstLemmas(
492     Node q,
493     std::vector<Node>& terms,
494     const std::vector<Node>& lems,
495     std::map<Node, Node>& quant,
496     std::map<Node, std::vector<Node> >& tvec) const
497 {
498   if (d_valid.get())
499   {
500     if (terms.size() == q[0].getNumChildren())
501     {
502       if (hasInstLemma())
503       {
504         Node lem;
505         if (std::find(lems.begin(), lems.end(), lem) != lems.end())
506         {
507           quant[lem] = q;
508           tvec[lem].clear();
509           tvec[lem].insert(tvec[lem].end(), terms.begin(), terms.end());
510         }
511       }
512     }
513     else
514     {
515       for (const std::pair<const Node, CDInstMatchTrie*>& d : d_data)
516       {
517         terms.push_back(d.first);
518         d.second->getExplanationForInstLemmas(q, terms, lems, quant, tvec);
519         terms.pop_back();
520       }
521     }
522   }
523 }
524 
525 } /* CVC4::theory::inst namespace */
526 } /* CVC4::theory namespace */
527 } /* CVC4 namespace */
528