1 /*********************                                                        */
2 /*! \file candidate_rewrite_filter.cpp
3  ** \verbatim
4  ** Top contributors (to current version):
5  **   Andrew Reynolds
6  ** This file is part of the CVC4 project.
7  ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8  ** in the top-level source directory) and their institutional affiliations.
9  ** All rights reserved.  See the file COPYING in the top-level source
10  ** directory for licensing information.\endverbatim
11  **
12  ** \brief Implements techniques for candidate rewrite rule filtering, which
13  ** filters the output of --sygus-rr-synth. The classes in this file implement
14  ** filtering based on congruence, variable ordering, and matching.
15  **/
16 
17 #include "theory/quantifiers/candidate_rewrite_filter.h"
18 
19 #include "options/base_options.h"
20 #include "options/quantifiers_options.h"
21 #include "printer/printer.h"
22 
23 using namespace CVC4::kind;
24 
25 namespace CVC4 {
26 namespace theory {
27 namespace quantifiers {
28 
getMatches(Node n,NotifyMatch * ntm)29 bool MatchTrie::getMatches(Node n, NotifyMatch* ntm)
30 {
31   std::vector<Node> vars;
32   std::vector<Node> subs;
33   std::map<Node, Node> smap;
34 
35   std::vector<std::vector<Node> > visit;
36   std::vector<MatchTrie*> visit_trie;
37   std::vector<int> visit_var_index;
38   std::vector<bool> visit_bound_var;
39 
40   visit.push_back(std::vector<Node>{n});
41   visit_trie.push_back(this);
42   visit_var_index.push_back(-1);
43   visit_bound_var.push_back(false);
44   while (!visit.empty())
45   {
46     std::vector<Node> cvisit = visit.back();
47     MatchTrie* curr = visit_trie.back();
48     if (cvisit.empty())
49     {
50       Assert(n
51              == curr->d_data.substitute(
52                     vars.begin(), vars.end(), subs.begin(), subs.end()));
53       Trace("crf-match-debug") << "notify : " << curr->d_data << std::endl;
54       if (!ntm->notify(n, curr->d_data, vars, subs))
55       {
56         return false;
57       }
58       visit.pop_back();
59       visit_trie.pop_back();
60       visit_var_index.pop_back();
61       visit_bound_var.pop_back();
62     }
63     else
64     {
65       Node cn = cvisit.back();
66       Trace("crf-match-debug") << "traverse : " << cn << " at depth "
67                                << visit.size() << std::endl;
68       unsigned index = visit.size() - 1;
69       int vindex = visit_var_index[index];
70       if (vindex == -1)
71       {
72         if (!cn.isVar())
73         {
74           Node op = cn.hasOperator() ? cn.getOperator() : cn;
75           unsigned nchild = cn.hasOperator() ? cn.getNumChildren() : 0;
76           std::map<unsigned, MatchTrie>::iterator itu =
77               curr->d_children[op].find(nchild);
78           if (itu != curr->d_children[op].end())
79           {
80             // recurse on the operator or self
81             cvisit.pop_back();
82             if (cn.hasOperator())
83             {
84               for (const Node& cnc : cn)
85               {
86                 cvisit.push_back(cnc);
87               }
88             }
89             Trace("crf-match-debug") << "recurse op : " << op << std::endl;
90             visit.push_back(cvisit);
91             visit_trie.push_back(&itu->second);
92             visit_var_index.push_back(-1);
93             visit_bound_var.push_back(false);
94           }
95         }
96         visit_var_index[index]++;
97       }
98       else
99       {
100         // clean up if we previously bound a variable
101         if (visit_bound_var[index])
102         {
103           Assert(!vars.empty());
104           smap.erase(vars.back());
105           vars.pop_back();
106           subs.pop_back();
107           visit_bound_var[index] = false;
108         }
109 
110         if (vindex == static_cast<int>(curr->d_vars.size()))
111         {
112           Trace("crf-match-debug")
113               << "finished checking " << curr->d_vars.size()
114               << " variables at depth " << visit.size() << std::endl;
115           // finished
116           visit.pop_back();
117           visit_trie.pop_back();
118           visit_var_index.pop_back();
119           visit_bound_var.pop_back();
120         }
121         else
122         {
123           Trace("crf-match-debug") << "check variable #" << vindex
124                                    << " at depth " << visit.size() << std::endl;
125           Assert(vindex < static_cast<int>(curr->d_vars.size()));
126           // recurse on variable?
127           Node var = curr->d_vars[vindex];
128           bool recurse = true;
129           // check if it is already bound
130           std::map<Node, Node>::iterator its = smap.find(var);
131           if (its != smap.end())
132           {
133             if (its->second != cn)
134             {
135               recurse = false;
136             }
137           }
138           else
139           {
140             vars.push_back(var);
141             subs.push_back(cn);
142             smap[var] = cn;
143             visit_bound_var[index] = true;
144           }
145           if (recurse)
146           {
147             Trace("crf-match-debug") << "recurse var : " << var << std::endl;
148             cvisit.pop_back();
149             visit.push_back(cvisit);
150             visit_trie.push_back(&curr->d_children[var][0]);
151             visit_var_index.push_back(-1);
152             visit_bound_var.push_back(false);
153           }
154           visit_var_index[index]++;
155         }
156       }
157     }
158   }
159   return true;
160 }
161 
addTerm(Node n)162 void MatchTrie::addTerm(Node n)
163 {
164   Assert(!n.isNull());
165   std::vector<Node> visit;
166   visit.push_back(n);
167   MatchTrie* curr = this;
168   while (!visit.empty())
169   {
170     Node cn = visit.back();
171     visit.pop_back();
172     if (cn.hasOperator())
173     {
174       curr = &(curr->d_children[cn.getOperator()][cn.getNumChildren()]);
175       for (const Node& cnc : cn)
176       {
177         visit.push_back(cnc);
178       }
179     }
180     else
181     {
182       if (cn.isVar()
183           && std::find(curr->d_vars.begin(), curr->d_vars.end(), cn)
184                  == curr->d_vars.end())
185       {
186         curr->d_vars.push_back(cn);
187       }
188       curr = &(curr->d_children[cn][0]);
189     }
190   }
191   curr->d_data = n;
192 }
193 
clear()194 void MatchTrie::clear()
195 {
196   d_children.clear();
197   d_vars.clear();
198   d_data = Node::null();
199 }
200 
201 // the number of d_drewrite objects we have allocated (to avoid name conflicts)
202 static unsigned drewrite_counter = 0;
203 
CandidateRewriteFilter()204 CandidateRewriteFilter::CandidateRewriteFilter()
205     : d_ss(nullptr),
206       d_tds(nullptr),
207       d_use_sygus_type(false),
208       d_drewrite(nullptr),
209       d_ssenm(*this)
210 {
211 }
212 
initialize(SygusSampler * ss,TermDbSygus * tds,bool useSygusType)213 void CandidateRewriteFilter::initialize(SygusSampler* ss,
214                                         TermDbSygus* tds,
215                                         bool useSygusType)
216 {
217   d_ss = ss;
218   d_use_sygus_type = useSygusType;
219   d_tds = tds;
220   // initialize members of this class
221   d_match_trie.clear();
222   d_pairs.clear();
223   // (re)initialize the dynamic rewriter
224   std::stringstream ssn;
225   ssn << "_dyn_rewriter_" << drewrite_counter;
226   drewrite_counter++;
227   d_drewrite = std::unique_ptr<DynamicRewriter>(
228       new DynamicRewriter(ssn.str(), &d_fake_context));
229 }
230 
filterPair(Node n,Node eq_n)231 bool CandidateRewriteFilter::filterPair(Node n, Node eq_n)
232 {
233   Node bn = n;
234   Node beq_n = eq_n;
235   if (d_use_sygus_type)
236   {
237     bn = d_tds->sygusToBuiltin(n);
238     beq_n = d_tds->sygusToBuiltin(eq_n);
239   }
240   Trace("cr-filter") << "crewriteFilter : " << bn << "..." << beq_n
241                      << std::endl;
242   // whether we will keep this pair
243   bool keep = true;
244 
245   // ----- check redundancy based on variables
246   if (options::sygusRewSynthFilterOrder()
247       || options::sygusRewSynthFilterNonLinear())
248   {
249     bool nor = d_ss->checkVariables(bn,
250                                     options::sygusRewSynthFilterOrder(),
251                                     options::sygusRewSynthFilterNonLinear());
252     bool eqor = d_ss->checkVariables(beq_n,
253                                      options::sygusRewSynthFilterOrder(),
254                                      options::sygusRewSynthFilterNonLinear());
255     Trace("cr-filter-debug")
256         << "Variables ok? : " << nor << " " << eqor << std::endl;
257     if (eqor || nor)
258     {
259       // if only one is ordered, then we require that the ordered one's
260       // variables cannot be a strict subset of the variables of the other.
261       if (!eqor)
262       {
263         if (d_ss->containsFreeVariables(beq_n, bn, true))
264         {
265           keep = false;
266         }
267         else
268         {
269           // if the previous value stored was unordered, but n is
270           // ordered, we prefer n. Thus, we force its addition to the
271           // sampler database.
272           d_ss->registerTerm(n, true);
273         }
274       }
275       else if (!nor)
276       {
277         keep = !d_ss->containsFreeVariables(bn, beq_n, true);
278       }
279     }
280     else
281     {
282       keep = false;
283     }
284     if (!keep)
285     {
286       Trace("cr-filter") << "...redundant (unordered)" << std::endl;
287     }
288   }
289 
290   // ----- check rewriting redundancy
291   if (keep && options::sygusRewSynthFilterCong())
292   {
293     // When using sygus types, this filtering applies to the builtin versions
294     // of n and eq_n. This means that we may filter out a rewrite rule for one
295     // sygus type based on another, e.g. we won't report x=x+0 for both A and B
296     // in:
297     //   A -> x | 0 | A+A
298     //   B -> x | 0 | B+B
299     Trace("cr-filter-debug") << "Check equal rewrite pair..." << std::endl;
300     if (d_drewrite->areEqual(bn, beq_n))
301     {
302       // must be unique according to the dynamic rewriter
303       Trace("cr-filter") << "...redundant (rewritable)" << std::endl;
304       keep = false;
305     }
306   }
307 
308   if (keep && options::sygusRewSynthFilterMatch())
309   {
310     // ----- check matchable
311     // check whether the pair is matchable with a previous one
312     d_curr_pair_rhs = beq_n;
313     Trace("crf-match") << "CRF check matches : " << bn << " [rhs = " << beq_n
314                        << "]..." << std::endl;
315     Node bni = d_drewrite->toInternal(bn);
316     if (!bni.isNull())
317     {
318       // as with congruence filtering, we cache based on the builtin type
319       TypeNode tn = bn.getType();
320       if (!d_match_trie[tn].getMatches(bni, &d_ssenm))
321       {
322         keep = false;
323         Trace("cr-filter") << "...redundant (matchable)" << std::endl;
324         // regardless, would help to remember it
325         registerRelevantPair(n, eq_n);
326       }
327     }
328     // if bni is null, it may involve non-first-class types that we cannot
329     // reason about
330   }
331 
332   if (keep)
333   {
334     return false;
335   }
336   if (Trace.isOn("sygus-rr-filter"))
337   {
338     Printer* p = Printer::getPrinter(options::outputLanguage());
339     std::stringstream ss;
340     ss << "(redundant-rewrite ";
341     p->toStreamSygus(ss, n);
342     ss << " ";
343     p->toStreamSygus(ss, eq_n);
344     ss << ")";
345     Trace("sygus-rr-filter") << ss.str() << std::endl;
346   }
347   return true;
348 }
349 
registerRelevantPair(Node n,Node eq_n)350 void CandidateRewriteFilter::registerRelevantPair(Node n, Node eq_n)
351 {
352   Node bn = n;
353   Node beq_n = eq_n;
354   if (d_use_sygus_type)
355   {
356     bn = d_tds->sygusToBuiltin(n);
357     beq_n = d_tds->sygusToBuiltin(eq_n);
358   }
359   // ----- check rewriting redundancy
360   if (options::sygusRewSynthFilterCong())
361   {
362     Trace("cr-filter-debug") << "Add rewrite pair..." << std::endl;
363     Assert(!d_drewrite->areEqual(bn, beq_n));
364     d_drewrite->addRewrite(bn, beq_n);
365   }
366   if (options::sygusRewSynthFilterMatch())
367   {
368     // cache based on the builtin type
369     TypeNode tn = bn.getType();
370     // add to match information
371     for (unsigned r = 0; r < 2; r++)
372     {
373       Node t = r == 0 ? bn : beq_n;
374       Node to = r == 0 ? beq_n : bn;
375       // insert in match trie if first time
376       if (d_pairs.find(t) == d_pairs.end())
377       {
378         Trace("crf-match") << "CRF add term : " << t << std::endl;
379         Node ti = d_drewrite->toInternal(t);
380         if (!ti.isNull())
381         {
382           d_match_trie[tn].addTerm(ti);
383         }
384       }
385       d_pairs[t].insert(to);
386     }
387   }
388 }
389 
notify(Node s,Node n,std::vector<Node> & vars,std::vector<Node> & subs)390 bool CandidateRewriteFilter::notify(Node s,
391                                     Node n,
392                                     std::vector<Node>& vars,
393                                     std::vector<Node>& subs)
394 {
395   Trace("crf-match-debug") << "Got : " << s << " matches " << n << std::endl;
396   Assert(!d_curr_pair_rhs.isNull());
397   // convert back to original forms
398   s = d_drewrite->toExternal(s);
399   Assert(!s.isNull());
400   n = d_drewrite->toExternal(n);
401   Assert(!n.isNull());
402   std::map<Node, std::unordered_set<Node, NodeHashFunction> >::iterator it =
403       d_pairs.find(n);
404   if (Trace.isOn("crf-match"))
405   {
406     Trace("crf-match") << "  " << s << " matches " << n
407                        << " under:" << std::endl;
408     for (unsigned i = 0, size = vars.size(); i < size; i++)
409     {
410       Trace("crf-match") << "    " << vars[i] << " -> " << subs[i] << std::endl;
411     }
412   }
413 #ifdef CVC4_ASSERTIONS
414   for (unsigned i = 0, size = vars.size(); i < size; i++)
415   {
416     // By using internal representation of terms, we ensure polymorphism is
417     // handled correctly.
418     Assert(vars[i].getType().isComparableTo(subs[i].getType()));
419   }
420 #endif
421   // must convert the inferred substitution to original form
422   std::vector<Node> esubs;
423   for (const Node& s : subs)
424   {
425     esubs.push_back(d_drewrite->toExternal(s));
426   }
427   Assert(it != d_pairs.end());
428   for (const Node& nr : it->second)
429   {
430     Node nrs =
431         nr.substitute(vars.begin(), vars.end(), esubs.begin(), esubs.end());
432     bool areEqual = (nrs == d_curr_pair_rhs);
433     if (!areEqual && options::sygusRewSynthFilterCong())
434     {
435       // if dynamic rewriter is available, consult it
436       areEqual = d_drewrite->areEqual(nrs, d_curr_pair_rhs);
437     }
438     if (areEqual)
439     {
440       Trace("crf-match") << "*** Match, current pair: " << std::endl;
441       Trace("crf-match") << "  (" << s << ", " << d_curr_pair_rhs << ")"
442                          << std::endl;
443       Trace("crf-match") << "is an instance of previous pair:" << std::endl;
444       Trace("crf-match") << "  (" << n << ", " << nr << ")" << std::endl;
445       return false;
446     }
447   }
448   return true;
449 }
450 
451 }  // namespace quantifiers
452 }  // namespace theory
453 }  // namespace CVC4
454