1 //===- AST.h - Node definition for the Toy AST ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the AST for the Toy language. It is optimized for 10 // simplicity, not efficiency. The AST forms a tree structure where each node 11 // references its children using std::unique_ptr<>. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #ifndef MLIR_TUTORIAL_TOY_AST_H_ 16 #define MLIR_TUTORIAL_TOY_AST_H_ 17 18 #include "toy/Lexer.h" 19 20 #include "llvm/ADT/ArrayRef.h" 21 #include "llvm/ADT/StringRef.h" 22 #include "llvm/Support/Casting.h" 23 #include <vector> 24 25 namespace toy { 26 27 /// A variable type with shape information. 28 struct VarType { 29 std::vector<int64_t> shape; 30 }; 31 32 /// Base class for all expression nodes. 33 class ExprAST { 34 public: 35 enum ExprASTKind { 36 Expr_VarDecl, 37 Expr_Return, 38 Expr_Num, 39 Expr_Literal, 40 Expr_Var, 41 Expr_BinOp, 42 Expr_Call, 43 Expr_Print, 44 }; 45 ExprAST(ExprASTKind kind,Location location)46 ExprAST(ExprASTKind kind, Location location) 47 : kind(kind), location(location) {} 48 virtual ~ExprAST() = default; 49 getKind()50 ExprASTKind getKind() const { return kind; } 51 loc()52 const Location &loc() { return location; } 53 54 private: 55 const ExprASTKind kind; 56 Location location; 57 }; 58 59 /// A block-list of expressions. 60 using ExprASTList = std::vector<std::unique_ptr<ExprAST>>; 61 62 /// Expression class for numeric literals like "1.0". 63 class NumberExprAST : public ExprAST { 64 double Val; 65 66 public: NumberExprAST(Location loc,double val)67 NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {} 68 getValue()69 double getValue() { return Val; } 70 71 /// LLVM style RTTI classof(const ExprAST * c)72 static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } 73 }; 74 75 /// Expression class for a literal value. 76 class LiteralExprAST : public ExprAST { 77 std::vector<std::unique_ptr<ExprAST>> values; 78 std::vector<int64_t> dims; 79 80 public: LiteralExprAST(Location loc,std::vector<std::unique_ptr<ExprAST>> values,std::vector<int64_t> dims)81 LiteralExprAST(Location loc, std::vector<std::unique_ptr<ExprAST>> values, 82 std::vector<int64_t> dims) 83 : ExprAST(Expr_Literal, loc), values(std::move(values)), 84 dims(std::move(dims)) {} 85 getValues()86 llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; } getDims()87 llvm::ArrayRef<int64_t> getDims() { return dims; } 88 89 /// LLVM style RTTI classof(const ExprAST * c)90 static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } 91 }; 92 93 /// Expression class for referencing a variable, like "a". 94 class VariableExprAST : public ExprAST { 95 std::string name; 96 97 public: VariableExprAST(Location loc,llvm::StringRef name)98 VariableExprAST(Location loc, llvm::StringRef name) 99 : ExprAST(Expr_Var, loc), name(name) {} 100 getName()101 llvm::StringRef getName() { return name; } 102 103 /// LLVM style RTTI classof(const ExprAST * c)104 static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } 105 }; 106 107 /// Expression class for defining a variable. 108 class VarDeclExprAST : public ExprAST { 109 std::string name; 110 VarType type; 111 std::unique_ptr<ExprAST> initVal; 112 113 public: VarDeclExprAST(Location loc,llvm::StringRef name,VarType type,std::unique_ptr<ExprAST> initVal)114 VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, 115 std::unique_ptr<ExprAST> initVal) 116 : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)), 117 initVal(std::move(initVal)) {} 118 getName()119 llvm::StringRef getName() { return name; } getInitVal()120 ExprAST *getInitVal() { return initVal.get(); } getType()121 const VarType &getType() { return type; } 122 123 /// LLVM style RTTI classof(const ExprAST * c)124 static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } 125 }; 126 127 /// Expression class for a return operator. 128 class ReturnExprAST : public ExprAST { 129 llvm::Optional<std::unique_ptr<ExprAST>> expr; 130 131 public: ReturnExprAST(Location loc,llvm::Optional<std::unique_ptr<ExprAST>> expr)132 ReturnExprAST(Location loc, llvm::Optional<std::unique_ptr<ExprAST>> expr) 133 : ExprAST(Expr_Return, loc), expr(std::move(expr)) {} 134 getExpr()135 llvm::Optional<ExprAST *> getExpr() { 136 if (expr.hasValue()) 137 return expr->get(); 138 return llvm::None; 139 } 140 141 /// LLVM style RTTI classof(const ExprAST * c)142 static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } 143 }; 144 145 /// Expression class for a binary operator. 146 class BinaryExprAST : public ExprAST { 147 char op; 148 std::unique_ptr<ExprAST> lhs, rhs; 149 150 public: getOp()151 char getOp() { return op; } getLHS()152 ExprAST *getLHS() { return lhs.get(); } getRHS()153 ExprAST *getRHS() { return rhs.get(); } 154 BinaryExprAST(Location loc,char Op,std::unique_ptr<ExprAST> lhs,std::unique_ptr<ExprAST> rhs)155 BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs, 156 std::unique_ptr<ExprAST> rhs) 157 : ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)), 158 rhs(std::move(rhs)) {} 159 160 /// LLVM style RTTI classof(const ExprAST * c)161 static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } 162 }; 163 164 /// Expression class for function calls. 165 class CallExprAST : public ExprAST { 166 std::string callee; 167 std::vector<std::unique_ptr<ExprAST>> args; 168 169 public: CallExprAST(Location loc,const std::string & callee,std::vector<std::unique_ptr<ExprAST>> args)170 CallExprAST(Location loc, const std::string &callee, 171 std::vector<std::unique_ptr<ExprAST>> args) 172 : ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {} 173 getCallee()174 llvm::StringRef getCallee() { return callee; } getArgs()175 llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; } 176 177 /// LLVM style RTTI classof(const ExprAST * c)178 static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } 179 }; 180 181 /// Expression class for builtin print calls. 182 class PrintExprAST : public ExprAST { 183 std::unique_ptr<ExprAST> arg; 184 185 public: PrintExprAST(Location loc,std::unique_ptr<ExprAST> arg)186 PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg) 187 : ExprAST(Expr_Print, loc), arg(std::move(arg)) {} 188 getArg()189 ExprAST *getArg() { return arg.get(); } 190 191 /// LLVM style RTTI classof(const ExprAST * c)192 static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } 193 }; 194 195 /// This class represents the "prototype" for a function, which captures its 196 /// name, and its argument names (thus implicitly the number of arguments the 197 /// function takes). 198 class PrototypeAST { 199 Location location; 200 std::string name; 201 std::vector<std::unique_ptr<VariableExprAST>> args; 202 203 public: PrototypeAST(Location location,const std::string & name,std::vector<std::unique_ptr<VariableExprAST>> args)204 PrototypeAST(Location location, const std::string &name, 205 std::vector<std::unique_ptr<VariableExprAST>> args) 206 : location(location), name(name), args(std::move(args)) {} 207 loc()208 const Location &loc() { return location; } getName()209 llvm::StringRef getName() const { return name; } getArgs()210 llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; } 211 }; 212 213 /// This class represents a function definition itself. 214 class FunctionAST { 215 std::unique_ptr<PrototypeAST> proto; 216 std::unique_ptr<ExprASTList> body; 217 218 public: FunctionAST(std::unique_ptr<PrototypeAST> proto,std::unique_ptr<ExprASTList> body)219 FunctionAST(std::unique_ptr<PrototypeAST> proto, 220 std::unique_ptr<ExprASTList> body) 221 : proto(std::move(proto)), body(std::move(body)) {} getProto()222 PrototypeAST *getProto() { return proto.get(); } getBody()223 ExprASTList *getBody() { return body.get(); } 224 }; 225 226 /// This class represents a list of functions to be processed together 227 class ModuleAST { 228 std::vector<FunctionAST> functions; 229 230 public: ModuleAST(std::vector<FunctionAST> functions)231 ModuleAST(std::vector<FunctionAST> functions) 232 : functions(std::move(functions)) {} 233 234 auto begin() -> decltype(functions.begin()) { return functions.begin(); } 235 auto end() -> decltype(functions.end()) { return functions.end(); } 236 }; 237 238 void dump(ModuleAST &); 239 240 } // namespace toy 241 242 #endif // MLIR_TUTORIAL_TOY_AST_H_ 243