1 /*++
2 Copyright (c) 2007 Microsoft Corporation
3 
4 Module Name:
5 
6     ast_pattern_match.cpp
7 
8 Abstract:
9 
10     Search for opportune pattern matching utilities.
11 
12 Author:
13 
14     Nikolaj Bjorner (nbjorner) 2007-04-10
15     Leonardo (leonardo)
16 
17 Notes:
18 
19     instead of the brute force enumeration of permutations
20     we can add an instruction 'gate' which copies the ast
21     into a register and creates another register with the same
22     term. Matching against a 'gate' is a noop, apart from clearing
23     the ast in the register. Then on backtracking we know how many
24     terms were matched from the permutation. It does not make sense
25     to enumerate all combinations of terms that were not considered, so
26     skip these.
27 
28     Also, compilation should re-order terms to fail fast.
29 
30 --*/
31 
32 #include "ast/ast.h"
33 #include "ast/pattern/expr_pattern_match.h"
34 #include "ast/for_each_ast.h"
35 #include "ast/ast_ll_pp.h"
36 #include "ast/ast_pp.h"
37 #include "cmd_context/cmd_context.h"
38 #include "parsers/smt2/smt2parser.h"
39 
expr_pattern_match(ast_manager & manager)40 expr_pattern_match::expr_pattern_match(ast_manager & manager):
41     m_manager(manager), m_precompiled(manager) {
42 }
43 
~expr_pattern_match()44 expr_pattern_match::~expr_pattern_match() {
45 }
46 
47 bool
match_quantifier(quantifier * qf,app_ref_vector & patterns,unsigned & weight)48 expr_pattern_match::match_quantifier(quantifier* qf, app_ref_vector& patterns, unsigned& weight) {
49     if (m_regs.empty()) {
50         // HACK: the code crashes if database is empty.
51         return false;
52     }
53     m_regs[0] = qf->get_expr();
54     for (unsigned i = 0; i < m_precompiled.size(); ++i) {
55         if (match_quantifier(i, qf, patterns, weight))
56             return true;
57     }
58     return false;
59 }
60 
61 bool
match_quantifier(unsigned i,quantifier * qf,app_ref_vector & patterns,unsigned & weight)62 expr_pattern_match::match_quantifier(unsigned i, quantifier* qf, app_ref_vector& patterns, unsigned& weight) {
63     quantifier* qf2 = m_precompiled[i].get();
64     if (qf2->get_kind() != qf->get_kind() || is_lambda(qf)) {
65         return false;
66     }
67     if (qf2->get_num_decls() != qf->get_num_decls()) {
68         return false;
69     }
70     subst s;
71     if (match(qf->get_expr(), m_first_instrs[i], s)) {
72         for (unsigned j = 0; j < qf2->get_num_patterns(); ++j) {
73             app* p = static_cast<app*>(qf2->get_pattern(j));
74             expr_ref p_result(m_manager);
75             instantiate(p, qf->get_num_decls(), s, p_result);
76             patterns.push_back(to_app(p_result.get()));
77         }
78         weight = qf2->get_weight();
79         return true;
80     }
81     return false;
82 }
83 
match_quantifier_index(quantifier * qf,app_ref_vector & patterns,unsigned & index)84 bool expr_pattern_match::match_quantifier_index(quantifier* qf, app_ref_vector& patterns, unsigned& index) {
85     if (m_regs.empty()) return false;
86     m_regs[0] = qf->get_expr();
87 
88     for (unsigned i = 0; i < m_precompiled.size(); ++i) {
89         unsigned weight = 0;
90         if (match_quantifier(i, qf, patterns, weight)) {
91             index = i;
92             return true;
93         }
94     }
95     return false;
96 }
97 
98 
99 void
instantiate(expr * a,unsigned num_bound,subst & s,expr_ref & result)100 expr_pattern_match::instantiate(expr* a, unsigned num_bound, subst& s, expr_ref& result) {
101     bound b;
102     for (unsigned i = 0; i < num_bound; ++i) {
103         b.insert(m_bound_dom[i], m_bound_rng[i]);
104     }
105     TRACE("expr_pattern_match", tout << mk_pp(a, m_manager) << " " << num_bound << "\n";);
106     inst_proc proc(m_manager, s, b, m_regs);
107     for_each_ast(proc, a);
108     expr* v = nullptr;
109     proc.m_memoize.find(a, v);
110     SASSERT(v);
111     result = v;
112 }
113 
114 
115 
116 void
compile(expr * q)117 expr_pattern_match::compile(expr* q)
118 {
119     SASSERT(q->get_kind() == AST_QUANTIFIER);
120     quantifier* qf = to_quantifier(q);
121     unsigned ip = m_instrs.size();
122     m_first_instrs.push_back(ip);
123     m_precompiled.push_back(qf);
124 
125     instr instr(BACKTRACK);
126     unsigned_vector regs;
127     ptr_vector<expr> pats;
128     unsigned max_reg = 1;
129     subst s;
130     pats.push_back(qf->get_expr());
131     regs.push_back(0);
132     unsigned num_bound = 0;
133     obj_map<var, unsigned> bound;
134 
135     while (!pats.empty()) {
136 
137         unsigned reg = regs.back();
138         expr* pat     = pats.back();
139         regs.pop_back();
140         pats.pop_back();
141 
142         instr.m_pat  = pat;
143         instr.m_next = m_instrs.size()+1;
144         instr.m_reg  = reg;
145         instr.m_offset = max_reg;
146 
147         switch(pat->get_kind()) {
148         case AST_VAR: {
149             var* b = to_var(pat);
150             if (bound.find(b, instr.m_num_bound)) {
151                 instr.m_kind = CHECK_BOUND;
152             }
153             else {
154                 instr.m_kind = SET_BOUND;
155                 instr.m_num_bound = num_bound;
156                 bound.insert(b, num_bound);
157                 ++num_bound;
158             }
159             break;
160         }
161         case AST_APP: {
162             unsigned r = 0;
163             app* app = to_app(pat);
164             func_decl* d  = app->get_decl();
165 
166             for (unsigned i = 0; i < app->get_num_args(); ++i) {
167                 regs.push_back(max_reg);
168                 pats.push_back(app->get_arg(i));
169                 ++max_reg;
170             }
171 
172             if (is_var(d)) {
173                 if (s.find(d, r)) {
174                     instr.m_kind = CHECK_VAR;
175                     instr.m_other_reg = r;
176                 }
177                 else {
178                     instr.m_kind = SET_VAR;
179                     s.insert(d, reg);
180                 }
181             }
182             else {
183                 if (d->is_associative() && d->is_commutative()) {
184                     instr.m_kind = BIND_AC;
185                 }
186                 else if (d->is_commutative()) {
187                     SASSERT(app->get_num_args() == 2);
188                     instr.m_kind = BIND_C;
189                 }
190                 else {
191                     instr.m_kind = BIND;
192                 }
193             }
194             break;
195         }
196         default:
197             instr.m_kind = CHECK_TERM;
198             break;
199         }
200         m_instrs.push_back(instr);
201     }
202 
203     if (m_regs.size() <= max_reg) {
204         m_regs.resize(max_reg+1);
205     }
206     if (m_bound_dom.size() <= num_bound) {
207         m_bound_dom.resize(num_bound+1);
208         m_bound_rng.resize(num_bound+1);
209     }
210 
211     instr.m_kind = YIELD;
212     m_instrs.push_back(instr);
213 }
214 
215 
216 bool
match(expr * a,unsigned init,subst & s)217 expr_pattern_match::match(expr* a, unsigned init, subst& s)
218 {
219     svector<instr> bstack;
220     instr pc = m_instrs[init];
221 
222     while (true) {
223         bool ok = false;
224         switch(pc.m_kind) {
225         case YIELD:
226             // substitution s contains registers with matching declarations.
227             return true;
228         case CHECK_TERM:
229             ok = (pc.m_pat == m_regs[pc.m_reg]);
230             break;
231         case SET_VAR:
232         case CHECK_VAR: {
233             app* app1 = to_app(pc.m_pat);
234             a   = m_regs[pc.m_reg];
235             if (a->get_kind() != AST_APP) {
236                 break;
237             }
238             app* app2 = to_app(a);
239             if (app1->get_num_args() != app2->get_num_args()) {
240                 break;
241             }
242             if (pc.m_kind == CHECK_VAR &&
243                 to_app(m_regs[pc.m_reg])->get_decl() !=
244                 to_app(m_regs[pc.m_other_reg])->get_decl()) {
245                 break;
246             }
247             for (unsigned i = 0; i < app2->get_num_args(); ++i) {
248                 m_regs[pc.m_offset + i] = app2->get_arg(i);
249             }
250             if (pc.m_kind == SET_VAR) {
251                 s.insert(app1->get_decl(), pc.m_reg);
252             }
253             ok = true;
254             break;
255         }
256         case SET_BOUND: {
257             a = m_regs[pc.m_reg];
258             if (a->get_kind() != AST_VAR) {
259                 break;
260             }
261             ok = true;
262             var* var_a = to_var(a);
263             var* var_p = to_var(pc.m_pat);
264             // check that the mapping of bound variables remains a bijection.
265             for (unsigned i = 0; ok && i < pc.m_num_bound; ++i) {
266                 ok = (a != m_bound_rng[i]);
267             }
268             if (!ok) {
269                 break;
270             }
271             m_bound_dom[pc.m_num_bound] = var_p;
272             m_bound_rng[pc.m_num_bound] = var_a;
273             break;
274         }
275         case CHECK_BOUND:
276             TRACE("expr_pattern_match", tout << "check bound " << pc.m_num_bound << " " << pc.m_reg << "\n";);
277             ok = m_bound_rng[pc.m_num_bound] == m_regs[pc.m_reg];
278             break;
279         case BIND:
280         case BIND_AC:
281         case BIND_C: {
282             TRACE("expr_pattern_match", display(tout, pc);
283                   tout << mk_pp(m_regs[pc.m_reg],m_manager) << "\n";);
284             app* app1 = to_app(pc.m_pat);
285             a   = m_regs[pc.m_reg];
286             if (a->get_kind() != AST_APP) {
287                 break;
288             }
289             app* app2 = to_app(a);
290             if (app1->get_num_args() != app2->get_num_args()) {
291                 break;
292             }
293             if (!match_decl(app1->get_decl(), app2->get_decl())) {
294                 break;
295             }
296             switch(pc.m_kind) {
297             case BIND:
298                 for (unsigned i = 0; i < app2->get_num_args(); ++i) {
299                     m_regs[pc.m_offset + i] = app2->get_arg(i);
300                 }
301                 ok = true;
302                 break; // process the next instruction.
303             case BIND_AC:
304                 // push CHOOSE_AC on the backtracking stack.
305                 bstack.push_back(instr(CHOOSE_AC, pc.m_offset, pc.m_next, app2, 1));
306                 break;
307             case BIND_C:
308                 // push CHOOSE_C on the backtracking stack.
309                 ok = true;
310                 m_regs[pc.m_offset]   = app2->get_arg(0);
311                 m_regs[pc.m_offset+1] = app2->get_arg(1);
312                 bstack.push_back(instr(CHOOSE_C, pc.m_offset, pc.m_next, app2, 2));
313                 break;
314             default:
315                 break;
316             }
317             break;
318         }
319         case CHOOSE_C:
320             ok = true;
321             SASSERT (pc.m_count == 2);
322             m_regs[pc.m_offset+1] = pc.m_app->get_arg(0);
323             m_regs[pc.m_offset]   = pc.m_app->get_arg(1);
324             break;
325         case CHOOSE_AC: {
326             ok = true;
327             app* app2 = pc.m_app;
328             for (unsigned i = 0; i < app2->get_num_args(); ++i) {
329                 m_regs[pc.m_offset + i] = app2->get_arg(i);
330             }
331             // generate the k'th permutation.
332             unsigned k = pc.m_count;
333             unsigned fac = 1;
334             unsigned num_args = pc.m_app->get_num_args();
335             for (unsigned j = 2; j <= num_args; ++j) {
336                 fac *= (j-1);
337                 SASSERT(((k /fac) % j) + 1 <= j);
338                 std::swap(m_regs[pc.m_offset + j - 1], m_regs[pc.m_offset + j - ((k / fac) % j) - 1]);
339             }
340             if (k < fac*num_args) {
341                 bstack.push_back(instr(CHOOSE_AC, pc.m_offset, pc.m_next, app2, k+1));
342             }
343             break;
344         }
345         case BACKTRACK:
346             if (bstack.empty()) {
347                 return false;
348             }
349             pc = bstack.back();
350             bstack.pop_back();
351             continue; // with the loop.
352         }
353 
354         if (ok) {
355             pc = m_instrs[pc.m_next];
356         }
357         else {
358             TRACE("expr_pattern_match", tout << "backtrack\n";);
359             pc = m_instrs[0];
360         }
361     }
362 }
363 
364 
365 bool
match_decl(func_decl const * pat,func_decl const * d) const366 expr_pattern_match::match_decl(func_decl const * pat, func_decl const * d) const {
367     if (pat == d) {
368         return true;
369     }
370     if (pat->get_arity() != d->get_arity()) {
371         return false;
372     }
373     // match families
374     if (pat->get_family_id() == null_family_id) {
375         return false;
376     }
377     if (d->get_family_id() != pat->get_family_id()) {
378         return false;
379     }
380     if (d->get_decl_kind() != pat->get_decl_kind()) {
381         return false;
382     }
383     if (d->get_num_parameters() != pat->get_num_parameters()) {
384         return false;
385     }
386     for (unsigned i = 0; i < d->get_num_parameters(); ++i) {
387         if (!(d->get_parameter(i) == pat->get_parameter(i))) {
388             return false;
389         }
390     }
391     return true;
392 }
393 
394 bool
is_var(func_decl * d)395 expr_pattern_match::is_var(func_decl* d) {
396     const char* s = d->get_name().bare_str();
397     return s && *s == '?';
398 }
399 
400 void
initialize(char const * spec_string)401 expr_pattern_match::initialize(char const * spec_string) {
402     if (!m_instrs.empty()) {
403         return;
404     }
405     m_instrs.push_back(instr(BACKTRACK));
406 
407     std::istringstream is(spec_string);
408     cmd_context      ctx(true, &m_manager);
409     bool ps = ctx.print_success_enabled();
410     ctx.set_print_success(false);
411     VERIFY(parse_smt2_commands(ctx, is));
412     ctx.set_print_success(ps);
413 
414     for (expr * e : ctx.assertions()) {
415         compile(e);
416     }
417 }
418 
initialize(quantifier * q)419 unsigned expr_pattern_match::initialize(quantifier* q) {
420     if (m_instrs.empty()) {
421         m_instrs.push_back(instr(BACKTRACK));
422     }
423     compile(q);
424     return m_precompiled.size() - 1;
425 }
426 
427 
display(std::ostream & out) const428 void expr_pattern_match::display(std::ostream& out) const {
429     for (unsigned i = 0; i < m_instrs.size(); ++i) {
430         display(out, m_instrs[i]);
431     }
432 }
433 
434 void
display(std::ostream & out,instr const & pc) const435 expr_pattern_match::display(std::ostream& out, instr const& pc) const {
436     switch(pc.m_kind) {
437     case BACKTRACK:
438         out << "backtrack\n";
439         break;
440     case BIND:
441         out << "bind       ";
442         out << mk_pp(pc.m_pat, m_manager) << "\n";
443         out << "next:      " << pc.m_next << "\n";
444         out << "offset:    " << pc.m_offset << "\n";
445         out << "reg:       " << pc.m_reg << "\n";
446         break;
447     case BIND_AC:
448         out << "bind_ac    ";
449         out << mk_pp(pc.m_pat, m_manager) << "\n";
450         out << "next:      " << pc.m_next << "\n";
451         out << "offset:    " << pc.m_offset << "\n";
452         out << "reg:       " << pc.m_reg << "\n";
453         break;
454     case BIND_C:
455         out << "bind_c     ";
456         out << mk_pp(pc.m_pat, m_manager) << "\n";
457         out << "next:      " << pc.m_next << "\n";
458         out << "offset:    " << pc.m_offset << "\n";
459         out << "reg:       " << pc.m_reg << "\n";
460         break;
461     case CHOOSE_AC:
462         out << "choose_ac\n";
463         out << "next:      " << pc.m_next  << "\n";
464         out << "count:     " << pc.m_count << "\n";
465         break;
466     case CHOOSE_C:
467         out << "choose_c\n";
468         out << "next:      " << pc.m_next << "\n";
469         //out << "reg:       " << pc.m_reg << "\n";
470         break;
471     case CHECK_VAR:
472         out << "check_var  ";
473         out << mk_pp(pc.m_pat, m_manager) << "\n";
474         out << "next:      " << pc.m_next << "\n";
475         out << "reg:       " << pc.m_reg << "\n";
476         out << "other_reg: " << pc.m_other_reg << "\n";
477         break;
478     case CHECK_TERM:
479         out << "check      ";
480         out << mk_pp(pc.m_pat, m_manager) << "\n";
481         out << "next:      " << pc.m_next << "\n";
482         out << "reg:       " << pc.m_reg << "\n";
483         break;
484     case YIELD:
485         out << "yield\n";
486         break;
487     case SET_VAR:
488         out << "set_var    ";
489         out << mk_pp(pc.m_pat, m_manager) << "\n";
490         out << "next:      " << pc.m_next << "\n";
491         break;
492     default:
493         break;
494     } }
495 
496 
497 // TBD: fix type overloading.
498 // TBD: bound number of permutations.
499 // TBD: forward pruning checks.
500