1 /*++
2 Copyright (c) 2007 Microsoft Corporation
3
4 Module Name:
5
6 ast_pattern_match.cpp
7
8 Abstract:
9
10 Search for opportune pattern matching utilities.
11
12 Author:
13
14 Nikolaj Bjorner (nbjorner) 2007-04-10
15 Leonardo (leonardo)
16
17 Notes:
18
19 instead of the brute force enumeration of permutations
20 we can add an instruction 'gate' which copies the ast
21 into a register and creates another register with the same
22 term. Matching against a 'gate' is a noop, apart from clearing
23 the ast in the register. Then on backtracking we know how many
24 terms were matched from the permutation. It does not make sense
25 to enumerate all combinations of terms that were not considered, so
26 skip these.
27
28 Also, compilation should re-order terms to fail fast.
29
30 --*/
31
32 #include "ast/ast.h"
33 #include "ast/pattern/expr_pattern_match.h"
34 #include "ast/for_each_ast.h"
35 #include "ast/ast_ll_pp.h"
36 #include "ast/ast_pp.h"
37 #include "cmd_context/cmd_context.h"
38 #include "parsers/smt2/smt2parser.h"
39
expr_pattern_match(ast_manager & manager)40 expr_pattern_match::expr_pattern_match(ast_manager & manager):
41 m_manager(manager), m_precompiled(manager) {
42 }
43
~expr_pattern_match()44 expr_pattern_match::~expr_pattern_match() {
45 }
46
47 bool
match_quantifier(quantifier * qf,app_ref_vector & patterns,unsigned & weight)48 expr_pattern_match::match_quantifier(quantifier* qf, app_ref_vector& patterns, unsigned& weight) {
49 if (m_regs.empty()) {
50 // HACK: the code crashes if database is empty.
51 return false;
52 }
53 m_regs[0] = qf->get_expr();
54 for (unsigned i = 0; i < m_precompiled.size(); ++i) {
55 if (match_quantifier(i, qf, patterns, weight))
56 return true;
57 }
58 return false;
59 }
60
61 bool
match_quantifier(unsigned i,quantifier * qf,app_ref_vector & patterns,unsigned & weight)62 expr_pattern_match::match_quantifier(unsigned i, quantifier* qf, app_ref_vector& patterns, unsigned& weight) {
63 quantifier* qf2 = m_precompiled[i].get();
64 if (qf2->get_kind() != qf->get_kind() || is_lambda(qf)) {
65 return false;
66 }
67 if (qf2->get_num_decls() != qf->get_num_decls()) {
68 return false;
69 }
70 subst s;
71 if (match(qf->get_expr(), m_first_instrs[i], s)) {
72 for (unsigned j = 0; j < qf2->get_num_patterns(); ++j) {
73 app* p = static_cast<app*>(qf2->get_pattern(j));
74 expr_ref p_result(m_manager);
75 instantiate(p, qf->get_num_decls(), s, p_result);
76 patterns.push_back(to_app(p_result.get()));
77 }
78 weight = qf2->get_weight();
79 return true;
80 }
81 return false;
82 }
83
match_quantifier_index(quantifier * qf,app_ref_vector & patterns,unsigned & index)84 bool expr_pattern_match::match_quantifier_index(quantifier* qf, app_ref_vector& patterns, unsigned& index) {
85 if (m_regs.empty()) return false;
86 m_regs[0] = qf->get_expr();
87
88 for (unsigned i = 0; i < m_precompiled.size(); ++i) {
89 unsigned weight = 0;
90 if (match_quantifier(i, qf, patterns, weight)) {
91 index = i;
92 return true;
93 }
94 }
95 return false;
96 }
97
98
99 void
instantiate(expr * a,unsigned num_bound,subst & s,expr_ref & result)100 expr_pattern_match::instantiate(expr* a, unsigned num_bound, subst& s, expr_ref& result) {
101 bound b;
102 for (unsigned i = 0; i < num_bound; ++i) {
103 b.insert(m_bound_dom[i], m_bound_rng[i]);
104 }
105 TRACE("expr_pattern_match", tout << mk_pp(a, m_manager) << " " << num_bound << "\n";);
106 inst_proc proc(m_manager, s, b, m_regs);
107 for_each_ast(proc, a);
108 expr* v = nullptr;
109 proc.m_memoize.find(a, v);
110 SASSERT(v);
111 result = v;
112 }
113
114
115
116 void
compile(expr * q)117 expr_pattern_match::compile(expr* q)
118 {
119 SASSERT(q->get_kind() == AST_QUANTIFIER);
120 quantifier* qf = to_quantifier(q);
121 unsigned ip = m_instrs.size();
122 m_first_instrs.push_back(ip);
123 m_precompiled.push_back(qf);
124
125 instr instr(BACKTRACK);
126 unsigned_vector regs;
127 ptr_vector<expr> pats;
128 unsigned max_reg = 1;
129 subst s;
130 pats.push_back(qf->get_expr());
131 regs.push_back(0);
132 unsigned num_bound = 0;
133 obj_map<var, unsigned> bound;
134
135 while (!pats.empty()) {
136
137 unsigned reg = regs.back();
138 expr* pat = pats.back();
139 regs.pop_back();
140 pats.pop_back();
141
142 instr.m_pat = pat;
143 instr.m_next = m_instrs.size()+1;
144 instr.m_reg = reg;
145 instr.m_offset = max_reg;
146
147 switch(pat->get_kind()) {
148 case AST_VAR: {
149 var* b = to_var(pat);
150 if (bound.find(b, instr.m_num_bound)) {
151 instr.m_kind = CHECK_BOUND;
152 }
153 else {
154 instr.m_kind = SET_BOUND;
155 instr.m_num_bound = num_bound;
156 bound.insert(b, num_bound);
157 ++num_bound;
158 }
159 break;
160 }
161 case AST_APP: {
162 unsigned r = 0;
163 app* app = to_app(pat);
164 func_decl* d = app->get_decl();
165
166 for (unsigned i = 0; i < app->get_num_args(); ++i) {
167 regs.push_back(max_reg);
168 pats.push_back(app->get_arg(i));
169 ++max_reg;
170 }
171
172 if (is_var(d)) {
173 if (s.find(d, r)) {
174 instr.m_kind = CHECK_VAR;
175 instr.m_other_reg = r;
176 }
177 else {
178 instr.m_kind = SET_VAR;
179 s.insert(d, reg);
180 }
181 }
182 else {
183 if (d->is_associative() && d->is_commutative()) {
184 instr.m_kind = BIND_AC;
185 }
186 else if (d->is_commutative()) {
187 SASSERT(app->get_num_args() == 2);
188 instr.m_kind = BIND_C;
189 }
190 else {
191 instr.m_kind = BIND;
192 }
193 }
194 break;
195 }
196 default:
197 instr.m_kind = CHECK_TERM;
198 break;
199 }
200 m_instrs.push_back(instr);
201 }
202
203 if (m_regs.size() <= max_reg) {
204 m_regs.resize(max_reg+1);
205 }
206 if (m_bound_dom.size() <= num_bound) {
207 m_bound_dom.resize(num_bound+1);
208 m_bound_rng.resize(num_bound+1);
209 }
210
211 instr.m_kind = YIELD;
212 m_instrs.push_back(instr);
213 }
214
215
216 bool
match(expr * a,unsigned init,subst & s)217 expr_pattern_match::match(expr* a, unsigned init, subst& s)
218 {
219 svector<instr> bstack;
220 instr pc = m_instrs[init];
221
222 while (true) {
223 bool ok = false;
224 switch(pc.m_kind) {
225 case YIELD:
226 // substitution s contains registers with matching declarations.
227 return true;
228 case CHECK_TERM:
229 ok = (pc.m_pat == m_regs[pc.m_reg]);
230 break;
231 case SET_VAR:
232 case CHECK_VAR: {
233 app* app1 = to_app(pc.m_pat);
234 a = m_regs[pc.m_reg];
235 if (a->get_kind() != AST_APP) {
236 break;
237 }
238 app* app2 = to_app(a);
239 if (app1->get_num_args() != app2->get_num_args()) {
240 break;
241 }
242 if (pc.m_kind == CHECK_VAR &&
243 to_app(m_regs[pc.m_reg])->get_decl() !=
244 to_app(m_regs[pc.m_other_reg])->get_decl()) {
245 break;
246 }
247 for (unsigned i = 0; i < app2->get_num_args(); ++i) {
248 m_regs[pc.m_offset + i] = app2->get_arg(i);
249 }
250 if (pc.m_kind == SET_VAR) {
251 s.insert(app1->get_decl(), pc.m_reg);
252 }
253 ok = true;
254 break;
255 }
256 case SET_BOUND: {
257 a = m_regs[pc.m_reg];
258 if (a->get_kind() != AST_VAR) {
259 break;
260 }
261 ok = true;
262 var* var_a = to_var(a);
263 var* var_p = to_var(pc.m_pat);
264 // check that the mapping of bound variables remains a bijection.
265 for (unsigned i = 0; ok && i < pc.m_num_bound; ++i) {
266 ok = (a != m_bound_rng[i]);
267 }
268 if (!ok) {
269 break;
270 }
271 m_bound_dom[pc.m_num_bound] = var_p;
272 m_bound_rng[pc.m_num_bound] = var_a;
273 break;
274 }
275 case CHECK_BOUND:
276 TRACE("expr_pattern_match", tout << "check bound " << pc.m_num_bound << " " << pc.m_reg << "\n";);
277 ok = m_bound_rng[pc.m_num_bound] == m_regs[pc.m_reg];
278 break;
279 case BIND:
280 case BIND_AC:
281 case BIND_C: {
282 TRACE("expr_pattern_match", display(tout, pc);
283 tout << mk_pp(m_regs[pc.m_reg],m_manager) << "\n";);
284 app* app1 = to_app(pc.m_pat);
285 a = m_regs[pc.m_reg];
286 if (a->get_kind() != AST_APP) {
287 break;
288 }
289 app* app2 = to_app(a);
290 if (app1->get_num_args() != app2->get_num_args()) {
291 break;
292 }
293 if (!match_decl(app1->get_decl(), app2->get_decl())) {
294 break;
295 }
296 switch(pc.m_kind) {
297 case BIND:
298 for (unsigned i = 0; i < app2->get_num_args(); ++i) {
299 m_regs[pc.m_offset + i] = app2->get_arg(i);
300 }
301 ok = true;
302 break; // process the next instruction.
303 case BIND_AC:
304 // push CHOOSE_AC on the backtracking stack.
305 bstack.push_back(instr(CHOOSE_AC, pc.m_offset, pc.m_next, app2, 1));
306 break;
307 case BIND_C:
308 // push CHOOSE_C on the backtracking stack.
309 ok = true;
310 m_regs[pc.m_offset] = app2->get_arg(0);
311 m_regs[pc.m_offset+1] = app2->get_arg(1);
312 bstack.push_back(instr(CHOOSE_C, pc.m_offset, pc.m_next, app2, 2));
313 break;
314 default:
315 break;
316 }
317 break;
318 }
319 case CHOOSE_C:
320 ok = true;
321 SASSERT (pc.m_count == 2);
322 m_regs[pc.m_offset+1] = pc.m_app->get_arg(0);
323 m_regs[pc.m_offset] = pc.m_app->get_arg(1);
324 break;
325 case CHOOSE_AC: {
326 ok = true;
327 app* app2 = pc.m_app;
328 for (unsigned i = 0; i < app2->get_num_args(); ++i) {
329 m_regs[pc.m_offset + i] = app2->get_arg(i);
330 }
331 // generate the k'th permutation.
332 unsigned k = pc.m_count;
333 unsigned fac = 1;
334 unsigned num_args = pc.m_app->get_num_args();
335 for (unsigned j = 2; j <= num_args; ++j) {
336 fac *= (j-1);
337 SASSERT(((k /fac) % j) + 1 <= j);
338 std::swap(m_regs[pc.m_offset + j - 1], m_regs[pc.m_offset + j - ((k / fac) % j) - 1]);
339 }
340 if (k < fac*num_args) {
341 bstack.push_back(instr(CHOOSE_AC, pc.m_offset, pc.m_next, app2, k+1));
342 }
343 break;
344 }
345 case BACKTRACK:
346 if (bstack.empty()) {
347 return false;
348 }
349 pc = bstack.back();
350 bstack.pop_back();
351 continue; // with the loop.
352 }
353
354 if (ok) {
355 pc = m_instrs[pc.m_next];
356 }
357 else {
358 TRACE("expr_pattern_match", tout << "backtrack\n";);
359 pc = m_instrs[0];
360 }
361 }
362 }
363
364
365 bool
match_decl(func_decl const * pat,func_decl const * d) const366 expr_pattern_match::match_decl(func_decl const * pat, func_decl const * d) const {
367 if (pat == d) {
368 return true;
369 }
370 if (pat->get_arity() != d->get_arity()) {
371 return false;
372 }
373 // match families
374 if (pat->get_family_id() == null_family_id) {
375 return false;
376 }
377 if (d->get_family_id() != pat->get_family_id()) {
378 return false;
379 }
380 if (d->get_decl_kind() != pat->get_decl_kind()) {
381 return false;
382 }
383 if (d->get_num_parameters() != pat->get_num_parameters()) {
384 return false;
385 }
386 for (unsigned i = 0; i < d->get_num_parameters(); ++i) {
387 if (!(d->get_parameter(i) == pat->get_parameter(i))) {
388 return false;
389 }
390 }
391 return true;
392 }
393
394 bool
is_var(func_decl * d)395 expr_pattern_match::is_var(func_decl* d) {
396 const char* s = d->get_name().bare_str();
397 return s && *s == '?';
398 }
399
400 void
initialize(char const * spec_string)401 expr_pattern_match::initialize(char const * spec_string) {
402 if (!m_instrs.empty()) {
403 return;
404 }
405 m_instrs.push_back(instr(BACKTRACK));
406
407 std::istringstream is(spec_string);
408 cmd_context ctx(true, &m_manager);
409 bool ps = ctx.print_success_enabled();
410 ctx.set_print_success(false);
411 VERIFY(parse_smt2_commands(ctx, is));
412 ctx.set_print_success(ps);
413
414 for (expr * e : ctx.assertions()) {
415 compile(e);
416 }
417 }
418
initialize(quantifier * q)419 unsigned expr_pattern_match::initialize(quantifier* q) {
420 if (m_instrs.empty()) {
421 m_instrs.push_back(instr(BACKTRACK));
422 }
423 compile(q);
424 return m_precompiled.size() - 1;
425 }
426
427
display(std::ostream & out) const428 void expr_pattern_match::display(std::ostream& out) const {
429 for (unsigned i = 0; i < m_instrs.size(); ++i) {
430 display(out, m_instrs[i]);
431 }
432 }
433
434 void
display(std::ostream & out,instr const & pc) const435 expr_pattern_match::display(std::ostream& out, instr const& pc) const {
436 switch(pc.m_kind) {
437 case BACKTRACK:
438 out << "backtrack\n";
439 break;
440 case BIND:
441 out << "bind ";
442 out << mk_pp(pc.m_pat, m_manager) << "\n";
443 out << "next: " << pc.m_next << "\n";
444 out << "offset: " << pc.m_offset << "\n";
445 out << "reg: " << pc.m_reg << "\n";
446 break;
447 case BIND_AC:
448 out << "bind_ac ";
449 out << mk_pp(pc.m_pat, m_manager) << "\n";
450 out << "next: " << pc.m_next << "\n";
451 out << "offset: " << pc.m_offset << "\n";
452 out << "reg: " << pc.m_reg << "\n";
453 break;
454 case BIND_C:
455 out << "bind_c ";
456 out << mk_pp(pc.m_pat, m_manager) << "\n";
457 out << "next: " << pc.m_next << "\n";
458 out << "offset: " << pc.m_offset << "\n";
459 out << "reg: " << pc.m_reg << "\n";
460 break;
461 case CHOOSE_AC:
462 out << "choose_ac\n";
463 out << "next: " << pc.m_next << "\n";
464 out << "count: " << pc.m_count << "\n";
465 break;
466 case CHOOSE_C:
467 out << "choose_c\n";
468 out << "next: " << pc.m_next << "\n";
469 //out << "reg: " << pc.m_reg << "\n";
470 break;
471 case CHECK_VAR:
472 out << "check_var ";
473 out << mk_pp(pc.m_pat, m_manager) << "\n";
474 out << "next: " << pc.m_next << "\n";
475 out << "reg: " << pc.m_reg << "\n";
476 out << "other_reg: " << pc.m_other_reg << "\n";
477 break;
478 case CHECK_TERM:
479 out << "check ";
480 out << mk_pp(pc.m_pat, m_manager) << "\n";
481 out << "next: " << pc.m_next << "\n";
482 out << "reg: " << pc.m_reg << "\n";
483 break;
484 case YIELD:
485 out << "yield\n";
486 break;
487 case SET_VAR:
488 out << "set_var ";
489 out << mk_pp(pc.m_pat, m_manager) << "\n";
490 out << "next: " << pc.m_next << "\n";
491 break;
492 default:
493 break;
494 } }
495
496
497 // TBD: fix type overloading.
498 // TBD: bound number of permutations.
499 // TBD: forward pruning checks.
500