1 //===- AST.cpp - Helper for printing out the Toy AST ----------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AST dump for the Toy language.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "toy/AST.h"
14 
15 #include "mlir/ADT/TypeSwitch.h"
16 #include "mlir/Support/STLExtras.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Support/raw_ostream.h"
19 
20 using namespace toy;
21 
22 namespace {
23 
24 // RAII helper to manage increasing/decreasing the indentation as we traverse
25 // the AST
26 struct Indent {
Indent__anon3ab142f60111::Indent27   Indent(int &level) : level(level) { ++level; }
~Indent__anon3ab142f60111::Indent28   ~Indent() { --level; }
29   int &level;
30 };
31 
32 /// Helper class that implement the AST tree traversal and print the nodes along
33 /// the way. The only data member is the current indentation level.
34 class ASTDumper {
35 public:
36   void dump(ModuleAST *node);
37 
38 private:
39   void dump(const VarType &type);
40   void dump(VarDeclExprAST *varDecl);
41   void dump(ExprAST *expr);
42   void dump(ExprASTList *exprList);
43   void dump(NumberExprAST *num);
44   void dump(LiteralExprAST *node);
45   void dump(VariableExprAST *node);
46   void dump(ReturnExprAST *node);
47   void dump(BinaryExprAST *node);
48   void dump(CallExprAST *node);
49   void dump(PrintExprAST *node);
50   void dump(PrototypeAST *node);
51   void dump(FunctionAST *node);
52 
53   // Actually print spaces matching the current indentation level
indent()54   void indent() {
55     for (int i = 0; i < curIndent; i++)
56       llvm::errs() << "  ";
57   }
58   int curIndent = 0;
59 };
60 
61 } // namespace
62 
63 /// Return a formatted string for the location of any node
loc(T * node)64 template <typename T> static std::string loc(T *node) {
65   const auto &loc = node->loc();
66   return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
67           llvm::Twine(loc.col))
68       .str();
69 }
70 
71 // Helper Macro to bump the indentation level and print the leading spaces for
72 // the current indentations
73 #define INDENT()                                                               \
74   Indent level_(curIndent);                                                    \
75   indent();
76 
77 /// Dispatch to a generic expressions to the appropriate subclass using RTTI
dump(ExprAST * expr)78 void ASTDumper::dump(ExprAST *expr) {
79   mlir::TypeSwitch<ExprAST *>(expr)
80       .Case<BinaryExprAST, CallExprAST, LiteralExprAST, NumberExprAST,
81             PrintExprAST, ReturnExprAST, VarDeclExprAST, VariableExprAST>(
82           [&](auto *node) { this->dump(node); })
83       .Default([&](ExprAST *) {
84         // No match, fallback to a generic message
85         INDENT();
86         llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
87       });
88 }
89 
90 /// A variable declaration is printing the variable name, the type, and then
91 /// recurse in the initializer value.
dump(VarDeclExprAST * varDecl)92 void ASTDumper::dump(VarDeclExprAST *varDecl) {
93   INDENT();
94   llvm::errs() << "VarDecl " << varDecl->getName();
95   dump(varDecl->getType());
96   llvm::errs() << " " << loc(varDecl) << "\n";
97   dump(varDecl->getInitVal());
98 }
99 
100 /// A "block", or a list of expression
dump(ExprASTList * exprList)101 void ASTDumper::dump(ExprASTList *exprList) {
102   INDENT();
103   llvm::errs() << "Block {\n";
104   for (auto &expr : *exprList)
105     dump(expr.get());
106   indent();
107   llvm::errs() << "} // Block\n";
108 }
109 
110 /// A literal number, just print the value.
dump(NumberExprAST * num)111 void ASTDumper::dump(NumberExprAST *num) {
112   INDENT();
113   llvm::errs() << num->getValue() << " " << loc(num) << "\n";
114 }
115 
116 /// Helper to print recursively a literal. This handles nested array like:
117 ///    [ [ 1, 2 ], [ 3, 4 ] ]
118 /// We print out such array with the dimensions spelled out at every level:
119 ///    <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
printLitHelper(ExprAST * litOrNum)120 void printLitHelper(ExprAST *litOrNum) {
121   // Inside a literal expression we can have either a number or another literal
122   if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
123     llvm::errs() << num->getValue();
124     return;
125   }
126   auto *literal = llvm::cast<LiteralExprAST>(litOrNum);
127 
128   // Print the dimension for this literal first
129   llvm::errs() << "<";
130   mlir::interleaveComma(literal->getDims(), llvm::errs());
131   llvm::errs() << ">";
132 
133   // Now print the content, recursing on every element of the list
134   llvm::errs() << "[ ";
135   mlir::interleaveComma(literal->getValues(), llvm::errs(),
136                         [&](auto &elt) { printLitHelper(elt.get()); });
137   llvm::errs() << "]";
138 }
139 
140 /// Print a literal, see the recursive helper above for the implementation.
dump(LiteralExprAST * node)141 void ASTDumper::dump(LiteralExprAST *node) {
142   INDENT();
143   llvm::errs() << "Literal: ";
144   printLitHelper(node);
145   llvm::errs() << " " << loc(node) << "\n";
146 }
147 
148 /// Print a variable reference (just a name).
dump(VariableExprAST * node)149 void ASTDumper::dump(VariableExprAST *node) {
150   INDENT();
151   llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n";
152 }
153 
154 /// Return statement print the return and its (optional) argument.
dump(ReturnExprAST * node)155 void ASTDumper::dump(ReturnExprAST *node) {
156   INDENT();
157   llvm::errs() << "Return\n";
158   if (node->getExpr().hasValue())
159     return dump(*node->getExpr());
160   {
161     INDENT();
162     llvm::errs() << "(void)\n";
163   }
164 }
165 
166 /// Print a binary operation, first the operator, then recurse into LHS and RHS.
dump(BinaryExprAST * node)167 void ASTDumper::dump(BinaryExprAST *node) {
168   INDENT();
169   llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n";
170   dump(node->getLHS());
171   dump(node->getRHS());
172 }
173 
174 /// Print a call expression, first the callee name and the list of args by
175 /// recursing into each individual argument.
dump(CallExprAST * node)176 void ASTDumper::dump(CallExprAST *node) {
177   INDENT();
178   llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n";
179   for (auto &arg : node->getArgs())
180     dump(arg.get());
181   indent();
182   llvm::errs() << "]\n";
183 }
184 
185 /// Print a builtin print call, first the builtin name and then the argument.
dump(PrintExprAST * node)186 void ASTDumper::dump(PrintExprAST *node) {
187   INDENT();
188   llvm::errs() << "Print [ " << loc(node) << "\n";
189   dump(node->getArg());
190   indent();
191   llvm::errs() << "]\n";
192 }
193 
194 /// Print type: only the shape is printed in between '<' and '>'
dump(const VarType & type)195 void ASTDumper::dump(const VarType &type) {
196   llvm::errs() << "<";
197   mlir::interleaveComma(type.shape, llvm::errs());
198   llvm::errs() << ">";
199 }
200 
201 /// Print a function prototype, first the function name, and then the list of
202 /// parameters names.
dump(PrototypeAST * node)203 void ASTDumper::dump(PrototypeAST *node) {
204   INDENT();
205   llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "'\n";
206   indent();
207   llvm::errs() << "Params: [";
208   mlir::interleaveComma(node->getArgs(), llvm::errs(),
209                         [](auto &arg) { llvm::errs() << arg->getName(); });
210   llvm::errs() << "]\n";
211 }
212 
213 /// Print a function, first the prototype and then the body.
dump(FunctionAST * node)214 void ASTDumper::dump(FunctionAST *node) {
215   INDENT();
216   llvm::errs() << "Function \n";
217   dump(node->getProto());
218   dump(node->getBody());
219 }
220 
221 /// Print a module, actually loop over the functions and print them in sequence.
dump(ModuleAST * node)222 void ASTDumper::dump(ModuleAST *node) {
223   INDENT();
224   llvm::errs() << "Module:\n";
225   for (auto &f : *node)
226     dump(&f);
227 }
228 
229 namespace toy {
230 
231 // Public API
dump(ModuleAST & module)232 void dump(ModuleAST &module) { ASTDumper().dump(&module); }
233 
234 } // namespace toy
235