1 //===- AST.cpp - Helper for printing out the Toy AST ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AST dump for the Toy language.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "toy/AST.h"
14 
15 #include "llvm/ADT/Twine.h"
16 #include "llvm/ADT/TypeSwitch.h"
17 #include "llvm/Support/raw_ostream.h"
18 
19 using namespace toy;
20 
21 namespace {
22 
23 // RAII helper to manage increasing/decreasing the indentation as we traverse
24 // the AST
25 struct Indent {
Indent__anon638b651a0111::Indent26   Indent(int &level) : level(level) { ++level; }
~Indent__anon638b651a0111::Indent27   ~Indent() { --level; }
28   int &level;
29 };
30 
31 /// Helper class that implement the AST tree traversal and print the nodes along
32 /// the way. The only data member is the current indentation level.
33 class ASTDumper {
34 public:
35   void dump(ModuleAST *node);
36 
37 private:
38   void dump(const VarType &type);
39   void dump(VarDeclExprAST *varDecl);
40   void dump(ExprAST *expr);
41   void dump(ExprASTList *exprList);
42   void dump(NumberExprAST *num);
43   void dump(LiteralExprAST *node);
44   void dump(VariableExprAST *node);
45   void dump(ReturnExprAST *node);
46   void dump(BinaryExprAST *node);
47   void dump(CallExprAST *node);
48   void dump(PrintExprAST *node);
49   void dump(PrototypeAST *node);
50   void dump(FunctionAST *node);
51 
52   // Actually print spaces matching the current indentation level
indent()53   void indent() {
54     for (int i = 0; i < curIndent; i++)
55       llvm::errs() << "  ";
56   }
57   int curIndent = 0;
58 };
59 
60 } // namespace
61 
62 /// Return a formatted string for the location of any node
loc(T * node)63 template <typename T> static std::string loc(T *node) {
64   const auto &loc = node->loc();
65   return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
66           llvm::Twine(loc.col))
67       .str();
68 }
69 
70 // Helper Macro to bump the indentation level and print the leading spaces for
71 // the current indentations
72 #define INDENT()                                                               \
73   Indent level_(curIndent);                                                    \
74   indent();
75 
76 /// Dispatch to a generic expressions to the appropriate subclass using RTTI
dump(ExprAST * expr)77 void ASTDumper::dump(ExprAST *expr) {
78   llvm::TypeSwitch<ExprAST *>(expr)
79       .Case<BinaryExprAST, CallExprAST, LiteralExprAST, NumberExprAST,
80             PrintExprAST, ReturnExprAST, VarDeclExprAST, VariableExprAST>(
81           [&](auto *node) { this->dump(node); })
82       .Default([&](ExprAST *) {
83         // No match, fallback to a generic message
84         INDENT();
85         llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
86       });
87 }
88 
89 /// A variable declaration is printing the variable name, the type, and then
90 /// recurse in the initializer value.
dump(VarDeclExprAST * varDecl)91 void ASTDumper::dump(VarDeclExprAST *varDecl) {
92   INDENT();
93   llvm::errs() << "VarDecl " << varDecl->getName();
94   dump(varDecl->getType());
95   llvm::errs() << " " << loc(varDecl) << "\n";
96   dump(varDecl->getInitVal());
97 }
98 
99 /// A "block", or a list of expression
dump(ExprASTList * exprList)100 void ASTDumper::dump(ExprASTList *exprList) {
101   INDENT();
102   llvm::errs() << "Block {\n";
103   for (auto &expr : *exprList)
104     dump(expr.get());
105   indent();
106   llvm::errs() << "} // Block\n";
107 }
108 
109 /// A literal number, just print the value.
dump(NumberExprAST * num)110 void ASTDumper::dump(NumberExprAST *num) {
111   INDENT();
112   llvm::errs() << num->getValue() << " " << loc(num) << "\n";
113 }
114 
115 /// Helper to print recursively a literal. This handles nested array like:
116 ///    [ [ 1, 2 ], [ 3, 4 ] ]
117 /// We print out such array with the dimensions spelled out at every level:
118 ///    <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
printLitHelper(ExprAST * litOrNum)119 void printLitHelper(ExprAST *litOrNum) {
120   // Inside a literal expression we can have either a number or another literal
121   if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
122     llvm::errs() << num->getValue();
123     return;
124   }
125   auto *literal = llvm::cast<LiteralExprAST>(litOrNum);
126 
127   // Print the dimension for this literal first
128   llvm::errs() << "<";
129   llvm::interleaveComma(literal->getDims(), llvm::errs());
130   llvm::errs() << ">";
131 
132   // Now print the content, recursing on every element of the list
133   llvm::errs() << "[ ";
134   llvm::interleaveComma(literal->getValues(), llvm::errs(),
135                         [&](auto &elt) { printLitHelper(elt.get()); });
136   llvm::errs() << "]";
137 }
138 
139 /// Print a literal, see the recursive helper above for the implementation.
dump(LiteralExprAST * node)140 void ASTDumper::dump(LiteralExprAST *node) {
141   INDENT();
142   llvm::errs() << "Literal: ";
143   printLitHelper(node);
144   llvm::errs() << " " << loc(node) << "\n";
145 }
146 
147 /// Print a variable reference (just a name).
dump(VariableExprAST * node)148 void ASTDumper::dump(VariableExprAST *node) {
149   INDENT();
150   llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n";
151 }
152 
153 /// Return statement print the return and its (optional) argument.
dump(ReturnExprAST * node)154 void ASTDumper::dump(ReturnExprAST *node) {
155   INDENT();
156   llvm::errs() << "Return\n";
157   if (node->getExpr().hasValue())
158     return dump(*node->getExpr());
159   {
160     INDENT();
161     llvm::errs() << "(void)\n";
162   }
163 }
164 
165 /// Print a binary operation, first the operator, then recurse into LHS and RHS.
dump(BinaryExprAST * node)166 void ASTDumper::dump(BinaryExprAST *node) {
167   INDENT();
168   llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n";
169   dump(node->getLHS());
170   dump(node->getRHS());
171 }
172 
173 /// Print a call expression, first the callee name and the list of args by
174 /// recursing into each individual argument.
dump(CallExprAST * node)175 void ASTDumper::dump(CallExprAST *node) {
176   INDENT();
177   llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n";
178   for (auto &arg : node->getArgs())
179     dump(arg.get());
180   indent();
181   llvm::errs() << "]\n";
182 }
183 
184 /// Print a builtin print call, first the builtin name and then the argument.
dump(PrintExprAST * node)185 void ASTDumper::dump(PrintExprAST *node) {
186   INDENT();
187   llvm::errs() << "Print [ " << loc(node) << "\n";
188   dump(node->getArg());
189   indent();
190   llvm::errs() << "]\n";
191 }
192 
193 /// Print type: only the shape is printed in between '<' and '>'
dump(const VarType & type)194 void ASTDumper::dump(const VarType &type) {
195   llvm::errs() << "<";
196   llvm::interleaveComma(type.shape, llvm::errs());
197   llvm::errs() << ">";
198 }
199 
200 /// Print a function prototype, first the function name, and then the list of
201 /// parameters names.
dump(PrototypeAST * node)202 void ASTDumper::dump(PrototypeAST *node) {
203   INDENT();
204   llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "\n";
205   indent();
206   llvm::errs() << "Params: [";
207   llvm::interleaveComma(node->getArgs(), llvm::errs(),
208                         [](auto &arg) { llvm::errs() << arg->getName(); });
209   llvm::errs() << "]\n";
210 }
211 
212 /// Print a function, first the prototype and then the body.
dump(FunctionAST * node)213 void ASTDumper::dump(FunctionAST *node) {
214   INDENT();
215   llvm::errs() << "Function \n";
216   dump(node->getProto());
217   dump(node->getBody());
218 }
219 
220 /// Print a module, actually loop over the functions and print them in sequence.
dump(ModuleAST * node)221 void ASTDumper::dump(ModuleAST *node) {
222   INDENT();
223   llvm::errs() << "Module:\n";
224   for (auto &f : *node)
225     dump(&f);
226 }
227 
228 namespace toy {
229 
230 // Public API
dump(ModuleAST & module)231 void dump(ModuleAST &module) { ASTDumper().dump(&module); }
232 
233 } // namespace toy
234