1 #include "../../common/isa.h"
2 #include <cassert>
3 #include <cstddef>
4 #include "location.hh"
5 #include <gmpxx.h>
6 #include <memory>
7 #include <rumur/Boolean.h>
8 #include <rumur/Decl.h>
9 #include <rumur/Expr.h>
10 #include <rumur/Model.h>
11 #include <rumur/Node.h>
12 #include <rumur/Number.h>
13 #include <rumur/Ptr.h>
14 #include <rumur/Rule.h>
15 #include <rumur/Stmt.h>
16 #include <rumur/Symtab.h>
17 #include <rumur/TypeExpr.h>
18 #include <rumur/resolve-symbols.h>
19 #include <rumur/traverse.h>
20 #include <rumur/validate.h>
21 #include <string>
22 #include <utility>
23 
24 using namespace rumur;
25 
26 namespace {
27 
28 class Resolver : public Traversal {
29 
30 private:
31   Symtab symtab;
32 
33 public:
Resolver()34   Resolver() {
35 
36     // Open a global scope
37     symtab.open_scope();
38 
39     // Teach the symbol table the built ins
40     auto td = Ptr<TypeDecl>::make("boolean", Boolean, location());
41     symtab.declare("boolean", td);
42     mpz_class index = 0;
43     for (const std::pair<std::string, location> &m : Boolean->members) {
44       symtab.declare(m.first,
45                      Ptr<ConstDecl>::make("boolean",
46                                           Ptr<Number>::make(index, location()),
47                                           Boolean, location()));
48       index++;
49     }
50   }
51 
visit_add(Add & n)52   void visit_add(Add &n) final { visit_bexpr(n); }
53 
visit_aliasdecl(AliasDecl & n)54   void visit_aliasdecl(AliasDecl &n) final {
55     dispatch(*n.value);
56     disambiguate(n.value);
57   }
58 
visit_aliasrule(AliasRule & n)59   void visit_aliasrule(AliasRule &n) final {
60     symtab.open_scope();
61     for (auto &a : n.aliases) {
62       dispatch(*a);
63       symtab.declare(a->name, a);
64     }
65     for (auto &r : n.rules)
66       dispatch(*r);
67     symtab.close_scope();
68   }
69 
visit_aliasstmt(AliasStmt & n)70   void visit_aliasstmt(AliasStmt &n) final {
71     symtab.open_scope();
72     for (auto &a : n.aliases) {
73       dispatch(*a);
74       symtab.declare(a->name, a);
75     }
76     for (auto &s : n.body)
77       dispatch(*s);
78     symtab.close_scope();
79   }
80 
visit_ambiguousamp(AmbiguousAmp & n)81   void visit_ambiguousamp(AmbiguousAmp &n) final { visit_bexpr(n); }
82 
visit_ambiguouspipe(AmbiguousPipe & n)83   void visit_ambiguouspipe(AmbiguousPipe &n) final { visit_bexpr(n); }
84 
visit_and(And & n)85   void visit_and(And &n) final { visit_bexpr(n); }
86 
visit_assignment(Assignment & n)87   void visit_assignment(Assignment &n) final {
88     dispatch(*n.lhs);
89     dispatch(*n.rhs);
90     disambiguate(n.lhs);
91     disambiguate(n.rhs);
92   }
93 
visit_band(Band & n)94   void visit_band(Band &n) final { visit_bexpr(n); }
95 
visit_bnot(Bnot & n)96   void visit_bnot(Bnot &n) final { visit_uexpr(n); }
97 
visit_bor(Bor & n)98   void visit_bor(Bor &n) final { visit_bexpr(n); }
99 
visit_clear(Clear & n)100   void visit_clear(Clear &n) final {
101     dispatch(*n.rhs);
102     disambiguate(n.rhs);
103   }
104 
visit_constdecl(ConstDecl & n)105   void visit_constdecl(ConstDecl &n) final {
106     dispatch(*n.value);
107     disambiguate(n.value);
108   }
109 
visit_div(Div & n)110   void visit_div(Div &n) final { visit_bexpr(n); }
111 
visit_element(Element & n)112   void visit_element(Element &n) final {
113     dispatch(*n.array);
114     dispatch(*n.index);
115     disambiguate(n.array);
116     disambiguate(n.index);
117   }
118 
visit_enum(Enum & n)119   void visit_enum(Enum &n) final {
120     auto e = Ptr<Enum>::make(n);
121 
122     // register all the enum members so they can be referenced later
123     mpz_class index = 0;
124     size_t id = e->unique_id + 1;
125     for (const std::pair<std::string, location> &m : n.members) {
126       auto cd = Ptr<ConstDecl>::make(
127           m.first, Ptr<Number>::make(index, m.second), e, m.second);
128       // assign this member a unique id so that referrers can use it if need be
129       assert(id < e->unique_id_limit &&
130              "number of enum members exceeds what was expected");
131       cd->unique_id = id;
132       symtab.declare(m.first, cd);
133       index++;
134       id++;
135     }
136   }
137 
visit_eq(Eq & n)138   void visit_eq(Eq &n) final { visit_bexpr(n); }
139 
visit_exists(Exists & n)140   void visit_exists(Exists &n) final {
141     symtab.open_scope();
142     dispatch(n.quantifier);
143     dispatch(*n.expr);
144     symtab.close_scope();
145     disambiguate(n.expr);
146   }
147 
visit_exprid(ExprID & n)148   void visit_exprid(ExprID &n) final {
149     if (n.value == nullptr) {
150       // This reference is unresolved
151 
152       Ptr<ExprDecl> d = symtab.lookup<ExprDecl>(n.id, n.loc);
153       if (d == nullptr)
154         throw Error("unknown symbol \"" + n.id + "\"", n.loc);
155 
156       n.value = d;
157     }
158   }
159 
visit_field(Field & n)160   void visit_field(Field &n) final {
161     dispatch(*n.record);
162     disambiguate(n.record);
163   }
164 
visit_for(For & n)165   void visit_for(For &n) final {
166     symtab.open_scope();
167     dispatch(n.quantifier);
168     for (auto &s : n.body)
169       dispatch(*s);
170     symtab.close_scope();
171   }
172 
visit_forall(Forall & n)173   void visit_forall(Forall &n) final {
174     symtab.open_scope();
175     dispatch(n.quantifier);
176     dispatch(*n.expr);
177     symtab.close_scope();
178     disambiguate(n.expr);
179   }
180 
visit_function(Function & n)181   void visit_function(Function &n) final {
182     symtab.open_scope();
183     for (auto &p : n.parameters)
184       dispatch(*p);
185     if (n.return_type != nullptr)
186       dispatch(*n.return_type);
187     // register the function itself, even though its body has not yet been
188     // resolved, in order to allow contained function calls to resolve to the
189     // containing function, supporting recursion
190     symtab.declare(n.name, Ptr<Function>::make(n));
191     // only register the function parameters now, to avoid their names shadowing
192     // anything that needs to be resolved during symbol resolution of another
193     // parameter or the return type
194     for (auto &p : n.parameters)
195       symtab.declare(p->name, p);
196     for (auto &d : n.decls) {
197       dispatch(*d);
198       symtab.declare(d->name, d);
199     }
200     for (auto &s : n.body)
201       dispatch(*s);
202     symtab.close_scope();
203   }
204 
visit_functioncall(FunctionCall & n)205   void visit_functioncall(FunctionCall &n) final {
206     if (n.function == nullptr) {
207       // This reference is unresolved
208 
209       Ptr<Function> f = symtab.lookup<Function>(n.name, n.loc);
210       if (f == nullptr)
211         throw Error("unknown function call \"" + n.name + "\"", n.loc);
212 
213       n.function = f;
214     }
215     for (auto &a : n.arguments)
216       dispatch(*a);
217 
218     for (Ptr<Expr> &a : n.arguments)
219       disambiguate(a);
220   }
221 
visit_geq(Geq & n)222   void visit_geq(Geq &n) final { visit_bexpr(n); }
223 
visit_gt(Gt & n)224   void visit_gt(Gt &n) final { visit_bexpr(n); }
225 
visit_ifclause(IfClause & n)226   void visit_ifclause(IfClause &n) final {
227     if (n.condition != nullptr)
228       dispatch(*n.condition);
229     for (auto &s : n.body)
230       dispatch(*s);
231     if (n.condition != nullptr)
232       disambiguate(n.condition);
233   }
234 
visit_implication(Implication & n)235   void visit_implication(Implication &n) final { visit_bexpr(n); }
236 
visit_isundefined(IsUndefined & n)237   void visit_isundefined(IsUndefined &n) final { visit_uexpr(n); }
238 
visit_leq(Leq & n)239   void visit_leq(Leq &n) final { visit_bexpr(n); }
240 
visit_lsh(Lsh & n)241   void visit_lsh(Lsh &n) final { visit_bexpr(n); }
242 
visit_lt(Lt & n)243   void visit_lt(Lt &n) final { visit_bexpr(n); }
244 
visit_model(Model & n)245   void visit_model(Model &n) final {
246 
247     // running marker of offset in the global state data
248     mpz_class offset = 0;
249 
250     /* whether we have not yet hit any problems that make offset calculation
251      * impossible
252      */
253     bool ok = true;
254 
255     for (Ptr<Node> &c : n.children) {
256       dispatch(*c);
257 
258       /* if this was a variable declaration, we now know enough to determine its
259        * offset in the global state data
260        */
261       if (ok) {
262         if (auto v = dynamic_cast<VarDecl *>(c.get())) {
263 
264           /* If the declaration or one of its children does not validate, it is
265            * unsafe to call width().
266            */
267           try {
268             validate(*v);
269           } catch (Error &) {
270             /* Skip this and future offset calculations and assume our caller
271              * will eventually discover the underlying reason when they call
272              * n.validate().
273              */
274             ok = false;
275           }
276 
277           if (ok) {
278             v->offset = offset;
279             offset += v->type->width();
280           }
281         }
282       }
283 
284       if (auto d = dynamic_cast<Decl *>(c.get()))
285         symtab.declare(d->name, c);
286       if (auto f = dynamic_cast<Function *>(c.get()))
287         symtab.declare(f->name, c);
288     }
289   }
290 
visit_mod(Mod & n)291   void visit_mod(Mod &n) final { visit_bexpr(n); }
292 
visit_mul(Mul & n)293   void visit_mul(Mul &n) final { visit_bexpr(n); }
294 
visit_negative(Negative & n)295   void visit_negative(Negative &n) final { visit_uexpr(n); }
296 
visit_neq(Neq & n)297   void visit_neq(Neq &n) final { visit_bexpr(n); }
298 
visit_not(Not & n)299   void visit_not(Not &n) final { visit_uexpr(n); }
300 
visit_or(Or & n)301   void visit_or(Or &n) final { visit_bexpr(n); }
302 
visit_property(Property & n)303   void visit_property(Property &n) final {
304     dispatch(*n.expr);
305     disambiguate(n.expr);
306   }
307 
visit_put(Put & n)308   void visit_put(Put &n) final {
309     if (n.expr != nullptr) {
310       dispatch(*n.expr);
311       disambiguate(n.expr);
312     }
313   }
314 
visit_quantifier(Quantifier & n)315   void visit_quantifier(Quantifier &n) final {
316     if (n.type != nullptr) {
317       // wrap symbol resolution within the type in a dummy scope to suppress any
318       // declarations (primarily enum members) as these will be duplicated in
319       // when we descend into decl below
320       symtab.open_scope();
321       dispatch(*n.type);
322       symtab.close_scope();
323     }
324     if (n.from != nullptr)
325       dispatch(*n.from);
326     if (n.to != nullptr)
327       dispatch(*n.to);
328     if (n.step != nullptr)
329       dispatch(*n.step);
330 
331     if (n.from != nullptr)
332       disambiguate(n.from);
333     if (n.to != nullptr)
334       disambiguate(n.to);
335     if (n.step != nullptr)
336       disambiguate(n.step);
337 
338     // if the bounds for this iteration are now known to be constant, we can
339     // narrow its VarDecl
340     if (n.from != nullptr && n.from->constant() && n.to != nullptr &&
341         n.to->constant()) {
342       auto r = dynamic_cast<Range *>(n.decl->type.get());
343       assert(r != nullptr && "non-range type used for inferred loop decl");
344       // the range may have been given as either an up count or down count
345       if (n.from->constant_fold() <= n.to->constant_fold()) {
346         r->min = n.from;
347         r->max = n.to;
348       } else {
349         r->min = n.to;
350         r->max = n.from;
351       }
352     }
353 
354     dispatch(*n.decl);
355 
356     symtab.declare(n.name, n.decl);
357   }
358 
visit_range(Range & n)359   void visit_range(Range &n) final {
360     dispatch(*n.min);
361     disambiguate(n.min);
362     dispatch(*n.max);
363     disambiguate(n.max);
364   }
365 
visit_return(Return & n)366   void visit_return(Return &n) final {
367     if (n.expr != nullptr) {
368       dispatch(*n.expr);
369       disambiguate(n.expr);
370     }
371   }
372 
visit_rsh(Rsh & n)373   void visit_rsh(Rsh &n) final { visit_bexpr(n); }
374 
visit_ruleset(Ruleset & n)375   void visit_ruleset(Ruleset &n) final {
376     symtab.open_scope();
377     for (Quantifier &q : n.quantifiers)
378       dispatch(q);
379     for (auto &r : n.rules)
380       dispatch(*r);
381     symtab.close_scope();
382   }
383 
visit_scalarset(Scalarset & n)384   void visit_scalarset(Scalarset &n) final {
385     dispatch(*n.bound);
386     disambiguate(n.bound);
387   }
388 
visit_simplerule(SimpleRule & n)389   void visit_simplerule(SimpleRule &n) final {
390     symtab.open_scope();
391     for (Quantifier &q : n.quantifiers)
392       dispatch(q);
393     if (n.guard != nullptr)
394       dispatch(*n.guard);
395     for (auto &d : n.decls) {
396       dispatch(*d);
397       symtab.declare(d->name, d);
398     }
399     for (auto &s : n.body)
400       dispatch(*s);
401     symtab.close_scope();
402     if (n.guard != nullptr)
403       disambiguate(n.guard);
404   }
405 
visit_startstate(StartState & n)406   void visit_startstate(StartState &n) final {
407     symtab.open_scope();
408     for (Quantifier &q : n.quantifiers)
409       dispatch(q);
410     for (auto &d : n.decls) {
411       dispatch(*d);
412       symtab.declare(d->name, d);
413     }
414     for (auto &s : n.body)
415       dispatch(*s);
416     symtab.close_scope();
417   }
418 
visit_sub(Sub & n)419   void visit_sub(Sub &n) final { visit_bexpr(n); }
420 
visit_switch(Switch & n)421   void visit_switch(Switch &n) final {
422     dispatch(*n.expr);
423     for (SwitchCase &c : n.cases)
424       dispatch(c);
425     disambiguate(n.expr);
426   }
427 
visit_switchcase(SwitchCase & n)428   void visit_switchcase(SwitchCase &n) final {
429     for (auto &m : n.matches)
430       dispatch(*m);
431     for (auto &s : n.body)
432       dispatch(*s);
433     for (Ptr<Expr> &m : n.matches)
434       disambiguate(m);
435   }
436 
visit_ternary(Ternary & n)437   void visit_ternary(Ternary &n) {
438     dispatch(*n.cond);
439     dispatch(*n.lhs);
440     dispatch(*n.rhs);
441     disambiguate(n.cond);
442     disambiguate(n.lhs);
443     disambiguate(n.rhs);
444   }
445 
visit_typeexprid(TypeExprID & n)446   void visit_typeexprid(TypeExprID &n) final {
447     if (n.referent == nullptr) {
448       // This reference is unresolved
449 
450       Ptr<TypeDecl> t = symtab.lookup<TypeDecl>(n.name, n.loc);
451       if (t == nullptr)
452         throw Error("unknown type symbol \"" + n.name + "\"", n.loc);
453 
454       n.referent = t;
455     }
456   }
457 
visit_undefine(Undefine & n)458   void visit_undefine(Undefine &n) final {
459     dispatch(*n.rhs);
460     disambiguate(n.rhs);
461   }
462 
visit_while(While & n)463   void visit_while(While &n) final {
464     dispatch(*n.condition);
465     for (auto &s : n.body)
466       dispatch(*s);
467     disambiguate(n.condition);
468   }
469 
visit_xor(Xor & n)470   void visit_xor(Xor &n) final { visit_bexpr(n); }
471 
472   virtual ~Resolver() = default;
473 
474 private:
visit_bexpr(BinaryExpr & n)475   void visit_bexpr(BinaryExpr &n) {
476     dispatch(*n.lhs);
477     dispatch(*n.rhs);
478     disambiguate(n.lhs);
479     disambiguate(n.rhs);
480   }
481 
visit_uexpr(UnaryExpr & n)482   void visit_uexpr(UnaryExpr &n) {
483     dispatch(*n.rhs);
484     disambiguate(n.rhs);
485   }
486 
487   // detect whether this is an ambiguous node and, if so, resolve it into its
488   // more precise AST node type
disambiguate(Ptr<Expr> & e)489   void disambiguate(Ptr<Expr> &e) {
490 
491     if (auto a = dynamic_cast<const AmbiguousAmp *>(e.get())) {
492 
493       // try to get the type of the left hand side
494       Ptr<TypeExpr> t;
495       try {
496         t = a->lhs->type();
497       } catch (Error &) {
498         // We failed because the left operand is somehow invalid. Silently
499         // ignore this, assuming it will be rediscovered during AST validation.
500         return;
501       }
502 
503       // Form an unambiguous replacement node based on the type of the left
504       // operand. Note that the types of the left and right operands may be
505       // incompatible. However, this will cause an error during AST validation
506       // so we do not need to worry about that here.
507       Ptr<Expr> replacement;
508       if (isa<Range>(t)) {
509         replacement = Ptr<Band>::make(a->lhs, a->rhs, a->loc);
510       } else {
511         replacement = Ptr<And>::make(a->lhs, a->rhs, a->loc);
512       }
513 
514       // also preserve the identifier which has already been set
515       replacement->unique_id = a->unique_id;
516 
517       // replace the ambiguous node
518       e = replacement;
519 
520       return;
521     }
522 
523     if (auto o = dynamic_cast<const AmbiguousPipe *>(e.get())) {
524 
525       // try to get the type of the left hand side
526       Ptr<TypeExpr> t;
527       try {
528         t = o->lhs->type();
529       } catch (Error &) {
530         // We failed because the left operand is somehow invalid. Silently
531         // ignore this, assuming it will be rediscovered during AST validation.
532         return;
533       }
534 
535       // Form an unambiguous replacement node based on the type of the left
536       // operand. Note that the types of the left and right operands may be
537       // incompatible. However, this will cause an error during AST validation
538       // so we do not need to worry about that here.
539       Ptr<Expr> replacement;
540       if (isa<Range>(t)) {
541         replacement = Ptr<Bor>::make(o->lhs, o->rhs, o->loc);
542       } else {
543         replacement = Ptr<Or>::make(o->lhs, o->rhs, o->loc);
544       }
545 
546       // also preserve the identifier which has already been set
547       replacement->unique_id = o->unique_id;
548 
549       // replace the ambiguous node
550       e = replacement;
551 
552       return;
553     }
554   }
555 };
556 
557 } // namespace
558 
resolve_symbols(Model & m)559 void rumur::resolve_symbols(Model &m) {
560   Resolver r;
561   r.dispatch(m);
562 }
563