1 #include "cpu/CPUAst.hpp"
2 
3 using namespace AST;
4 
5 static std::map<std::string, AllocaInst *> NamedValues;
6 static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
7 
LogError(const char * Str)8 std::unique_ptr<ExprAST> LogError(const char *Str) {
9     fprintf(stderr, "Error: %s\n", Str);
10     return nullptr;
11 }
12 
LogErrorV(const char * Str)13 Value *LogErrorV(const char *Str) {
14     LogError(Str);
15     return nullptr;
16 }
getValueByName(const std::string & name)17 static Value* getValueByName(const std::string& name) {
18     if (name.empty()) {
19         LogErrorV(std::string("Variable name: " + name + "is empty!").c_str());
20     }
21     if (NamedValues.find(name) == NamedValues.end()) {
22         LogErrorV(std::string("Unknown variable name: " + name).c_str());
23     }
24     return NamedValues[name];
25 }
26 
getFunction(LLVMTarget * target,std::string Name)27 static Function *getFunction(LLVMTarget* target, std::string Name) {
28     // First, see if the function has already been added to the current module.
29     if (auto *F = target->getModule()->getFunction(Name)) {
30         return F;
31     }
32 
33     // If not, check whether we can codegen the declaration from some existing
34     // prototype.
35     auto FI = FunctionProtos.find(Name);
36     if (FI != FunctionProtos.end()) {
37         return FI->second->codegen(target);
38     }
39 
40     // If no existing prototype exists, return null.
41     return nullptr;
42 }
43 
44 /// CreateEntryBlockAlloca - Create an alloca instruction in the entry block of
45 /// the function.  This is used for mutable variables etc.
CreateEntryBlockAlloca(Function * TheFunction,StringRef VarName,Type * type)46 static AllocaInst *CreateEntryBlockAlloca(Function *TheFunction,
47                                           StringRef VarName, Type* type) {
48     IRBuilder<> TmpB(&TheFunction->getEntryBlock(),
49                      TheFunction->getEntryBlock().begin());
50     return TmpB.CreateAlloca(type, nullptr, VarName);
51 }
52 
codegen(LLVMTarget * target)53 Value *NumberExprAST::codegen(LLVMTarget* target) {
54     switch (mType) {
55         case FP32:
56             return ConstantFP::get(target->getContext(), APFloat(mVal.f32Val));
57         case FP64:
58             return ConstantFP::get(target->getContext(), APFloat(mVal.f64Val));
59         case INT32:
60             return ConstantInt::get(target->getContext(), APInt(32, mVal.i32Val, true));
61         case INT64:
62             return ConstantInt::get(target->getContext(), APInt(64, mVal.i64Val, true));
63         default:
64             return nullptr;
65     }
66 }
getRef(LLVMTarget * target)67 Value *VariableExprAST::getRef(LLVMTarget* target) {
68     return getValueByName(Name);
69 }
codegen(LLVMTarget * target)70 Value *VariableExprAST::codegen(LLVMTarget* target) {
71     return target->getBuilder()->CreateLoad(getRef(target), Name.c_str());
72 }
73 
getRef(LLVMTarget * target)74 Value* SubscriptExprAST::getRef(LLVMTarget* target) {
75     return target->getBuilder()->CreateGEP(Base->codegen(target), Offset->codegen(target));
76 }
codegen(LLVMTarget * target)77 Value* SubscriptExprAST::codegen(LLVMTarget* target) {
78     return target->getBuilder()->CreateLoad(getRef(target));
79 }
codegen(LLVMTarget * target)80 Value *ReluExprAST::codegen(LLVMTarget* target) {
81     Value *V = Operand->codegen(target);
82     if (!V) {
83         return nullptr;
84     }
85     auto builder = target->getBuilder();
86     auto& llvmContext = target->getContext();
87     if (maxVal == 0.f) {
88         // relu(x) = (x <= 0) * slope * x + (x > 0) * x
89         auto aV = builder->CreateFMul(V, ConstantFP::get(llvmContext, APFloat((float)minVal)));
90         auto gtz = builder->CreateUIToFP(builder->CreateFCmpUGT(V, ConstantFP::get(llvmContext, APFloat((float)0))), Type::getFloatTy(llvmContext));
91         auto ltz = builder->CreateFSub(ConstantFP::get(llvmContext, APFloat((float)1)), gtz);
92         auto l = builder->CreateFMul(ltz, aV);
93         auto r = builder->CreateFMul(gtz, V);
94         return builder->CreateFAdd(l, r);
95     }
96     // relu6(x) = min(max(x, minv), maxv)
97     V = builder->CreateMaxNum(V, ConstantFP::get(llvmContext, APFloat((float)minVal)));
98     return builder->CreateMinNum(V, ConstantFP::get(llvmContext, APFloat((float)maxVal)));
99 }
codegen(LLVMTarget * target)100 Value *UnaryExprAST::codegen(LLVMTarget* target) {
101     Value *V = Operand->codegen(target);
102     if (!V) {
103         return nullptr;
104     }
105     auto builder = target->getBuilder();
106     auto& llvmContext = target->getContext();
107     // llvm intrinsic suppport: abs, floor, ceil, sqrt, exp, log, sin, cos, round
108     //             not support: neg, square, rsqrt, tan, asin, acos, atan, reciprocal, log1p,
109     //                          bnll, acosh, sinh, asinh, atanh, sign, cosh, erf, erfc,
110     //                          erfinv, expm1, tanh, sigmoid,
111     switch (Op) {
112         // llvm intrinsic
113         case MNN::UnaryOpOperation_ABS:
114             return builder->CreateUnaryIntrinsic(Intrinsic::abs, V);
115         case MNN::UnaryOpOperation_FLOOR:
116             return builder->CreateUnaryIntrinsic(Intrinsic::floor, V);
117         case MNN::UnaryOpOperation_CEIL:
118             return builder->CreateUnaryIntrinsic(Intrinsic::ceil, V);
119         case MNN::UnaryOpOperation_SQRT:
120             return builder->CreateUnaryIntrinsic(Intrinsic::sqrt, V);
121         case MNN::UnaryOpOperation_EXP:
122             return builder->CreateUnaryIntrinsic(Intrinsic::exp, V);
123         case MNN::UnaryOpOperation_LOG:
124             return builder->CreateUnaryIntrinsic(Intrinsic::log, V);
125         case MNN::UnaryOpOperation_SIN:
126             return builder->CreateUnaryIntrinsic(Intrinsic::sin, V);
127         case MNN::UnaryOpOperation_COS:
128             return builder->CreateUnaryIntrinsic(Intrinsic::cos, V);
129         case MNN::UnaryOpOperation_ROUND:
130             return builder->CreateUnaryIntrinsic(Intrinsic::round, V);
131         // other
132         case MNN::UnaryOpOperation_NEG:
133             return builder->CreateFNeg(V);
134         case MNN::UnaryOpOperation_SQUARE:
135             return builder->CreateFMul(V, V);
136         case MNN::UnaryOpOperation_RSQRT:
137             V = builder->CreateUnaryIntrinsic(Intrinsic::sqrt, V);
138             return builder->CreateFDiv(ConstantFP::get(llvmContext, APFloat((float)1.0)), V);
139         case MNN::UnaryOpOperation_RECIPROCAL:
140             return builder->CreateFDiv(ConstantFP::get(llvmContext, APFloat((float)1.0)), V);
141         case MNN::UnaryOpOperation_SIGMOID:
142         {
143             V = builder->CreateFNeg(V);
144             V = builder->CreateUnaryIntrinsic(Intrinsic::exp, V);
145             V = builder->CreateFAdd(ConstantFP::get(llvmContext, APFloat((float)1.0)), V);
146             return builder->CreateFDiv(ConstantFP::get(llvmContext, APFloat((float)1.0)), V);
147             // Type* type = V->getType();
148             // FunctionCallee func = TheModule->getOrInsertFunction("sigmoid", FunctionType::get(type, {type}, false));
149             // return builder->CreateCall(func, {V});
150         }
151         // function call
152         case MNN::UnaryOpOperation_TANH:
153         {
154 #ifdef USE_FUNC_CALL
155             Type* type = V->getType();
156             FunctionCallee func = TheModule->getOrInsertFunction("tanhf", FunctionType::get(type, {type}, false));
157             return builder->CreateCall(func, {V});
158 #else
159             auto exp = builder->CreateUnaryIntrinsic(Intrinsic::exp, V);
160             auto nexp = builder->CreateFDiv(ConstantFP::get(llvmContext, APFloat((float)1.0)), exp);
161             return builder->CreateFDiv(builder->CreateFSub(exp, nexp), builder->CreateFAdd(exp, nexp));
162 #endif
163         }
164         default:
165             return LogErrorV(std::string("Unknown unary operator: " + std::string(MNN::EnumNameUnaryOpOperation(Op))).c_str());
166     }
167 }
168 
codegen(LLVMTarget * target)169 Value *BinaryExprAST::codegen(LLVMTarget* target) {
170     Value *L = LHS->codegen(target);
171     Value *R = RHS->codegen(target);
172     if (!L || !R || L->getType() != R->getType()) {
173         return nullptr;
174     }
175     auto builder = target->getBuilder();
176     auto& llvmContext = target->getContext();
177     bool isInt = L->getType()->isIntegerTy();
178     switch (Op) {
179         case MNN::BinaryOpOperation_ADD:
180             return isInt ? builder->CreateAdd(L, R) : builder->CreateFAdd(L, R);
181         case MNN::BinaryOpOperation_SUB:
182             return isInt ? builder->CreateSub(L, R) : builder->CreateFSub(L, R);
183         case MNN::BinaryOpOperation_MUL:
184             return isInt ? builder->CreateMul(L, R) : builder->CreateFMul(L, R);
185         case MNN::BinaryOpOperation_DIV:
186         case MNN::BinaryOpOperation_REALDIV:
187         case MNN::BinaryOpOperation_FLOORDIV:
188             return builder->CreateFDiv(L, R);
189         case MNN::BinaryOpOperation_POW:
190             return builder->CreateBinaryIntrinsic(Intrinsic::pow, L, R);
191         case MNN::BinaryOpOperation_MINIMUM:
192             // Minimun and Maximun Intrinsic will meet bug when LLVM backend select
193             // so use MinNum and MaxNum instead
194             return builder->CreateMinNum(L, R);
195         case MNN::BinaryOpOperation_MAXIMUM:
196             return builder->CreateMaxNum(L, R);
197         case MNN::BinaryOpOperation_GREATER:
198             L = builder->CreateFCmpUGT(L, R);
199             return builder->CreateUIToFP(L, Type::getFloatTy(llvmContext));
200         case MNN::BinaryOpOperation_GREATER_EQUAL:
201             L = builder->CreateFCmpUGE(L, R);
202             return builder->CreateUIToFP(L, Type::getFloatTy(llvmContext));
203         case MNN::BinaryOpOperation_LESS:
204             L = builder->CreateFCmpULT(L, R);
205             // Convert bool 0/1 to double 0.0 or 1.0
206             return builder->CreateUIToFP(L, Type::getFloatTy(llvmContext));
207         case MNN::BinaryOpOperation_LESS_EQUAL:
208             L = builder->CreateFCmpULE(L, R);
209             // Convert bool 0/1 to double 0.0 or 1.0
210             return builder->CreateUIToFP(L, Type::getFloatTy(llvmContext));
211         case MNN::BinaryOpOperation_EQUAL:
212             L = builder->CreateFCmpUEQ(L, R);
213             // Convert bool 0/1 to double 0.0 or 1.0
214             return builder->CreateUIToFP(L, Type::getFloatTy(llvmContext));
215         default:
216             return LogErrorV(std::string("Unknown Binary operator: " + std::string(MNN::EnumNameBinaryOpOperation(Op))).c_str());
217     }
218 }
219 
codegen(LLVMTarget * target)220 Value *AssignExprAST::codegen(LLVMTarget* target) {
221     VariableExprAST *LHSV = static_cast<VariableExprAST *>(LHS.get());
222     if (!LHSV) {
223         return LogErrorV("destination of '=' must be a variable");
224     }
225     Value *Variable = LHSV->getRef(target);
226     Value *Val = RHS->codegen(target);
227     // if (!Val || !Variable) {
228     if (!Val) {
229         return nullptr;
230     }
231     target->getBuilder()->CreateStore(Val, Variable);
232     return Val;
233 }
234 
codegen(LLVMTarget * target)235 Value *CallExprAST::codegen(LLVMTarget* target) {
236     // Look up the name in the global module table.
237     Function *CalleeF = getFunction(target, Callee);
238     if (!CalleeF) {
239         return LogErrorV("Unknown function referenced");
240     }
241 
242     // If argument mismatch error.
243     if (CalleeF->arg_size() != Args.size()) {
244         return LogErrorV("Incorrect # arguments passed");
245     }
246 
247     std::vector<Value *> ArgsV;
248     for (unsigned i = 0, e = Args.size(); i != e; ++i) {
249         ArgsV.push_back(Args[i]->codegen(target));
250         if (!ArgsV.back()) {
251             return nullptr;
252         }
253     }
254     return target->getBuilder()->CreateCall(CalleeF, ArgsV);
255 }
256 
codegen(LLVMTarget * target)257 Value *IfExprAST::codegen(LLVMTarget* target) {
258     Value *CondV = Cond->codegen(target);
259     if (!CondV) {
260         return nullptr;
261     }
262     auto builder = target->getBuilder();
263     auto& llvmContext = target->getContext();
264 
265     // Convert condition to a bool by comparing non-equal to 0.0.
266     CondV = builder->CreateFCmpONE(CondV, ConstantFP::get(llvmContext, APFloat(0.0)));
267 
268     Function *TheFunction = builder->GetInsertBlock()->getParent();
269 
270     // Create blocks for the then and else cases.  Insert the 'then' block at the
271     // end of the function.
272     BasicBlock *ThenBB = BasicBlock::Create(llvmContext, "then", TheFunction);
273     BasicBlock *ElseBB = BasicBlock::Create(llvmContext, "else");
274     BasicBlock *MergeBB = BasicBlock::Create(llvmContext, "ifcont");
275 
276     builder->CreateCondBr(CondV, ThenBB, ElseBB);
277 
278     // Emit then value.
279     builder->SetInsertPoint(ThenBB);
280 
281     Value *ThenV = Then->codegen(target);
282     if (!ThenV) {
283         return nullptr;
284     }
285 
286     builder->CreateBr(MergeBB);
287     // Codegen of 'Then' can change the current block, update ThenBB for the PHI.
288     ThenBB = builder->GetInsertBlock();
289 
290     // Emit else block.
291     TheFunction->getBasicBlockList().push_back(ElseBB);
292     builder->SetInsertPoint(ElseBB);
293 
294     Value *ElseV = Else->codegen(target);
295     if (!ElseV) {
296         return nullptr;
297     }
298 
299     builder->CreateBr(MergeBB);
300     // Codegen of 'Else' can change the current block, update ElseBB for the PHI.
301     ElseBB = builder->GetInsertBlock();
302 
303     // Emit merge block.
304     TheFunction->getBasicBlockList().push_back(MergeBB);
305     builder->SetInsertPoint(MergeBB);
306     PHINode *PN = builder->CreatePHI(Type::getFloatTy(llvmContext), 2, "iftmp");
307 
308     PN->addIncoming(ThenV, ThenBB);
309     PN->addIncoming(ElseV, ElseBB);
310     return PN;
311 }
312 
313 // Output for-loop as:
314 //   var = alloca double
315 //   ...
316 //   start = startexpr
317 //   store start -> var
318 //   goto loop
319 // loop:
320 //   ...
321 //   bodyexpr
322 //   ...
323 // loopend:
324 //   step = stepexpr
325 //   endcond = endexpr
326 //
327 //   curvar = load var
328 //   nextvar = curvar + step
329 //   store nextvar -> var
330 //   br endcond, loop, endloop
331 // outloop:
codegen(LLVMTarget * target)332 Value *ForExprAST::codegen(LLVMTarget* target) {
333     auto builder = target->getBuilder();
334     auto& llvmContext = target->getContext();
335 
336     Function *TheFunction = builder->GetInsertBlock()->getParent();
337 
338     // Create an alloca for the variable in the entry block.
339     AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, VarName, Type::getInt32Ty(llvmContext));
340 
341     // Emit the start code first, without 'variable' in scope.
342     Value *StartVal = Start->codegen(target);
343     if (!StartVal) {
344         return nullptr;
345     }
346 
347     // Store the value into the alloca.
348     builder->CreateStore(StartVal, Alloca);
349 
350     // Make the new basic block for the loop header, inserting after current
351     // block.
352     BasicBlock *LoopBB = BasicBlock::Create(llvmContext, "loop", TheFunction);
353 
354     // Insert an explicit fall through from the current block to the LoopBB.
355     builder->CreateBr(LoopBB);
356 
357     // Start insertion in LoopBB.
358     builder->SetInsertPoint(LoopBB);
359 
360     // Within the loop, the variable is defined equal to the PHI node.  If it
361     // shadows an existing variable, we have to restore it, so save it now.
362     AllocaInst *OldVal = NamedValues[VarName];
363     NamedValues[VarName] = Alloca;
364 
365     // Emit the body of the loop.  This, like any other expr, can change the
366     // current BB.  Note that we ignore the value computed by the body, but don't
367     // allow an error.
368     if (!Body->codegen(target)) {
369         return nullptr;
370     }
371 
372     // Emit the step value.
373     Value *StepVal = nullptr;
374     if (Step) {
375         StepVal = Step->codegen(target);
376         if (!StepVal) {
377             return nullptr;
378         }
379     } else {
380         // If not specified, use 1.0.
381         StepVal = ConstantFP::get(llvmContext, APFloat(1.0));
382     }
383 
384     // Compute the end condition.
385     Value *EndVar = End->codegen(target);
386     if (!EndVar) {
387         return nullptr;
388     }
389 
390     // Reload, increment, and restore the alloca.  This handles the case where
391     // the body of the loop mutates the variable.
392     Value *CurVar = builder->CreateLoad(Alloca, VarName.c_str());
393     Value *NextVar = builder->CreateAdd(CurVar, StepVal, "nextvar");
394 
395     builder->CreateStore(NextVar, Alloca);
396 
397     // Convert condition to a bool by comparing non-equal to 0.0.
398     Value *EndCond = builder->CreateICmpSLT(NextVar, EndVar, "loopcond");
399 
400     // Create the "after loop" block and insert it.
401     BasicBlock *AfterBB = BasicBlock::Create(llvmContext, "afterloop", TheFunction);
402 
403     // Insert the conditional branch into the end of LoopEndBB.
404     builder->CreateCondBr(EndCond, LoopBB, AfterBB);
405 
406     // Any new code will be inserted in AfterBB.
407     builder->SetInsertPoint(AfterBB);
408 
409     // Restore the unshadowed variable.
410     if (OldVal) {
411         NamedValues[VarName] = OldVal;
412     } else {
413         NamedValues.erase(VarName);
414     }
415 
416     // for expr always returns 0.0.
417     return Constant::getNullValue(Type::getFloatTy(llvmContext));
418 }
codegen(LLVMTarget * target)419 Value *ListExprAST::codegen(LLVMTarget* target) {
420     for (auto& expr : exprs) {
421         expr->codegen(target);
422     }
423     // list exprs always returns 0.0.
424     return Constant::getNullValue(Type::getFloatTy(target->getContext()));
425 }
426 
codegen(LLVMTarget * target)427 Value *VarExprAST::codegen(LLVMTarget* target) {
428     std::vector<AllocaInst *> OldBindings;
429 
430     Function *TheFunction = target->getBuilder()->GetInsertBlock()->getParent();
431 
432     // Register all variables and emit their initializer.
433     for (unsigned i = 0, e = VarNames.size(); i != e; ++i) {
434         const std::string &VarName = VarNames[i].first;
435         ExprAST *Init = VarNames[i].second.get();
436 
437         // Emit the initializer before adding the variable to scope, this prevents
438         // the initializer from referencing the variable itself, and permits stuff
439         // like this:
440         //  var a = 1 in
441         //    var a = a in ...   # refers to outer 'a'.
442         Value *InitVal;
443         if (Init) {
444             InitVal = Init->codegen(target);
445             if (!InitVal) {
446                 return nullptr;
447             }
448         } else { // If not specified, use 0.0.
449             InitVal = ConstantFP::get(target->getContext(), APFloat(0.0));
450         }
451 
452         AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, VarName, Type::getFloatTy(target->getContext()));
453         target->getBuilder()->CreateStore(InitVal, Alloca);
454 
455         // Remember the old variable binding so that we can restore the binding when
456         // we unrecurse.
457         OldBindings.push_back(NamedValues[VarName]);
458 
459         // Remember this binding.
460         NamedValues[VarName] = Alloca;
461     }
462 
463     // Codegen the body, now that all vars are in scope.
464     Value *BodyVal = Body->codegen(target);
465     if (!BodyVal) {
466         return nullptr;
467     }
468 
469     // Pop all our variables from scope.
470     for (unsigned i = 0, e = VarNames.size(); i != e; ++i) {
471         NamedValues[VarNames[i].first] = OldBindings[i];
472     }
473 
474     // Return the body computation.
475     return BodyVal;
476 }
477 
codegen(LLVMTarget * target)478 Function *PrototypeAST::codegen(LLVMTarget* target) {
479     // Make the function type:  double(double,double) etc.
480     std::vector<std::string> Args {"inputs", "outputs"};
481     std::vector<Type*> Types {PointerType::getUnqual(Type::getFloatPtrTy(target->getContext())), PointerType::getUnqual(Type::getFloatPtrTy(target->getContext()))};
482     FunctionType *FT = FunctionType::get(Type::getVoidTy(target->getContext()), Types, false);
483 
484     Function *F = Function::Create(FT, Function::ExternalLinkage, Name, target->getModule());
485     // Set names for all arguments.
486     unsigned Idx = 0;
487     for (auto &Arg : F->args()) {
488         F->addParamAttr(Idx, Attribute::NoAlias);
489         Arg.setName(Args[Idx++]);
490     }
491 
492     return F;
493 }
494 
codegen(LLVMTarget * target)495 Function *FunctionAST::codegen(LLVMTarget* target) {
496     // Transfer ownership of the prototype to the FunctionProtos map, but keep a
497     // reference to it for use below.
498     auto &P = *Proto;
499     FunctionProtos[Proto->getName()] = std::move(Proto);
500     Function *TheFunction = getFunction(target, P.getName());
501     if (!TheFunction) {
502         return nullptr;
503     }
504 
505     // Create a new basic block to start insertion into.
506     BasicBlock *BB = BasicBlock::Create(target->getContext(), "entry", TheFunction);
507     target->getBuilder()->SetInsertPoint(BB);
508 
509     // Record the function arguments in the NamedValues map.
510     NamedValues.clear();
511 
512     for (auto &Arg : TheFunction->args()) {
513         // Create an alloca for this variable.
514         AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, Arg.getName(), Arg.getType());
515 
516         // Store the initial value into the alloca.
517         target->getBuilder()->CreateStore(&Arg, Alloca);
518 
519         // Add arguments to variable symbol table.
520         NamedValues[std::string(Arg.getName())] = Alloca;
521     }
522 
523     if (Value *RetVal = Body->codegen(target)) {
524         // ret
525         target->getBuilder()->CreateRetVoid();
526         // Validate the generated code, checking for consistency.
527         verifyFunction(*TheFunction);
528 
529         return TheFunction;
530     }
531 
532     // Error reading body, remove function.
533     TheFunction->eraseFromParent();
534 
535     return nullptr;
536 }
537