1 /*++
2 Copyright (c) 2006 Microsoft Corporation
3 
4 Module Name:
5 
6     smt_model_generator.cpp
7 
8 Abstract:
9 
10     <abstract>
11 
12 Author:
13 
14     Leonardo de Moura (leonardo) 2008-10-29.
15 
16 Revision History:
17 
18 --*/
19 
20 #include "util/ref_util.h"
21 #include "ast/for_each_expr.h"
22 #include "ast/ast_pp.h"
23 #include "ast/ast_smt2_pp.h"
24 #include "ast/array_decl_plugin.h"
25 #include "smt/smt_context.h"
26 #include "smt/smt_model_generator.h"
27 #include "smt/proto_model/proto_model.h"
28 #include "model/model_v2_pp.h"
29 
30 namespace smt {
31 
get_dependencies(buffer<model_value_dependency> & result)32     void fresh_value_proc::get_dependencies(buffer<model_value_dependency>& result) {
33         result.push_back(model_value_dependency(m_value));
34     }
35 
operator <<(std::ostream & out,model_value_dependency const & src)36     std::ostream& operator<<(std::ostream& out, model_value_dependency const& src) {
37         if (src.is_fresh_value()) return out << "fresh!" << src.get_value()->get_idx();
38         else return out << "#" << src.get_enode()->get_owner_id();
39     }
40 
model_generator(ast_manager & m)41     model_generator::model_generator(ast_manager & m):
42         m(m),
43         m_context(nullptr),
44         m_fresh_idx(1),
45         m_asts(m),
46         m_model(nullptr) {
47     }
48 
~model_generator()49     model_generator::~model_generator() {
50         dec_ref_collection_values(m, m_hidden_ufs);
51     }
52 
reset()53     void model_generator::reset() {
54         m_extra_fresh_values.reset();
55         m_fresh_idx = 1;
56         m_root2value.reset();
57         m_asts.reset();
58         m_model = nullptr;
59     }
60 
init_model()61     void model_generator::init_model() {
62         SASSERT(!m_model);
63         // PARAM-TODO smt_params ---> params_ref
64         m_model = alloc(proto_model, m); // , m_context->get_fparams());
65         for (theory* th : m_context->theories()) {
66             TRACE("model_generator_bug", tout << "init_model for theory: " << th->get_name() << "\n";);
67             th->init_model(*this);
68         }
69     }
70 
71     /**
72        \brief Create the boolean assignment.
73     */
mk_bool_model()74     void model_generator::mk_bool_model() {
75         unsigned sz = m_context->get_num_b_internalized();
76         for (unsigned i = 0; i < sz; i++) {
77             expr * p = m_context->get_b_internalized(i);
78             if (is_uninterp_const(p) && m_context->is_relevant(p)) {
79                 SASSERT(m.is_bool(p));
80                 func_decl * d = to_app(p)->get_decl();
81                 lbool val     = m_context->get_assignment(p);
82                 expr * v      = val == l_true ? m.mk_true() : m.mk_false();
83                 m_model->register_decl(d, v);
84             }
85         }
86     }
87 
88     /**
89        \brief Create the mapping root2proc: enode-root -> model_value_proc, and roots.
90        Store the new model_value_proc at procs.
91     */
mk_value_procs(obj_map<enode,model_value_proc * > & root2proc,ptr_vector<enode> & roots,ptr_vector<model_value_proc> & procs)92     void model_generator::mk_value_procs(obj_map<enode, model_value_proc *> & root2proc, ptr_vector<enode> & roots,
93                                          ptr_vector<model_value_proc> & procs) {
94         for (enode * r : m_context->enodes()) {
95             if (r == r->get_root() && (m_context->is_relevant(r) || m.is_value(r->get_expr()))) {
96                 roots.push_back(r);
97                 sort * s      = m.get_sort(r->get_owner());
98                 model_value_proc * proc = nullptr;
99                 if (m.is_bool(s)) {
100                     CTRACE("model", m_context->get_assignment(r) == l_undef,
101                            tout << mk_pp(r->get_owner(), m) << "\n";);
102                     SASSERT(m_context->get_assignment(r) != l_undef);
103                     if (m_context->get_assignment(r) == l_true)
104                         proc = alloc(expr_wrapper_proc, m.mk_true());
105                     else
106                         proc = alloc(expr_wrapper_proc, m.mk_false());
107                 }
108                 else if (m.is_value(r->get_expr()))
109                     proc = alloc(expr_wrapper_proc, r->get_expr());
110                 else {
111                     family_id fid = s->get_family_id();
112                     theory * th   = m_context->get_theory(fid);
113                     if (th && th->build_models()) {
114                         if (r->get_th_var(th->get_id()) != null_theory_var) {
115                             proc = th->mk_value(r, *this);
116                             SASSERT(proc);
117                         }
118                         else {
119                             TRACE("model", tout << "creating fresh value for #" << r->get_owner_id() << "\n";);
120                             proc = alloc(fresh_value_proc, mk_extra_fresh_value(m.get_sort(r->get_owner())));
121                         }
122                     }
123                     else {
124                         proc = mk_model_value(r);
125                         SASSERT(proc);
126                     }
127                 }
128                 SASSERT(proc);
129                 procs.push_back(proc);
130                 root2proc.insert(r, proc);
131             }
132         }
133     }
134 
mk_model_value(enode * r)135     model_value_proc* model_generator::mk_model_value(enode* r) {
136         SASSERT(r == r->get_root());
137         expr * n = r->get_owner();
138         if (!m.is_model_value(n)) {
139             sort * s = m.get_sort(r->get_owner());
140             n = m_model->get_fresh_value(s);
141             CTRACE("model", n == 0,
142                    tout << mk_pp(r->get_owner(), m) << "\nsort:\n" << mk_pp(s, m) << "\n";
143                    tout << "is_finite: " << m_model->is_finite(s) << "\n";);
144         }
145         return alloc(expr_wrapper_proc, to_app(n));
146     }
147 
148 #define White 0
149 #define Grey  1
150 #define Black 2
151 
get_color(source2color const & colors,source const & s)152     static int get_color(source2color const & colors, source const & s) {
153         int color;
154         if (colors.find(s, color))
155             return color;
156         return White;
157     }
158 
set_color(source2color & colors,source const & s,int c)159     static void set_color(source2color & colors, source const & s, int c) {
160         colors.insert(s, c);
161     }
162 
visit_child(source const & s,source2color & colors,svector<source> & todo,bool & visited)163     static void visit_child(source const & s, source2color & colors, svector<source> & todo, bool & visited) {
164         if (get_color(colors, s) == White) {
165             todo.push_back(s);
166             visited = false;
167         }
168     }
169 
visit_children(source const & src,ptr_vector<enode> const & roots,obj_map<enode,model_value_proc * > const & root2proc,source2color & colors,obj_hashtable<sort> & already_traversed,svector<source> & todo)170     bool model_generator::visit_children(source const & src,
171                                          ptr_vector<enode> const & roots,
172                                          obj_map<enode, model_value_proc *> const & root2proc,
173                                          source2color & colors,
174                                          obj_hashtable<sort> & already_traversed,
175                                          svector<source> & todo) {
176 
177         if (src.is_fresh_value()) {
178             // there is an implicit dependency between a fresh value stub of
179             // sort S and the root enodes of sort S that are not associated with fresh values.
180             //
181             sort * s = src.get_value()->get_sort();
182             if (already_traversed.contains(s))
183                 return true;
184             bool visited = true;
185             for (enode * r : roots) {
186                 if (m.get_sort(r->get_owner()) != s)
187                     continue;
188                 SASSERT(r == r->get_root());
189                 if (root2proc[r]->is_fresh())
190                     continue; // r is associated with a fresh value...
191                 TRACE("mg_top_sort", tout << "fresh!" << src.get_value()->get_idx() << " -> #" << r->get_owner_id() << " " << mk_pp(m.get_sort(r->get_owner()), m) << "\n";);
192                 visit_child(source(r), colors, todo, visited);
193                 TRACE("mg_top_sort", tout << "visited: " << visited << ", todo.size(): " << todo.size() << "\n";);
194             }
195             already_traversed.insert(s);
196             return visited;
197         }
198 
199         SASSERT(!src.is_fresh_value());
200 
201         enode * n = src.get_enode();
202         SASSERT(n == n->get_root());
203         bool visited = true;
204         model_value_proc * proc = root2proc[n];
205         buffer<model_value_dependency> dependencies;
206         proc->get_dependencies(dependencies);
207         for (model_value_dependency const& dep : dependencies) {
208             visit_child(dep, colors, todo, visited);
209         }
210         TRACE("mg_top_sort",
211               tout << "src: " << src << " ";
212               tout << mk_pp(n->get_owner(), m) << "\n";
213               for (model_value_dependency const& dep : dependencies) {
214                   tout << "#" << n->get_owner_id() << " -> " << dep << " already visited: " << visited << "\n";
215               }
216         );
217         return visited;
218     }
219 
process_source(source const & src,ptr_vector<enode> const & roots,obj_map<enode,model_value_proc * > const & root2proc,source2color & colors,obj_hashtable<sort> & already_traversed,svector<source> & todo,svector<source> & sorted_sources)220     void model_generator::process_source(source const & src,
221                                          ptr_vector<enode> const & roots,
222                                          obj_map<enode, model_value_proc *> const & root2proc,
223                                          source2color & colors,
224                                          obj_hashtable<sort> & already_traversed,
225                                          svector<source> & todo,
226                                          svector<source> & sorted_sources) {
227         TRACE("mg_top_sort", tout << "process source, is_fresh: " << src.is_fresh_value() << " ";
228               tout << src << ", todo.size(): " << todo.size() << "\n";);
229         int color     = get_color(colors, src);
230         SASSERT(color != Grey);
231         if (color == Black)
232             return;
233         SASSERT(color == White);
234         todo.push_back(src);
235         while (!todo.empty()) {
236             source curr = todo.back();
237             TRACE("mg_top_sort", tout << "current source, is_fresh: " << curr.is_fresh_value() << " ";
238                   tout << curr << ", todo.size(): " << todo.size() << "\n";);
239             switch (get_color(colors, curr)) {
240             case White:
241                 set_color(colors, curr, Grey);
242                 visit_children(curr, roots, root2proc, colors, already_traversed, todo);
243                 break;
244             case Grey:
245                 // SASSERT(visit_children(curr, roots, root2proc, colors, already_traversed, todo));
246                 set_color(colors, curr, Black);
247                 TRACE("mg_top_sort", tout << "append " << curr << "\n";);
248                 sorted_sources.push_back(curr);
249                 break;
250             case Black:
251                 todo.pop_back();
252                 break;
253             default:
254                 UNREACHABLE();
255             }
256         }
257         TRACE("mg_top_sort", tout << "END process_source, todo.size(): " << todo.size() << "\n";);
258     }
259 
260     /**
261        \brief Topological sort of 'sources'. Store result in sorted_sources.
262     */
top_sort_sources(ptr_vector<enode> const & roots,obj_map<enode,model_value_proc * > const & root2proc,svector<source> & sorted_sources)263     void model_generator::top_sort_sources(ptr_vector<enode> const & roots,
264                                            obj_map<enode, model_value_proc *> const & root2proc,
265                                            svector<source> & sorted_sources) {
266 
267         svector<source>     todo;
268         source2color        colors;
269         // The following 'set' of sorts is used to avoid traversing roots looking for enodes of sort S.
270         // That is, a sort S is in already_traversed, if all enodes of sort S in roots were already traversed.
271         obj_hashtable<sort> already_traversed;
272 
273         // topological sort
274 
275         // traverse all extra fresh values...
276         for (extra_fresh_value * f : m_extra_fresh_values) {
277             process_source(source(f), roots, root2proc, colors, already_traversed, todo, sorted_sources);
278         }
279 
280         // traverse all enodes that are associated with fresh values...
281         for (enode* r : roots) {
282             if (root2proc[r]->is_fresh()) {
283                 process_source(source(r), roots, root2proc, colors, already_traversed, todo, sorted_sources);
284             }
285         }
286 
287         for (enode * r : roots) {
288             process_source(source(r), roots, root2proc, colors, already_traversed, todo, sorted_sources);
289         }
290     }
291 
mk_values()292     void model_generator::mk_values() {
293         obj_map<enode, model_value_proc *> root2proc;
294         ptr_vector<enode> roots;
295         ptr_vector<model_value_proc> procs;
296         svector<source> sources;
297         buffer<model_value_dependency> dependencies;
298         expr_ref_vector dependency_values(m);
299         mk_value_procs(root2proc, roots, procs);
300         top_sort_sources(roots, root2proc, sources);
301         TRACE("sorted_sources",
302               for (source const& curr : sources) {
303                   if (curr.is_fresh_value()) {
304                       tout << curr << " " << mk_pp(curr.get_value()->get_sort(), m) << "\n";
305                   }
306                   else {
307                       enode * n = curr.get_enode();
308                       SASSERT(n->get_root() == n);
309                       tout << mk_pp(n->get_owner(), m) << "\n";
310                       sort * s = m.get_sort(n->get_owner());
311                       tout << curr << " " << mk_pp(s, m);
312                       tout << " is_fresh: " << root2proc[n]->is_fresh() << "\n";
313                   }
314               }
315               m_context->display(tout);
316               );
317 
318         scoped_reset _scoped_reset(*this, procs);
319 
320         for (source const& curr : sources) {
321             if (curr.is_fresh_value()) {
322                 sort * s = curr.get_value()->get_sort();
323                 TRACE("model_fresh_bug", tout << curr << " : " << mk_pp(s, m) << " " << curr.get_value()->get_value() << "\n";);
324                 expr * val = m_model->get_fresh_value(s);
325                 TRACE("model_fresh_bug", tout << curr << " := #" << (val == nullptr ? UINT_MAX : val->get_id()) << "\n";);
326                 m_asts.push_back(val);
327                 curr.get_value()->set_value(val);
328             }
329             else {
330                 enode * n = curr.get_enode();
331                 SASSERT(n->get_root() == n);
332                 TRACE("mg_top_sort", tout << curr << "\n";);
333                 dependencies.reset();
334                 dependency_values.reset();
335                 model_value_proc * proc = root2proc[n];
336                 SASSERT(proc);
337                 proc->get_dependencies(dependencies);
338                 for (model_value_dependency const& d : dependencies) {
339                     if (d.is_fresh_value()) {
340                         CTRACE("mg_top_sort", !d.get_value()->get_value(),
341                                tout << "#" << n->get_owner_id() << " " << mk_pp(n->get_owner(), m) << " -> " << d << "\n";);
342                         SASSERT(d.get_value()->get_value());
343                         dependency_values.push_back(d.get_value()->get_value());
344                     }
345                     else {
346                         enode * child = d.get_enode();
347                         TRACE("mg_top_sort", tout << "#" << n->get_owner_id() << " (" << mk_pp(n->get_owner(), m) << "): "
348                               << mk_pp(child->get_owner(), m) << " " << mk_pp(child->get_root()->get_owner(), m) << "\n";);
349                         child = child->get_root();
350                         dependency_values.push_back(m_root2value[child]);
351                     }
352                 }
353                 app * val = proc->mk_value(*this, dependency_values);
354                 register_value(val);
355                 m_asts.push_back(val);
356                 m_root2value.insert(n, val);
357             }
358         }
359         // send model
360         for (enode * n : m_context->enodes()) {
361             if (is_uninterp_const(n->get_owner()) && m_context->is_relevant(n)) {
362                 func_decl * d = n->get_owner()->get_decl();
363                 TRACE("mg_top_sort", tout << d->get_name() << " " << (m_hidden_ufs.contains(d)?"hidden":"visible") << "\n";);
364                 if (m_hidden_ufs.contains(d)) continue;
365                 expr * val    = get_value(n);
366                 m_model->register_decl(d, val);
367             }
368         }
369     }
370 
scoped_reset(model_generator & mg,ptr_vector<model_value_proc> & procs)371     model_generator::scoped_reset::scoped_reset(model_generator& mg, ptr_vector<model_value_proc>& procs):
372         mg(mg), procs(procs) {}
373 
~scoped_reset()374     model_generator::scoped_reset::~scoped_reset() {
375         std::for_each(procs.begin(), procs.end(), delete_proc<model_value_proc>());
376         std::for_each(mg.m_extra_fresh_values.begin(), mg.m_extra_fresh_values.end(), delete_proc<extra_fresh_value>());
377         mg.m_extra_fresh_values.reset();
378     }
379 
get_value(enode * n) const380     app * model_generator::get_value(enode * n) const {
381         return m_root2value[n->get_root()];
382     }
383 
384     /**
385        \brief Return true if the interpretation of the function should be included in the model.
386     */
include_func_interp(func_decl * f) const387     bool model_generator::include_func_interp(func_decl * f) const {
388         family_id fid = f->get_family_id();
389         if (fid == null_family_id) return !m_hidden_ufs.contains(f);
390         if (fid == m.get_basic_family_id()) return false;
391         theory * th = m_context->get_theory(fid);
392         if (!th) return true;
393         return th->include_func_interp(f);
394     }
395 
396     /**
397        \brief Create (partial) interpretation of function symbols.
398        The "else" is missing.
399     */
mk_func_interps()400     void model_generator::mk_func_interps() {
401         unsigned sz = m_context->get_num_e_internalized();
402         for (unsigned i = 0; i < sz; i++) {
403             expr * t  = m_context->get_e_internalized(i);
404             if (!m_context->is_relevant(t))
405                 continue;
406             enode * n         = m_context->get_enode(t);
407             unsigned num_args = n->get_num_args();
408             func_decl * f     = n->get_decl();
409             if (num_args == 0 && include_func_interp(f)) {
410                 m_model->register_decl(f, get_value(n));
411             }
412             else if (num_args > 0 && n->get_cg() == n && include_func_interp(f)) {
413                 ptr_buffer<expr> args;
414                 expr * result = get_value(n);
415                 SASSERT(result);
416                 for (unsigned j = 0; j < num_args; j++) {
417                     app * arg = get_value(n->get_arg(j));
418                     SASSERT(arg);
419                     args.push_back(arg);
420                 }
421                 func_interp * fi = m_model->get_func_interp(f);
422                 if (fi == nullptr) {
423                     fi = alloc(func_interp, m, f->get_arity());
424                     m_model->register_decl(f, fi);
425                 }
426                 SASSERT(m_model->has_interpretation(f));
427                 SASSERT(m_model->get_func_interp(f) == fi);
428                 // The entry must be new because n->get_cg() == n
429                 TRACE("model",
430                       tout << "insert new entry for:\n" << mk_ismt2_pp(n->get_owner(), m) << "\nargs: ";
431                       for (unsigned i = 0; i < num_args; i++) {
432                           tout << "#" << n->get_arg(i)->get_owner_id() << " ";
433                       }
434                       tout << "\n";
435                       for (expr* arg : args) {
436                           tout << mk_pp(arg, m) << " ";
437                       }
438                       tout << "\n";
439                       tout << "value: #" << n->get_owner_id() << "\n" << mk_ismt2_pp(result, m) << "\n";);
440                 if (fi->get_entry(args.c_ptr()) == nullptr)
441                     fi->insert_new_entry(args.c_ptr(), result);
442             }
443         }
444     }
445 
mk_extra_fresh_value(sort * s)446     extra_fresh_value * model_generator::mk_extra_fresh_value(sort * s) {
447         extra_fresh_value * r = alloc(extra_fresh_value, s, m_fresh_idx);
448         m_fresh_idx++;
449         m_extra_fresh_values.push_back(r);
450         return r;
451     }
452 
get_some_value(sort * s)453     expr * model_generator::get_some_value(sort * s) {
454         SASSERT(m_model);
455         return m_model->get_some_value(s);
456     }
457 
register_value(expr * val)458     void model_generator::register_value(expr * val) {
459         SASSERT(m_model);
460         m_model->register_value(val);
461     }
462 
finalize_theory_models()463     void model_generator::finalize_theory_models() {
464         for (theory* th : m_context->theories())
465             th->finalize_model(*this);
466     }
467 
register_existing_model_values()468     void model_generator::register_existing_model_values() {
469         for (enode * r : m_context->enodes()) {
470             if (r == r->get_root() && m_context->is_relevant(r)) {
471                 expr * n = r->get_owner();
472                 if (m.is_model_value(n)) {
473                     register_value(n);
474                 }
475             }
476         }
477     }
478 
register_factory(value_factory * f)479     void model_generator::register_factory(value_factory * f) {
480         m_model->register_factory(f);
481     }
482 
register_macros()483     void model_generator::register_macros() {
484         unsigned num = m_context->get_num_macros();
485         TRACE("model", tout << "num. macros: " << num << "\n";);
486         expr_ref v(m);
487         for (unsigned i = 0; i < num; i++) {
488             func_decl * f    = m_context->get_macro_interpretation(i, v);
489             func_interp * fi = alloc(func_interp, m, f->get_arity());
490             fi->set_else(v);
491             TRACE("model", tout << f->get_name() << "\n" << mk_pp(v, m) << "\n";);
492             m_model->register_decl(f, fi);
493         }
494     }
495 
mk_model()496     proto_model * model_generator::mk_model() {
497         SASSERT(!m_model);
498         TRACE("model_verbose", m_context->display(tout););
499         init_model();
500         register_existing_model_values();
501         mk_bool_model();
502         mk_values();
503         mk_func_interps();
504         finalize_theory_models();
505         register_macros();
506         TRACE("model", model_v2_pp(tout, *m_model, true););
507         return m_model.get();
508     }
509 
510 };
511