1 #include "../../common/isa.h"
2 #include <cassert>
3 #include <cstddef>
4 #include "location.hh"
5 #include <gmpxx.h>
6 #include <memory>
7 #include <rumur/Boolean.h>
8 #include <rumur/Decl.h>
9 #include <rumur/Expr.h>
10 #include <rumur/Model.h>
11 #include <rumur/Node.h>
12 #include <rumur/Number.h>
13 #include <rumur/Ptr.h>
14 #include <rumur/Rule.h>
15 #include <rumur/Stmt.h>
16 #include <rumur/Symtab.h>
17 #include <rumur/TypeExpr.h>
18 #include <rumur/resolve-symbols.h>
19 #include <rumur/traverse.h>
20 #include <rumur/validate.h>
21 #include <string>
22 #include <utility>
23
24 using namespace rumur;
25
26 namespace {
27
28 class Resolver : public Traversal {
29
30 private:
31 Symtab symtab;
32
33 public:
Resolver()34 Resolver() {
35
36 // Open a global scope
37 symtab.open_scope();
38
39 // Teach the symbol table the built ins
40 auto td = Ptr<TypeDecl>::make("boolean", Boolean, location());
41 symtab.declare("boolean", td);
42 mpz_class index = 0;
43 for (const std::pair<std::string, location> &m : Boolean->members) {
44 symtab.declare(m.first,
45 Ptr<ConstDecl>::make("boolean",
46 Ptr<Number>::make(index, location()),
47 Boolean, location()));
48 index++;
49 }
50 }
51
visit_add(Add & n)52 void visit_add(Add &n) final { visit_bexpr(n); }
53
visit_aliasdecl(AliasDecl & n)54 void visit_aliasdecl(AliasDecl &n) final {
55 dispatch(*n.value);
56 disambiguate(n.value);
57 }
58
visit_aliasrule(AliasRule & n)59 void visit_aliasrule(AliasRule &n) final {
60 symtab.open_scope();
61 for (auto &a : n.aliases) {
62 dispatch(*a);
63 symtab.declare(a->name, a);
64 }
65 for (auto &r : n.rules)
66 dispatch(*r);
67 symtab.close_scope();
68 }
69
visit_aliasstmt(AliasStmt & n)70 void visit_aliasstmt(AliasStmt &n) final {
71 symtab.open_scope();
72 for (auto &a : n.aliases) {
73 dispatch(*a);
74 symtab.declare(a->name, a);
75 }
76 for (auto &s : n.body)
77 dispatch(*s);
78 symtab.close_scope();
79 }
80
visit_ambiguousamp(AmbiguousAmp & n)81 void visit_ambiguousamp(AmbiguousAmp &n) final { visit_bexpr(n); }
82
visit_ambiguouspipe(AmbiguousPipe & n)83 void visit_ambiguouspipe(AmbiguousPipe &n) final { visit_bexpr(n); }
84
visit_and(And & n)85 void visit_and(And &n) final { visit_bexpr(n); }
86
visit_assignment(Assignment & n)87 void visit_assignment(Assignment &n) final {
88 dispatch(*n.lhs);
89 dispatch(*n.rhs);
90 disambiguate(n.lhs);
91 disambiguate(n.rhs);
92 }
93
visit_band(Band & n)94 void visit_band(Band &n) final { visit_bexpr(n); }
95
visit_bnot(Bnot & n)96 void visit_bnot(Bnot &n) final { visit_uexpr(n); }
97
visit_bor(Bor & n)98 void visit_bor(Bor &n) final { visit_bexpr(n); }
99
visit_clear(Clear & n)100 void visit_clear(Clear &n) final {
101 dispatch(*n.rhs);
102 disambiguate(n.rhs);
103 }
104
visit_constdecl(ConstDecl & n)105 void visit_constdecl(ConstDecl &n) final {
106 dispatch(*n.value);
107 disambiguate(n.value);
108 }
109
visit_div(Div & n)110 void visit_div(Div &n) final { visit_bexpr(n); }
111
visit_element(Element & n)112 void visit_element(Element &n) final {
113 dispatch(*n.array);
114 dispatch(*n.index);
115 disambiguate(n.array);
116 disambiguate(n.index);
117 }
118
visit_enum(Enum & n)119 void visit_enum(Enum &n) final {
120 auto e = Ptr<Enum>::make(n);
121
122 // register all the enum members so they can be referenced later
123 mpz_class index = 0;
124 size_t id = e->unique_id + 1;
125 for (const std::pair<std::string, location> &m : n.members) {
126 auto cd = Ptr<ConstDecl>::make(
127 m.first, Ptr<Number>::make(index, m.second), e, m.second);
128 // assign this member a unique id so that referrers can use it if need be
129 assert(id < e->unique_id_limit &&
130 "number of enum members exceeds what was expected");
131 cd->unique_id = id;
132 symtab.declare(m.first, cd);
133 index++;
134 id++;
135 }
136 }
137
visit_eq(Eq & n)138 void visit_eq(Eq &n) final { visit_bexpr(n); }
139
visit_exists(Exists & n)140 void visit_exists(Exists &n) final {
141 symtab.open_scope();
142 dispatch(n.quantifier);
143 dispatch(*n.expr);
144 symtab.close_scope();
145 disambiguate(n.expr);
146 }
147
visit_exprid(ExprID & n)148 void visit_exprid(ExprID &n) final {
149 if (n.value == nullptr) {
150 // This reference is unresolved
151
152 Ptr<ExprDecl> d = symtab.lookup<ExprDecl>(n.id, n.loc);
153 if (d == nullptr)
154 throw Error("unknown symbol \"" + n.id + "\"", n.loc);
155
156 n.value = d;
157 }
158 }
159
visit_field(Field & n)160 void visit_field(Field &n) final {
161 dispatch(*n.record);
162 disambiguate(n.record);
163 }
164
visit_for(For & n)165 void visit_for(For &n) final {
166 symtab.open_scope();
167 dispatch(n.quantifier);
168 for (auto &s : n.body)
169 dispatch(*s);
170 symtab.close_scope();
171 }
172
visit_forall(Forall & n)173 void visit_forall(Forall &n) final {
174 symtab.open_scope();
175 dispatch(n.quantifier);
176 dispatch(*n.expr);
177 symtab.close_scope();
178 disambiguate(n.expr);
179 }
180
visit_function(Function & n)181 void visit_function(Function &n) final {
182 symtab.open_scope();
183 for (auto &p : n.parameters)
184 dispatch(*p);
185 if (n.return_type != nullptr)
186 dispatch(*n.return_type);
187 // register the function itself, even though its body has not yet been
188 // resolved, in order to allow contained function calls to resolve to the
189 // containing function, supporting recursion
190 symtab.declare(n.name, Ptr<Function>::make(n));
191 // only register the function parameters now, to avoid their names shadowing
192 // anything that needs to be resolved during symbol resolution of another
193 // parameter or the return type
194 for (auto &p : n.parameters)
195 symtab.declare(p->name, p);
196 for (auto &d : n.decls) {
197 dispatch(*d);
198 symtab.declare(d->name, d);
199 }
200 for (auto &s : n.body)
201 dispatch(*s);
202 symtab.close_scope();
203 }
204
visit_functioncall(FunctionCall & n)205 void visit_functioncall(FunctionCall &n) final {
206 if (n.function == nullptr) {
207 // This reference is unresolved
208
209 Ptr<Function> f = symtab.lookup<Function>(n.name, n.loc);
210 if (f == nullptr)
211 throw Error("unknown function call \"" + n.name + "\"", n.loc);
212
213 n.function = f;
214 }
215 for (auto &a : n.arguments)
216 dispatch(*a);
217
218 for (Ptr<Expr> &a : n.arguments)
219 disambiguate(a);
220 }
221
visit_geq(Geq & n)222 void visit_geq(Geq &n) final { visit_bexpr(n); }
223
visit_gt(Gt & n)224 void visit_gt(Gt &n) final { visit_bexpr(n); }
225
visit_ifclause(IfClause & n)226 void visit_ifclause(IfClause &n) final {
227 if (n.condition != nullptr)
228 dispatch(*n.condition);
229 for (auto &s : n.body)
230 dispatch(*s);
231 if (n.condition != nullptr)
232 disambiguate(n.condition);
233 }
234
visit_implication(Implication & n)235 void visit_implication(Implication &n) final { visit_bexpr(n); }
236
visit_isundefined(IsUndefined & n)237 void visit_isundefined(IsUndefined &n) final { visit_uexpr(n); }
238
visit_leq(Leq & n)239 void visit_leq(Leq &n) final { visit_bexpr(n); }
240
visit_lsh(Lsh & n)241 void visit_lsh(Lsh &n) final { visit_bexpr(n); }
242
visit_lt(Lt & n)243 void visit_lt(Lt &n) final { visit_bexpr(n); }
244
visit_model(Model & n)245 void visit_model(Model &n) final {
246
247 // running marker of offset in the global state data
248 mpz_class offset = 0;
249
250 /* whether we have not yet hit any problems that make offset calculation
251 * impossible
252 */
253 bool ok = true;
254
255 for (Ptr<Node> &c : n.children) {
256 dispatch(*c);
257
258 /* if this was a variable declaration, we now know enough to determine its
259 * offset in the global state data
260 */
261 if (ok) {
262 if (auto v = dynamic_cast<VarDecl *>(c.get())) {
263
264 /* If the declaration or one of its children does not validate, it is
265 * unsafe to call width().
266 */
267 try {
268 validate(*v);
269 } catch (Error &) {
270 /* Skip this and future offset calculations and assume our caller
271 * will eventually discover the underlying reason when they call
272 * n.validate().
273 */
274 ok = false;
275 }
276
277 if (ok) {
278 v->offset = offset;
279 offset += v->type->width();
280 }
281 }
282 }
283
284 if (auto d = dynamic_cast<Decl *>(c.get()))
285 symtab.declare(d->name, c);
286 if (auto f = dynamic_cast<Function *>(c.get()))
287 symtab.declare(f->name, c);
288 }
289 }
290
visit_mod(Mod & n)291 void visit_mod(Mod &n) final { visit_bexpr(n); }
292
visit_mul(Mul & n)293 void visit_mul(Mul &n) final { visit_bexpr(n); }
294
visit_negative(Negative & n)295 void visit_negative(Negative &n) final { visit_uexpr(n); }
296
visit_neq(Neq & n)297 void visit_neq(Neq &n) final { visit_bexpr(n); }
298
visit_not(Not & n)299 void visit_not(Not &n) final { visit_uexpr(n); }
300
visit_or(Or & n)301 void visit_or(Or &n) final { visit_bexpr(n); }
302
visit_property(Property & n)303 void visit_property(Property &n) final {
304 dispatch(*n.expr);
305 disambiguate(n.expr);
306 }
307
visit_put(Put & n)308 void visit_put(Put &n) final {
309 if (n.expr != nullptr) {
310 dispatch(*n.expr);
311 disambiguate(n.expr);
312 }
313 }
314
visit_quantifier(Quantifier & n)315 void visit_quantifier(Quantifier &n) final {
316 if (n.type != nullptr) {
317 // wrap symbol resolution within the type in a dummy scope to suppress any
318 // declarations (primarily enum members) as these will be duplicated in
319 // when we descend into decl below
320 symtab.open_scope();
321 dispatch(*n.type);
322 symtab.close_scope();
323 }
324 if (n.from != nullptr)
325 dispatch(*n.from);
326 if (n.to != nullptr)
327 dispatch(*n.to);
328 if (n.step != nullptr)
329 dispatch(*n.step);
330
331 if (n.from != nullptr)
332 disambiguate(n.from);
333 if (n.to != nullptr)
334 disambiguate(n.to);
335 if (n.step != nullptr)
336 disambiguate(n.step);
337
338 // if the bounds for this iteration are now known to be constant, we can
339 // narrow its VarDecl
340 if (n.from != nullptr && n.from->constant() && n.to != nullptr &&
341 n.to->constant()) {
342 auto r = dynamic_cast<Range *>(n.decl->type.get());
343 assert(r != nullptr && "non-range type used for inferred loop decl");
344 // the range may have been given as either an up count or down count
345 if (n.from->constant_fold() <= n.to->constant_fold()) {
346 r->min = n.from;
347 r->max = n.to;
348 } else {
349 r->min = n.to;
350 r->max = n.from;
351 }
352 }
353
354 dispatch(*n.decl);
355
356 symtab.declare(n.name, n.decl);
357 }
358
visit_range(Range & n)359 void visit_range(Range &n) final {
360 dispatch(*n.min);
361 disambiguate(n.min);
362 dispatch(*n.max);
363 disambiguate(n.max);
364 }
365
visit_return(Return & n)366 void visit_return(Return &n) final {
367 if (n.expr != nullptr) {
368 dispatch(*n.expr);
369 disambiguate(n.expr);
370 }
371 }
372
visit_rsh(Rsh & n)373 void visit_rsh(Rsh &n) final { visit_bexpr(n); }
374
visit_ruleset(Ruleset & n)375 void visit_ruleset(Ruleset &n) final {
376 symtab.open_scope();
377 for (Quantifier &q : n.quantifiers)
378 dispatch(q);
379 for (auto &r : n.rules)
380 dispatch(*r);
381 symtab.close_scope();
382 }
383
visit_scalarset(Scalarset & n)384 void visit_scalarset(Scalarset &n) final {
385 dispatch(*n.bound);
386 disambiguate(n.bound);
387 }
388
visit_simplerule(SimpleRule & n)389 void visit_simplerule(SimpleRule &n) final {
390 symtab.open_scope();
391 for (Quantifier &q : n.quantifiers)
392 dispatch(q);
393 if (n.guard != nullptr)
394 dispatch(*n.guard);
395 for (auto &d : n.decls) {
396 dispatch(*d);
397 symtab.declare(d->name, d);
398 }
399 for (auto &s : n.body)
400 dispatch(*s);
401 symtab.close_scope();
402 if (n.guard != nullptr)
403 disambiguate(n.guard);
404 }
405
visit_startstate(StartState & n)406 void visit_startstate(StartState &n) final {
407 symtab.open_scope();
408 for (Quantifier &q : n.quantifiers)
409 dispatch(q);
410 for (auto &d : n.decls) {
411 dispatch(*d);
412 symtab.declare(d->name, d);
413 }
414 for (auto &s : n.body)
415 dispatch(*s);
416 symtab.close_scope();
417 }
418
visit_sub(Sub & n)419 void visit_sub(Sub &n) final { visit_bexpr(n); }
420
visit_switch(Switch & n)421 void visit_switch(Switch &n) final {
422 dispatch(*n.expr);
423 for (SwitchCase &c : n.cases)
424 dispatch(c);
425 disambiguate(n.expr);
426 }
427
visit_switchcase(SwitchCase & n)428 void visit_switchcase(SwitchCase &n) final {
429 for (auto &m : n.matches)
430 dispatch(*m);
431 for (auto &s : n.body)
432 dispatch(*s);
433 for (Ptr<Expr> &m : n.matches)
434 disambiguate(m);
435 }
436
visit_ternary(Ternary & n)437 void visit_ternary(Ternary &n) {
438 dispatch(*n.cond);
439 dispatch(*n.lhs);
440 dispatch(*n.rhs);
441 disambiguate(n.cond);
442 disambiguate(n.lhs);
443 disambiguate(n.rhs);
444 }
445
visit_typeexprid(TypeExprID & n)446 void visit_typeexprid(TypeExprID &n) final {
447 if (n.referent == nullptr) {
448 // This reference is unresolved
449
450 Ptr<TypeDecl> t = symtab.lookup<TypeDecl>(n.name, n.loc);
451 if (t == nullptr)
452 throw Error("unknown type symbol \"" + n.name + "\"", n.loc);
453
454 n.referent = t;
455 }
456 }
457
visit_undefine(Undefine & n)458 void visit_undefine(Undefine &n) final {
459 dispatch(*n.rhs);
460 disambiguate(n.rhs);
461 }
462
visit_while(While & n)463 void visit_while(While &n) final {
464 dispatch(*n.condition);
465 for (auto &s : n.body)
466 dispatch(*s);
467 disambiguate(n.condition);
468 }
469
visit_xor(Xor & n)470 void visit_xor(Xor &n) final { visit_bexpr(n); }
471
472 virtual ~Resolver() = default;
473
474 private:
visit_bexpr(BinaryExpr & n)475 void visit_bexpr(BinaryExpr &n) {
476 dispatch(*n.lhs);
477 dispatch(*n.rhs);
478 disambiguate(n.lhs);
479 disambiguate(n.rhs);
480 }
481
visit_uexpr(UnaryExpr & n)482 void visit_uexpr(UnaryExpr &n) {
483 dispatch(*n.rhs);
484 disambiguate(n.rhs);
485 }
486
487 // detect whether this is an ambiguous node and, if so, resolve it into its
488 // more precise AST node type
disambiguate(Ptr<Expr> & e)489 void disambiguate(Ptr<Expr> &e) {
490
491 if (auto a = dynamic_cast<const AmbiguousAmp *>(e.get())) {
492
493 // try to get the type of the left hand side
494 Ptr<TypeExpr> t;
495 try {
496 t = a->lhs->type();
497 } catch (Error &) {
498 // We failed because the left operand is somehow invalid. Silently
499 // ignore this, assuming it will be rediscovered during AST validation.
500 return;
501 }
502
503 // Form an unambiguous replacement node based on the type of the left
504 // operand. Note that the types of the left and right operands may be
505 // incompatible. However, this will cause an error during AST validation
506 // so we do not need to worry about that here.
507 Ptr<Expr> replacement;
508 if (isa<Range>(t)) {
509 replacement = Ptr<Band>::make(a->lhs, a->rhs, a->loc);
510 } else {
511 replacement = Ptr<And>::make(a->lhs, a->rhs, a->loc);
512 }
513
514 // also preserve the identifier which has already been set
515 replacement->unique_id = a->unique_id;
516
517 // replace the ambiguous node
518 e = replacement;
519
520 return;
521 }
522
523 if (auto o = dynamic_cast<const AmbiguousPipe *>(e.get())) {
524
525 // try to get the type of the left hand side
526 Ptr<TypeExpr> t;
527 try {
528 t = o->lhs->type();
529 } catch (Error &) {
530 // We failed because the left operand is somehow invalid. Silently
531 // ignore this, assuming it will be rediscovered during AST validation.
532 return;
533 }
534
535 // Form an unambiguous replacement node based on the type of the left
536 // operand. Note that the types of the left and right operands may be
537 // incompatible. However, this will cause an error during AST validation
538 // so we do not need to worry about that here.
539 Ptr<Expr> replacement;
540 if (isa<Range>(t)) {
541 replacement = Ptr<Bor>::make(o->lhs, o->rhs, o->loc);
542 } else {
543 replacement = Ptr<Or>::make(o->lhs, o->rhs, o->loc);
544 }
545
546 // also preserve the identifier which has already been set
547 replacement->unique_id = o->unique_id;
548
549 // replace the ambiguous node
550 e = replacement;
551
552 return;
553 }
554 }
555 };
556
557 } // namespace
558
resolve_symbols(Model & m)559 void rumur::resolve_symbols(Model &m) {
560 Resolver r;
561 r.dispatch(m);
562 }
563