1 2 #include <cassert> 3 #include <map> 4 #include <string> 5 #include <vector> 6 #include "MNN_generated.h" 7 #include "PluginModule.hpp" 8 9 #ifdef MNN_CODEGEN_LLVM 10 #include "llvm/IR/Type.h" 11 #include "llvm/IR/Function.h" 12 #include "llvm/ADT/APFloat.h" 13 #include "llvm/ADT/Optional.h" 14 #include "llvm/ADT/STLExtras.h" 15 #include "llvm/IR/BasicBlock.h" 16 #include "llvm/IR/Constants.h" 17 #include "llvm/IR/DerivedTypes.h" 18 #include "llvm/IR/Instructions.h" 19 #include "llvm/IR/IRBuilder.h" 20 #include "llvm/IR/LLVMContext.h" 21 #include "llvm/IR/LegacyPassManager.h" 22 #include "llvm/IR/Module.h" 23 #include "llvm/IR/Verifier.h" 24 #include "llvm/ExecutionEngine/Orc/LLJIT.h" 25 using namespace llvm; 26 using namespace llvm::orc; 27 #endif 28 29 #ifdef MNN_CODEGEN_LLVM 30 class LLVMTarget { 31 public: LLVMTarget(std::string name)32 LLVMTarget(std::string name) { 33 llvmContext.reset(new LLVMContext); 34 llvmBuilder = std::make_unique<IRBuilder<>>(*llvmContext.get()); 35 llvmModule = std::make_unique<Module>(name, *llvmContext.get()); 36 llvmModule->setTargetTriple("x86_64-apple-macosx11.0.0"); 37 } ~LLVMTarget()38 ~LLVMTarget() {} getModule()39 Module* getModule() { 40 return llvmModule.get(); 41 } getContext()42 LLVMContext& getContext() { 43 return *llvmContext.get(); 44 } getBuilder()45 IRBuilder<>* getBuilder() { 46 return llvmBuilder.get(); 47 } getThreadSafeModule()48 ThreadSafeModule getThreadSafeModule() { 49 return ThreadSafeModule(std::move(llvmModule), std::move(llvmContext)); 50 } 51 private: 52 std::unique_ptr<LLVMContext> llvmContext; 53 std::unique_ptr<IRBuilder<>> llvmBuilder; 54 std::unique_ptr<Module> llvmModule; 55 }; 56 #endif 57 58 #ifdef MNN_CODEGEN_C 59 class SourceTarget { 60 public: SourceTarget()61 SourceTarget() {} ~SourceTarget()62 ~SourceTarget() {} addIndent()63 void addIndent() { indent++; } subIndent()64 void subIndent() { indent--; } getIndent()65 std::string getIndent() { 66 return std::string(4 * indent, ' '); 67 } 68 private: 69 int indent = 0; 70 }; 71 72 class CTarget : public SourceTarget { 73 public: CTarget(std::string & name)74 CTarget(std::string& name) {} ~CTarget()75 ~CTarget() {} 76 }; 77 #endif 78 79 #ifdef MNN_CODEGEN_LLVM 80 #define LLVM_CODEGEN Value *codegen(LLVMTarget* target) override; 81 #else 82 #define LLVM_CODEGEN 83 #endif 84 #ifdef MNN_CODEGEN_C 85 #define C_CODEGEN std::string codegen(SourceTarget* target) override; 86 #else 87 #define C_CODEGEN 88 #endif 89 90 #define CODEGEN_FUNCS \ 91 LLVM_CODEGEN \ 92 C_CODEGEN 93 94 namespace AST { 95 /// ExprAST - Base class for all expression nodes. 96 class ExprAST { 97 public: 98 virtual ~ExprAST() = default; 99 100 #ifdef MNN_CODEGEN_LLVM 101 virtual Value *codegen(LLVMTarget* target) = 0; 102 #endif 103 #ifdef MNN_CODEGEN_C 104 virtual std::string codegen(SourceTarget* target) = 0; 105 #endif 106 private: 107 friend class PluginModule; 108 }; 109 110 /// NumberExprAST - Expression class for numeric literals like "1.0". 111 class NumberExprAST : public ExprAST { 112 private: 113 union Val { 114 char chVal; 115 float f32Val; 116 double f64Val; 117 int8_t i8Val; 118 int16_t i16Val; 119 int32_t i32Val; 120 int64_t i64Val; 121 uint8_t ui8Val; 122 uint16_t ui16Val; 123 uint32_t ui32Val; 124 uint64_t ui64Val; 125 } mVal; 126 enum DataType { 127 CHAR = 0, 128 FP16, 129 FP32, 130 FP64, 131 INT1, 132 INT8, 133 INT16, 134 INT32, 135 INT64, 136 UINT1, 137 UINT8, 138 UINT16, 139 UINT32, 140 UINT64 141 }; 142 DataType mType; 143 public: NumberExprAST(float Val)144 NumberExprAST(float Val) : mType(FP32) { mVal.f32Val = Val; } NumberExprAST(double Val)145 NumberExprAST(double Val) : mType(FP64) { mVal.f64Val = Val; } NumberExprAST(int8_t Val)146 NumberExprAST(int8_t Val) : mType(INT8) { mVal.i8Val = Val; } NumberExprAST(int16_t Val)147 NumberExprAST(int16_t Val) : mType(INT16) { mVal.i16Val = Val; } NumberExprAST(int32_t Val)148 NumberExprAST(int32_t Val) : mType(INT32) { mVal.i32Val = Val; } NumberExprAST(int64_t Val)149 NumberExprAST(int64_t Val) : mType(INT64) { mVal.i64Val = Val; } NumberExprAST(uint8_t Val)150 NumberExprAST(uint8_t Val) : mType(UINT8) { mVal.ui8Val = Val; } NumberExprAST(uint16_t Val)151 NumberExprAST(uint16_t Val) : mType(UINT16) { mVal.ui16Val = Val;} NumberExprAST(uint32_t Val)152 NumberExprAST(uint32_t Val) : mType(UINT32) { mVal.ui32Val = Val;} NumberExprAST(uint64_t Val)153 NumberExprAST(uint64_t Val) : mType(UINT64) { mVal.ui64Val = Val;} 154 CODEGEN_FUNCS 155 }; 156 157 /// VariableExprAST - Expression class for referencing a variable, like "a". 158 class VariableExprAST : public ExprAST { 159 std::string Name; 160 161 public: 162 VariableExprAST() = default; VariableExprAST(const std::string & Name)163 VariableExprAST(const std::string &Name) : Name(Name) {} getName() const164 const std::string &getName() const { return Name; } 165 #ifdef MNN_CODEGEN_LLVM 166 virtual Value* getRef(LLVMTarget* target); 167 #endif 168 CODEGEN_FUNCS 169 }; 170 171 class SubscriptExprAST : public VariableExprAST { 172 std::unique_ptr<ExprAST> Base, Offset; 173 public: SubscriptExprAST(std::unique_ptr<ExprAST> Base,std::unique_ptr<ExprAST> Offset)174 SubscriptExprAST(std::unique_ptr<ExprAST> Base, std::unique_ptr<ExprAST> Offset) 175 : Base(std::move(Base)), Offset(std::move(Offset)) {} SubscriptExprAST(std::unique_ptr<ExprAST> Base,const std::string & Offset)176 SubscriptExprAST(std::unique_ptr<ExprAST> Base, const std::string& Offset) 177 : Base(std::move(Base)), Offset(std::make_unique<VariableExprAST>(Offset)) {} SubscriptExprAST(std::unique_ptr<ExprAST> Base,int Offset)178 SubscriptExprAST(std::unique_ptr<ExprAST> Base, int Offset) 179 : Base(std::move(Base)), Offset(std::make_unique<NumberExprAST>(Offset)) {} SubscriptExprAST(const std::string & Base,const std::string & Offset)180 SubscriptExprAST(const std::string& Base, const std::string& Offset) 181 : Base(std::make_unique<VariableExprAST>(Base)), Offset(std::make_unique<VariableExprAST>(Offset)) {} SubscriptExprAST(const std::string & Base,int Offset)182 SubscriptExprAST(const std::string& Base, int Offset) 183 : Base(std::make_unique<VariableExprAST>(Base)), Offset(std::make_unique<NumberExprAST>(Offset)) {} SubscriptExprAST(const std::string & Base,std::unique_ptr<ExprAST> Offset)184 SubscriptExprAST(const std::string& Base, std::unique_ptr<ExprAST> Offset) 185 : Base(std::make_unique<VariableExprAST>(Base)), Offset(std::move(Offset)) {} 186 #ifdef MNN_CODEGEN_LLVM 187 Value *getRef(LLVMTarget* target) override; 188 #endif 189 CODEGEN_FUNCS 190 }; 191 192 /// UnaryExprAST - Expression class for a unary operator. 193 class UnaryExprAST : public ExprAST { 194 MNN::UnaryOpOperation Op; 195 std::unique_ptr<ExprAST> Operand; 196 public: UnaryExprAST(MNN::UnaryOpOperation Op,std::unique_ptr<ExprAST> Operand)197 UnaryExprAST(MNN::UnaryOpOperation Op, std::unique_ptr<ExprAST> Operand) 198 : Op(Op), Operand(std::move(Operand)) {} 199 200 CODEGEN_FUNCS 201 }; 202 203 class ReluExprAST : public ExprAST { 204 float minVal, maxVal; 205 std::unique_ptr<ExprAST> Operand; 206 public: ReluExprAST(float minVal,float maxVal,std::unique_ptr<ExprAST> Operand)207 ReluExprAST(float minVal, float maxVal, std::unique_ptr<ExprAST> Operand) 208 : minVal(minVal), maxVal(maxVal), Operand(std::move(Operand)) {} 209 210 CODEGEN_FUNCS 211 }; 212 213 214 /// BinaryExprAST - Expression class for a binary operator. 215 class BinaryExprAST : public ExprAST { 216 MNN::BinaryOpOperation Op; 217 std::unique_ptr<ExprAST> LHS, RHS; 218 219 public: BinaryExprAST(MNN::BinaryOpOperation Op,std::unique_ptr<ExprAST> LHS,std::unique_ptr<ExprAST> RHS)220 BinaryExprAST(MNN::BinaryOpOperation Op, std::unique_ptr<ExprAST> LHS, 221 std::unique_ptr<ExprAST> RHS) 222 : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {} 223 CODEGEN_FUNCS 224 }; 225 226 class AssignExprAST : public ExprAST { 227 std::unique_ptr<ExprAST> LHS, RHS; 228 public: AssignExprAST(std::unique_ptr<ExprAST> LHS,std::unique_ptr<ExprAST> RHS)229 AssignExprAST(std::unique_ptr<ExprAST> LHS, 230 std::unique_ptr<ExprAST> RHS) 231 : LHS(std::move(LHS)), RHS(std::move(RHS)) {} 232 233 CODEGEN_FUNCS 234 }; 235 236 /// CallExprAST - Expression class for function calls. 237 class CallExprAST : public ExprAST { 238 std::string Callee; 239 std::vector<std::unique_ptr<ExprAST>> Args; 240 241 public: CallExprAST(const std::string & Callee,std::vector<std::unique_ptr<ExprAST>> Args)242 CallExprAST(const std::string &Callee, 243 std::vector<std::unique_ptr<ExprAST>> Args) 244 : Callee(Callee), Args(std::move(Args)) {} 245 246 CODEGEN_FUNCS 247 }; 248 249 /// IfExprAST - Expression class for if/then/else. 250 class IfExprAST : public ExprAST { 251 std::unique_ptr<ExprAST> Cond, Then, Else; 252 253 public: IfExprAST(std::unique_ptr<ExprAST> Cond,std::unique_ptr<ExprAST> Then,std::unique_ptr<ExprAST> Else)254 IfExprAST(std::unique_ptr<ExprAST> Cond, std::unique_ptr<ExprAST> Then, 255 std::unique_ptr<ExprAST> Else) 256 : Cond(std::move(Cond)), Then(std::move(Then)), Else(std::move(Else)) {} 257 258 CODEGEN_FUNCS 259 }; 260 261 /// ForExprAST - Expression class for for/in. 262 class ForExprAST : public ExprAST { 263 std::string VarName; 264 std::unique_ptr<ExprAST> Start, End, Step, Body; 265 266 public: ForExprAST(const std::string & VarName,std::unique_ptr<ExprAST> Start,std::unique_ptr<ExprAST> End,std::unique_ptr<ExprAST> Step,std::unique_ptr<ExprAST> Body)267 ForExprAST(const std::string &VarName, std::unique_ptr<ExprAST> Start, 268 std::unique_ptr<ExprAST> End, std::unique_ptr<ExprAST> Step, 269 std::unique_ptr<ExprAST> Body) 270 : VarName(VarName), Start(std::move(Start)), End(std::move(End)), 271 Step(std::move(Step)), Body(std::move(Body)) {} 272 273 CODEGEN_FUNCS 274 }; 275 276 /// VarExprAST - Expression class for var/in 277 class VarExprAST : public ExprAST { 278 std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames; 279 std::unique_ptr<ExprAST> Body; 280 281 public: VarExprAST(std::vector<std::pair<std::string,std::unique_ptr<ExprAST>>> VarNames,std::unique_ptr<ExprAST> Body)282 VarExprAST( 283 std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames, 284 std::unique_ptr<ExprAST> Body) 285 : VarNames(std::move(VarNames)), Body(std::move(Body)) {} 286 287 CODEGEN_FUNCS 288 }; 289 290 /// ListExprAST - Expression class for expr list 291 class ListExprAST : public ExprAST { 292 std::vector<std::unique_ptr<ExprAST>> exprs; 293 public: 294 ListExprAST() = default; ListExprAST(std::vector<std::unique_ptr<ExprAST>> exprs)295 ListExprAST(std::vector<std::unique_ptr<ExprAST>> exprs) 296 : exprs(std::move(exprs)) {} push_back(std::unique_ptr<ExprAST> expr)297 void push_back(std::unique_ptr<ExprAST> expr) { 298 exprs.emplace_back(std::move(expr)); 299 } 300 301 CODEGEN_FUNCS 302 }; 303 304 class PrototypeAST { 305 std::string Name; 306 int inputArgNum, outputArgNum; 307 public: PrototypeAST(const std::string & Name,int inputNum,int outputNum)308 PrototypeAST(const std::string &Name, int inputNum, int outputNum) 309 : Name(Name), inputArgNum(inputNum), outputArgNum(outputNum) {} 310 getName() const311 const std::string &getName() const { return Name; } 312 313 #ifdef MNN_CODEGEN_LLVM 314 Function *codegen(LLVMTarget* target); 315 #endif 316 #ifdef MNN_CODEGEN_C 317 std::string codegen(SourceTarget* target); 318 #endif 319 }; 320 321 /// FunctionAST - This class represents a function definition itself. 322 class FunctionAST { 323 std::unique_ptr<PrototypeAST> Proto; 324 std::unique_ptr<ExprAST> Body; 325 326 public: FunctionAST(std::unique_ptr<PrototypeAST> Proto,std::unique_ptr<ExprAST> Body)327 FunctionAST(std::unique_ptr<PrototypeAST> Proto, 328 std::unique_ptr<ExprAST> Body) 329 : Proto(std::move(Proto)), Body(std::move(Body)) {} 330 331 #ifdef MNN_CODEGEN_LLVM 332 Function *codegen(LLVMTarget* target); 333 #endif 334 #ifdef MNN_CODEGEN_C 335 std::string codegen(SourceTarget* target); 336 #endif 337 }; 338 } // end CodeGen namespace 339