1 //===- AST.h - Node definition for the Toy AST ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AST for the Toy language. It is optimized for
10 // simplicity, not efficiency. The AST forms a tree structure where each node
11 // references its children using std::unique_ptr<>.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_TUTORIAL_TOY_AST_H_
16 #define MLIR_TUTORIAL_TOY_AST_H_
17 
18 #include "toy/Lexer.h"
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/Support/Casting.h"
23 #include <vector>
24 
25 namespace toy {
26 
27 /// A variable type with shape information.
28 struct VarType {
29   std::vector<int64_t> shape;
30 };
31 
32 /// Base class for all expression nodes.
33 class ExprAST {
34 public:
35   enum ExprASTKind {
36     Expr_VarDecl,
37     Expr_Return,
38     Expr_Num,
39     Expr_Literal,
40     Expr_Var,
41     Expr_BinOp,
42     Expr_Call,
43     Expr_Print,
44   };
45 
ExprAST(ExprASTKind kind,Location location)46   ExprAST(ExprASTKind kind, Location location)
47       : kind(kind), location(location) {}
48   virtual ~ExprAST() = default;
49 
getKind()50   ExprASTKind getKind() const { return kind; }
51 
loc()52   const Location &loc() { return location; }
53 
54 private:
55   const ExprASTKind kind;
56   Location location;
57 };
58 
59 /// A block-list of expressions.
60 using ExprASTList = std::vector<std::unique_ptr<ExprAST>>;
61 
62 /// Expression class for numeric literals like "1.0".
63 class NumberExprAST : public ExprAST {
64   double Val;
65 
66 public:
NumberExprAST(Location loc,double val)67   NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {}
68 
getValue()69   double getValue() { return Val; }
70 
71   /// LLVM style RTTI
classof(const ExprAST * c)72   static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
73 };
74 
75 /// Expression class for a literal value.
76 class LiteralExprAST : public ExprAST {
77   std::vector<std::unique_ptr<ExprAST>> values;
78   std::vector<int64_t> dims;
79 
80 public:
LiteralExprAST(Location loc,std::vector<std::unique_ptr<ExprAST>> values,std::vector<int64_t> dims)81   LiteralExprAST(Location loc, std::vector<std::unique_ptr<ExprAST>> values,
82                  std::vector<int64_t> dims)
83       : ExprAST(Expr_Literal, loc), values(std::move(values)),
84         dims(std::move(dims)) {}
85 
getValues()86   llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; }
getDims()87   llvm::ArrayRef<int64_t> getDims() { return dims; }
88 
89   /// LLVM style RTTI
classof(const ExprAST * c)90   static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; }
91 };
92 
93 /// Expression class for referencing a variable, like "a".
94 class VariableExprAST : public ExprAST {
95   std::string name;
96 
97 public:
VariableExprAST(Location loc,llvm::StringRef name)98   VariableExprAST(Location loc, llvm::StringRef name)
99       : ExprAST(Expr_Var, loc), name(name) {}
100 
getName()101   llvm::StringRef getName() { return name; }
102 
103   /// LLVM style RTTI
classof(const ExprAST * c)104   static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; }
105 };
106 
107 /// Expression class for defining a variable.
108 class VarDeclExprAST : public ExprAST {
109   std::string name;
110   VarType type;
111   std::unique_ptr<ExprAST> initVal;
112 
113 public:
VarDeclExprAST(Location loc,llvm::StringRef name,VarType type,std::unique_ptr<ExprAST> initVal)114   VarDeclExprAST(Location loc, llvm::StringRef name, VarType type,
115                  std::unique_ptr<ExprAST> initVal)
116       : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
117         initVal(std::move(initVal)) {}
118 
getName()119   llvm::StringRef getName() { return name; }
getInitVal()120   ExprAST *getInitVal() { return initVal.get(); }
getType()121   const VarType &getType() { return type; }
122 
123   /// LLVM style RTTI
classof(const ExprAST * c)124   static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; }
125 };
126 
127 /// Expression class for a return operator.
128 class ReturnExprAST : public ExprAST {
129   llvm::Optional<std::unique_ptr<ExprAST>> expr;
130 
131 public:
ReturnExprAST(Location loc,llvm::Optional<std::unique_ptr<ExprAST>> expr)132   ReturnExprAST(Location loc, llvm::Optional<std::unique_ptr<ExprAST>> expr)
133       : ExprAST(Expr_Return, loc), expr(std::move(expr)) {}
134 
getExpr()135   llvm::Optional<ExprAST *> getExpr() {
136     if (expr.hasValue())
137       return expr->get();
138     return llvm::None;
139   }
140 
141   /// LLVM style RTTI
classof(const ExprAST * c)142   static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; }
143 };
144 
145 /// Expression class for a binary operator.
146 class BinaryExprAST : public ExprAST {
147   char op;
148   std::unique_ptr<ExprAST> lhs, rhs;
149 
150 public:
getOp()151   char getOp() { return op; }
getLHS()152   ExprAST *getLHS() { return lhs.get(); }
getRHS()153   ExprAST *getRHS() { return rhs.get(); }
154 
BinaryExprAST(Location loc,char Op,std::unique_ptr<ExprAST> lhs,std::unique_ptr<ExprAST> rhs)155   BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs,
156                 std::unique_ptr<ExprAST> rhs)
157       : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)),
158         rhs(std::move(rhs)) {}
159 
160   /// LLVM style RTTI
classof(const ExprAST * c)161   static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; }
162 };
163 
164 /// Expression class for function calls.
165 class CallExprAST : public ExprAST {
166   std::string callee;
167   std::vector<std::unique_ptr<ExprAST>> args;
168 
169 public:
CallExprAST(Location loc,const std::string & callee,std::vector<std::unique_ptr<ExprAST>> args)170   CallExprAST(Location loc, const std::string &callee,
171               std::vector<std::unique_ptr<ExprAST>> args)
172       : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {}
173 
getCallee()174   llvm::StringRef getCallee() { return callee; }
getArgs()175   llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; }
176 
177   /// LLVM style RTTI
classof(const ExprAST * c)178   static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; }
179 };
180 
181 /// Expression class for builtin print calls.
182 class PrintExprAST : public ExprAST {
183   std::unique_ptr<ExprAST> arg;
184 
185 public:
PrintExprAST(Location loc,std::unique_ptr<ExprAST> arg)186   PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg)
187       : ExprAST(Expr_Print, loc), arg(std::move(arg)) {}
188 
getArg()189   ExprAST *getArg() { return arg.get(); }
190 
191   /// LLVM style RTTI
classof(const ExprAST * c)192   static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; }
193 };
194 
195 /// This class represents the "prototype" for a function, which captures its
196 /// name, and its argument names (thus implicitly the number of arguments the
197 /// function takes).
198 class PrototypeAST {
199   Location location;
200   std::string name;
201   std::vector<std::unique_ptr<VariableExprAST>> args;
202 
203 public:
PrototypeAST(Location location,const std::string & name,std::vector<std::unique_ptr<VariableExprAST>> args)204   PrototypeAST(Location location, const std::string &name,
205                std::vector<std::unique_ptr<VariableExprAST>> args)
206       : location(location), name(name), args(std::move(args)) {}
207 
loc()208   const Location &loc() { return location; }
getName()209   llvm::StringRef getName() const { return name; }
getArgs()210   llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; }
211 };
212 
213 /// This class represents a function definition itself.
214 class FunctionAST {
215   std::unique_ptr<PrototypeAST> proto;
216   std::unique_ptr<ExprASTList> body;
217 
218 public:
FunctionAST(std::unique_ptr<PrototypeAST> proto,std::unique_ptr<ExprASTList> body)219   FunctionAST(std::unique_ptr<PrototypeAST> proto,
220               std::unique_ptr<ExprASTList> body)
221       : proto(std::move(proto)), body(std::move(body)) {}
getProto()222   PrototypeAST *getProto() { return proto.get(); }
getBody()223   ExprASTList *getBody() { return body.get(); }
224 };
225 
226 /// This class represents a list of functions to be processed together
227 class ModuleAST {
228   std::vector<FunctionAST> functions;
229 
230 public:
ModuleAST(std::vector<FunctionAST> functions)231   ModuleAST(std::vector<FunctionAST> functions)
232       : functions(std::move(functions)) {}
233 
234   auto begin() -> decltype(functions.begin()) { return functions.begin(); }
235   auto end() -> decltype(functions.end()) { return functions.end(); }
236 };
237 
238 void dump(ModuleAST &);
239 
240 } // namespace toy
241 
242 #endif // MLIR_TUTORIAL_TOY_AST_H_
243