1 
2 #include <cassert>
3 #include <map>
4 #include <string>
5 #include <vector>
6 #include "MNN_generated.h"
7 #include "PluginModule.hpp"
8 
9 #ifdef MNN_CODEGEN_LLVM
10 #include "llvm/IR/Type.h"
11 #include "llvm/IR/Function.h"
12 #include "llvm/ADT/APFloat.h"
13 #include "llvm/ADT/Optional.h"
14 #include "llvm/ADT/STLExtras.h"
15 #include "llvm/IR/BasicBlock.h"
16 #include "llvm/IR/Constants.h"
17 #include "llvm/IR/DerivedTypes.h"
18 #include "llvm/IR/Instructions.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/LLVMContext.h"
21 #include "llvm/IR/LegacyPassManager.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/IR/Verifier.h"
24 #include "llvm/ExecutionEngine/Orc/LLJIT.h"
25 using namespace llvm;
26 using namespace llvm::orc;
27 #endif
28 
29 #ifdef MNN_CODEGEN_LLVM
30 class LLVMTarget {
31 public:
LLVMTarget(std::string name)32     LLVMTarget(std::string name) {
33         llvmContext.reset(new LLVMContext);
34         llvmBuilder = std::make_unique<IRBuilder<>>(*llvmContext.get());
35         llvmModule = std::make_unique<Module>(name, *llvmContext.get());
36         llvmModule->setTargetTriple("x86_64-apple-macosx11.0.0");
37     }
~LLVMTarget()38     ~LLVMTarget() {}
getModule()39     Module* getModule() {
40         return llvmModule.get();
41     }
getContext()42     LLVMContext& getContext() {
43         return *llvmContext.get();
44     }
getBuilder()45     IRBuilder<>* getBuilder() {
46         return llvmBuilder.get();
47     }
getThreadSafeModule()48     ThreadSafeModule getThreadSafeModule() {
49         return ThreadSafeModule(std::move(llvmModule), std::move(llvmContext));
50     }
51 private:
52     std::unique_ptr<LLVMContext> llvmContext;
53     std::unique_ptr<IRBuilder<>> llvmBuilder;
54     std::unique_ptr<Module> llvmModule;
55 };
56 #endif
57 
58 #ifdef MNN_CODEGEN_C
59 class SourceTarget {
60 public:
SourceTarget()61     SourceTarget() {}
~SourceTarget()62     ~SourceTarget() {}
addIndent()63     void addIndent() { indent++; }
subIndent()64     void subIndent() { indent--; }
getIndent()65     std::string getIndent() {
66         return std::string(4 * indent, ' ');
67     }
68 private:
69     int indent = 0;
70 };
71 
72 class CTarget : public SourceTarget {
73 public:
CTarget(std::string & name)74     CTarget(std::string& name) {}
~CTarget()75     ~CTarget() {}
76 };
77 #endif
78 
79 #ifdef MNN_CODEGEN_LLVM
80     #define LLVM_CODEGEN Value *codegen(LLVMTarget* target) override;
81 #else
82     #define LLVM_CODEGEN
83 #endif
84 #ifdef MNN_CODEGEN_C
85     #define C_CODEGEN std::string codegen(SourceTarget* target) override;
86 #else
87     #define C_CODEGEN
88 #endif
89 
90 #define CODEGEN_FUNCS \
91         LLVM_CODEGEN  \
92         C_CODEGEN
93 
94 namespace AST {
95 /// ExprAST - Base class for all expression nodes.
96 class ExprAST {
97 public:
98     virtual ~ExprAST() = default;
99 
100 #ifdef MNN_CODEGEN_LLVM
101     virtual Value *codegen(LLVMTarget* target) = 0;
102 #endif
103 #ifdef MNN_CODEGEN_C
104     virtual std::string codegen(SourceTarget* target) = 0;
105 #endif
106 private:
107     friend class PluginModule;
108 };
109 
110 /// NumberExprAST - Expression class for numeric literals like "1.0".
111 class NumberExprAST : public ExprAST {
112 private:
113     union Val {
114         char chVal;
115         float f32Val;
116         double f64Val;
117         int8_t  i8Val;
118         int16_t i16Val;
119         int32_t i32Val;
120         int64_t i64Val;
121         uint8_t  ui8Val;
122         uint16_t ui16Val;
123         uint32_t ui32Val;
124         uint64_t ui64Val;
125     } mVal;
126     enum DataType {
127         CHAR = 0,
128         FP16,
129         FP32,
130         FP64,
131         INT1,
132         INT8,
133         INT16,
134         INT32,
135         INT64,
136         UINT1,
137         UINT8,
138         UINT16,
139         UINT32,
140         UINT64
141     };
142     DataType mType;
143 public:
NumberExprAST(float Val)144     NumberExprAST(float Val) : mType(FP32) { mVal.f32Val = Val; }
NumberExprAST(double Val)145     NumberExprAST(double Val) : mType(FP64) { mVal.f64Val = Val; }
NumberExprAST(int8_t Val)146     NumberExprAST(int8_t Val) : mType(INT8) { mVal.i8Val = Val; }
NumberExprAST(int16_t Val)147     NumberExprAST(int16_t Val) : mType(INT16) { mVal.i16Val = Val; }
NumberExprAST(int32_t Val)148     NumberExprAST(int32_t Val) : mType(INT32) { mVal.i32Val = Val; }
NumberExprAST(int64_t Val)149     NumberExprAST(int64_t Val) : mType(INT64) { mVal.i64Val = Val; }
NumberExprAST(uint8_t Val)150     NumberExprAST(uint8_t Val) : mType(UINT8) { mVal.ui8Val = Val; }
NumberExprAST(uint16_t Val)151     NumberExprAST(uint16_t Val) : mType(UINT16) { mVal.ui16Val = Val;}
NumberExprAST(uint32_t Val)152     NumberExprAST(uint32_t Val) : mType(UINT32) { mVal.ui32Val = Val;}
NumberExprAST(uint64_t Val)153     NumberExprAST(uint64_t Val) : mType(UINT64) { mVal.ui64Val = Val;}
154     CODEGEN_FUNCS
155 };
156 
157 /// VariableExprAST - Expression class for referencing a variable, like "a".
158 class VariableExprAST : public ExprAST {
159     std::string Name;
160 
161 public:
162     VariableExprAST() = default;
VariableExprAST(const std::string & Name)163     VariableExprAST(const std::string &Name) : Name(Name) {}
getName() const164     const std::string &getName() const { return Name; }
165 #ifdef MNN_CODEGEN_LLVM
166     virtual Value* getRef(LLVMTarget* target);
167 #endif
168     CODEGEN_FUNCS
169 };
170 
171 class SubscriptExprAST : public VariableExprAST {
172     std::unique_ptr<ExprAST> Base, Offset;
173 public:
SubscriptExprAST(std::unique_ptr<ExprAST> Base,std::unique_ptr<ExprAST> Offset)174     SubscriptExprAST(std::unique_ptr<ExprAST> Base, std::unique_ptr<ExprAST> Offset)
175         : Base(std::move(Base)), Offset(std::move(Offset)) {}
SubscriptExprAST(std::unique_ptr<ExprAST> Base,const std::string & Offset)176     SubscriptExprAST(std::unique_ptr<ExprAST> Base, const std::string& Offset)
177         : Base(std::move(Base)), Offset(std::make_unique<VariableExprAST>(Offset)) {}
SubscriptExprAST(std::unique_ptr<ExprAST> Base,int Offset)178     SubscriptExprAST(std::unique_ptr<ExprAST> Base, int Offset)
179         : Base(std::move(Base)), Offset(std::make_unique<NumberExprAST>(Offset)) {}
SubscriptExprAST(const std::string & Base,const std::string & Offset)180     SubscriptExprAST(const std::string& Base, const std::string& Offset)
181         : Base(std::make_unique<VariableExprAST>(Base)), Offset(std::make_unique<VariableExprAST>(Offset)) {}
SubscriptExprAST(const std::string & Base,int Offset)182     SubscriptExprAST(const std::string& Base, int Offset)
183         : Base(std::make_unique<VariableExprAST>(Base)), Offset(std::make_unique<NumberExprAST>(Offset)) {}
SubscriptExprAST(const std::string & Base,std::unique_ptr<ExprAST> Offset)184     SubscriptExprAST(const std::string& Base, std::unique_ptr<ExprAST> Offset)
185         : Base(std::make_unique<VariableExprAST>(Base)), Offset(std::move(Offset)) {}
186 #ifdef MNN_CODEGEN_LLVM
187     Value *getRef(LLVMTarget* target) override;
188 #endif
189     CODEGEN_FUNCS
190 };
191 
192 /// UnaryExprAST - Expression class for a unary operator.
193 class UnaryExprAST : public ExprAST {
194     MNN::UnaryOpOperation Op;
195     std::unique_ptr<ExprAST> Operand;
196 public:
UnaryExprAST(MNN::UnaryOpOperation Op,std::unique_ptr<ExprAST> Operand)197     UnaryExprAST(MNN::UnaryOpOperation Op, std::unique_ptr<ExprAST> Operand)
198         : Op(Op), Operand(std::move(Operand)) {}
199 
200     CODEGEN_FUNCS
201 };
202 
203 class ReluExprAST : public ExprAST {
204     float minVal, maxVal;
205     std::unique_ptr<ExprAST> Operand;
206 public:
ReluExprAST(float minVal,float maxVal,std::unique_ptr<ExprAST> Operand)207     ReluExprAST(float minVal, float maxVal, std::unique_ptr<ExprAST> Operand)
208         : minVal(minVal), maxVal(maxVal), Operand(std::move(Operand)) {}
209 
210     CODEGEN_FUNCS
211 };
212 
213 
214 /// BinaryExprAST - Expression class for a binary operator.
215 class BinaryExprAST : public ExprAST {
216     MNN::BinaryOpOperation Op;
217     std::unique_ptr<ExprAST> LHS, RHS;
218 
219 public:
BinaryExprAST(MNN::BinaryOpOperation Op,std::unique_ptr<ExprAST> LHS,std::unique_ptr<ExprAST> RHS)220     BinaryExprAST(MNN::BinaryOpOperation Op, std::unique_ptr<ExprAST> LHS,
221                     std::unique_ptr<ExprAST> RHS)
222           : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
223     CODEGEN_FUNCS
224 };
225 
226 class AssignExprAST : public ExprAST {
227     std::unique_ptr<ExprAST> LHS, RHS;
228 public:
AssignExprAST(std::unique_ptr<ExprAST> LHS,std::unique_ptr<ExprAST> RHS)229     AssignExprAST(std::unique_ptr<ExprAST> LHS,
230                   std::unique_ptr<ExprAST> RHS)
231         : LHS(std::move(LHS)), RHS(std::move(RHS)) {}
232 
233     CODEGEN_FUNCS
234 };
235 
236 /// CallExprAST - Expression class for function calls.
237 class CallExprAST : public ExprAST {
238     std::string Callee;
239     std::vector<std::unique_ptr<ExprAST>> Args;
240 
241 public:
CallExprAST(const std::string & Callee,std::vector<std::unique_ptr<ExprAST>> Args)242     CallExprAST(const std::string &Callee,
243                 std::vector<std::unique_ptr<ExprAST>> Args)
244         : Callee(Callee), Args(std::move(Args)) {}
245 
246     CODEGEN_FUNCS
247 };
248 
249 /// IfExprAST - Expression class for if/then/else.
250 class IfExprAST : public ExprAST {
251     std::unique_ptr<ExprAST> Cond, Then, Else;
252 
253 public:
IfExprAST(std::unique_ptr<ExprAST> Cond,std::unique_ptr<ExprAST> Then,std::unique_ptr<ExprAST> Else)254     IfExprAST(std::unique_ptr<ExprAST> Cond, std::unique_ptr<ExprAST> Then,
255               std::unique_ptr<ExprAST> Else)
256         : Cond(std::move(Cond)), Then(std::move(Then)), Else(std::move(Else)) {}
257 
258     CODEGEN_FUNCS
259 };
260 
261 /// ForExprAST - Expression class for for/in.
262 class ForExprAST : public ExprAST {
263     std::string VarName;
264     std::unique_ptr<ExprAST> Start, End, Step, Body;
265 
266 public:
ForExprAST(const std::string & VarName,std::unique_ptr<ExprAST> Start,std::unique_ptr<ExprAST> End,std::unique_ptr<ExprAST> Step,std::unique_ptr<ExprAST> Body)267     ForExprAST(const std::string &VarName, std::unique_ptr<ExprAST> Start,
268                std::unique_ptr<ExprAST> End, std::unique_ptr<ExprAST> Step,
269                std::unique_ptr<ExprAST> Body)
270         : VarName(VarName), Start(std::move(Start)), End(std::move(End)),
271             Step(std::move(Step)), Body(std::move(Body)) {}
272 
273     CODEGEN_FUNCS
274 };
275 
276 /// VarExprAST - Expression class for var/in
277 class VarExprAST : public ExprAST {
278     std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames;
279     std::unique_ptr<ExprAST> Body;
280 
281 public:
VarExprAST(std::vector<std::pair<std::string,std::unique_ptr<ExprAST>>> VarNames,std::unique_ptr<ExprAST> Body)282     VarExprAST(
283       std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames,
284       std::unique_ptr<ExprAST> Body)
285       : VarNames(std::move(VarNames)), Body(std::move(Body)) {}
286 
287     CODEGEN_FUNCS
288 };
289 
290 /// ListExprAST - Expression class for expr list
291 class ListExprAST : public ExprAST {
292     std::vector<std::unique_ptr<ExprAST>> exprs;
293 public:
294     ListExprAST() = default;
ListExprAST(std::vector<std::unique_ptr<ExprAST>> exprs)295     ListExprAST(std::vector<std::unique_ptr<ExprAST>> exprs)
296         : exprs(std::move(exprs)) {}
push_back(std::unique_ptr<ExprAST> expr)297     void push_back(std::unique_ptr<ExprAST> expr) {
298         exprs.emplace_back(std::move(expr));
299     }
300 
301     CODEGEN_FUNCS
302 };
303 
304 class PrototypeAST {
305     std::string Name;
306     int inputArgNum, outputArgNum;
307 public:
PrototypeAST(const std::string & Name,int inputNum,int outputNum)308     PrototypeAST(const std::string &Name, int inputNum, int outputNum)
309         : Name(Name), inputArgNum(inputNum), outputArgNum(outputNum) {}
310 
getName() const311     const std::string &getName() const { return Name; }
312 
313 #ifdef MNN_CODEGEN_LLVM
314     Function *codegen(LLVMTarget* target);
315 #endif
316 #ifdef MNN_CODEGEN_C
317     std::string codegen(SourceTarget* target);
318 #endif
319 };
320 
321 /// FunctionAST - This class represents a function definition itself.
322 class FunctionAST {
323     std::unique_ptr<PrototypeAST> Proto;
324     std::unique_ptr<ExprAST> Body;
325 
326 public:
FunctionAST(std::unique_ptr<PrototypeAST> Proto,std::unique_ptr<ExprAST> Body)327     FunctionAST(std::unique_ptr<PrototypeAST> Proto,
328                 std::unique_ptr<ExprAST> Body)
329         : Proto(std::move(Proto)), Body(std::move(Body)) {}
330 
331 #ifdef MNN_CODEGEN_LLVM
332     Function *codegen(LLVMTarget* target);
333 #endif
334 #ifdef MNN_CODEGEN_C
335     std::string codegen(SourceTarget* target);
336 #endif
337 };
338 } // end CodeGen namespace
339