1 #include "cpu/CPUAst.hpp"
2
3 using namespace AST;
4
5 static std::map<std::string, AllocaInst *> NamedValues;
6 static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
7
LogError(const char * Str)8 std::unique_ptr<ExprAST> LogError(const char *Str) {
9 fprintf(stderr, "Error: %s\n", Str);
10 return nullptr;
11 }
12
LogErrorV(const char * Str)13 Value *LogErrorV(const char *Str) {
14 LogError(Str);
15 return nullptr;
16 }
getValueByName(const std::string & name)17 static Value* getValueByName(const std::string& name) {
18 if (name.empty()) {
19 LogErrorV(std::string("Variable name: " + name + "is empty!").c_str());
20 }
21 if (NamedValues.find(name) == NamedValues.end()) {
22 LogErrorV(std::string("Unknown variable name: " + name).c_str());
23 }
24 return NamedValues[name];
25 }
26
getFunction(LLVMTarget * target,std::string Name)27 static Function *getFunction(LLVMTarget* target, std::string Name) {
28 // First, see if the function has already been added to the current module.
29 if (auto *F = target->getModule()->getFunction(Name)) {
30 return F;
31 }
32
33 // If not, check whether we can codegen the declaration from some existing
34 // prototype.
35 auto FI = FunctionProtos.find(Name);
36 if (FI != FunctionProtos.end()) {
37 return FI->second->codegen(target);
38 }
39
40 // If no existing prototype exists, return null.
41 return nullptr;
42 }
43
44 /// CreateEntryBlockAlloca - Create an alloca instruction in the entry block of
45 /// the function. This is used for mutable variables etc.
CreateEntryBlockAlloca(Function * TheFunction,StringRef VarName,Type * type)46 static AllocaInst *CreateEntryBlockAlloca(Function *TheFunction,
47 StringRef VarName, Type* type) {
48 IRBuilder<> TmpB(&TheFunction->getEntryBlock(),
49 TheFunction->getEntryBlock().begin());
50 return TmpB.CreateAlloca(type, nullptr, VarName);
51 }
52
codegen(LLVMTarget * target)53 Value *NumberExprAST::codegen(LLVMTarget* target) {
54 switch (mType) {
55 case FP32:
56 return ConstantFP::get(target->getContext(), APFloat(mVal.f32Val));
57 case FP64:
58 return ConstantFP::get(target->getContext(), APFloat(mVal.f64Val));
59 case INT32:
60 return ConstantInt::get(target->getContext(), APInt(32, mVal.i32Val, true));
61 case INT64:
62 return ConstantInt::get(target->getContext(), APInt(64, mVal.i64Val, true));
63 default:
64 return nullptr;
65 }
66 }
getRef(LLVMTarget * target)67 Value *VariableExprAST::getRef(LLVMTarget* target) {
68 return getValueByName(Name);
69 }
codegen(LLVMTarget * target)70 Value *VariableExprAST::codegen(LLVMTarget* target) {
71 return target->getBuilder()->CreateLoad(getRef(target), Name.c_str());
72 }
73
getRef(LLVMTarget * target)74 Value* SubscriptExprAST::getRef(LLVMTarget* target) {
75 return target->getBuilder()->CreateGEP(Base->codegen(target), Offset->codegen(target));
76 }
codegen(LLVMTarget * target)77 Value* SubscriptExprAST::codegen(LLVMTarget* target) {
78 return target->getBuilder()->CreateLoad(getRef(target));
79 }
codegen(LLVMTarget * target)80 Value *ReluExprAST::codegen(LLVMTarget* target) {
81 Value *V = Operand->codegen(target);
82 if (!V) {
83 return nullptr;
84 }
85 auto builder = target->getBuilder();
86 auto& llvmContext = target->getContext();
87 if (maxVal == 0.f) {
88 // relu(x) = (x <= 0) * slope * x + (x > 0) * x
89 auto aV = builder->CreateFMul(V, ConstantFP::get(llvmContext, APFloat((float)minVal)));
90 auto gtz = builder->CreateUIToFP(builder->CreateFCmpUGT(V, ConstantFP::get(llvmContext, APFloat((float)0))), Type::getFloatTy(llvmContext));
91 auto ltz = builder->CreateFSub(ConstantFP::get(llvmContext, APFloat((float)1)), gtz);
92 auto l = builder->CreateFMul(ltz, aV);
93 auto r = builder->CreateFMul(gtz, V);
94 return builder->CreateFAdd(l, r);
95 }
96 // relu6(x) = min(max(x, minv), maxv)
97 V = builder->CreateMaxNum(V, ConstantFP::get(llvmContext, APFloat((float)minVal)));
98 return builder->CreateMinNum(V, ConstantFP::get(llvmContext, APFloat((float)maxVal)));
99 }
codegen(LLVMTarget * target)100 Value *UnaryExprAST::codegen(LLVMTarget* target) {
101 Value *V = Operand->codegen(target);
102 if (!V) {
103 return nullptr;
104 }
105 auto builder = target->getBuilder();
106 auto& llvmContext = target->getContext();
107 // llvm intrinsic suppport: abs, floor, ceil, sqrt, exp, log, sin, cos, round
108 // not support: neg, square, rsqrt, tan, asin, acos, atan, reciprocal, log1p,
109 // bnll, acosh, sinh, asinh, atanh, sign, cosh, erf, erfc,
110 // erfinv, expm1, tanh, sigmoid,
111 switch (Op) {
112 // llvm intrinsic
113 case MNN::UnaryOpOperation_ABS:
114 return builder->CreateUnaryIntrinsic(Intrinsic::abs, V);
115 case MNN::UnaryOpOperation_FLOOR:
116 return builder->CreateUnaryIntrinsic(Intrinsic::floor, V);
117 case MNN::UnaryOpOperation_CEIL:
118 return builder->CreateUnaryIntrinsic(Intrinsic::ceil, V);
119 case MNN::UnaryOpOperation_SQRT:
120 return builder->CreateUnaryIntrinsic(Intrinsic::sqrt, V);
121 case MNN::UnaryOpOperation_EXP:
122 return builder->CreateUnaryIntrinsic(Intrinsic::exp, V);
123 case MNN::UnaryOpOperation_LOG:
124 return builder->CreateUnaryIntrinsic(Intrinsic::log, V);
125 case MNN::UnaryOpOperation_SIN:
126 return builder->CreateUnaryIntrinsic(Intrinsic::sin, V);
127 case MNN::UnaryOpOperation_COS:
128 return builder->CreateUnaryIntrinsic(Intrinsic::cos, V);
129 case MNN::UnaryOpOperation_ROUND:
130 return builder->CreateUnaryIntrinsic(Intrinsic::round, V);
131 // other
132 case MNN::UnaryOpOperation_NEG:
133 return builder->CreateFNeg(V);
134 case MNN::UnaryOpOperation_SQUARE:
135 return builder->CreateFMul(V, V);
136 case MNN::UnaryOpOperation_RSQRT:
137 V = builder->CreateUnaryIntrinsic(Intrinsic::sqrt, V);
138 return builder->CreateFDiv(ConstantFP::get(llvmContext, APFloat((float)1.0)), V);
139 case MNN::UnaryOpOperation_RECIPROCAL:
140 return builder->CreateFDiv(ConstantFP::get(llvmContext, APFloat((float)1.0)), V);
141 case MNN::UnaryOpOperation_SIGMOID:
142 {
143 V = builder->CreateFNeg(V);
144 V = builder->CreateUnaryIntrinsic(Intrinsic::exp, V);
145 V = builder->CreateFAdd(ConstantFP::get(llvmContext, APFloat((float)1.0)), V);
146 return builder->CreateFDiv(ConstantFP::get(llvmContext, APFloat((float)1.0)), V);
147 // Type* type = V->getType();
148 // FunctionCallee func = TheModule->getOrInsertFunction("sigmoid", FunctionType::get(type, {type}, false));
149 // return builder->CreateCall(func, {V});
150 }
151 // function call
152 case MNN::UnaryOpOperation_TANH:
153 {
154 #ifdef USE_FUNC_CALL
155 Type* type = V->getType();
156 FunctionCallee func = TheModule->getOrInsertFunction("tanhf", FunctionType::get(type, {type}, false));
157 return builder->CreateCall(func, {V});
158 #else
159 auto exp = builder->CreateUnaryIntrinsic(Intrinsic::exp, V);
160 auto nexp = builder->CreateFDiv(ConstantFP::get(llvmContext, APFloat((float)1.0)), exp);
161 return builder->CreateFDiv(builder->CreateFSub(exp, nexp), builder->CreateFAdd(exp, nexp));
162 #endif
163 }
164 default:
165 return LogErrorV(std::string("Unknown unary operator: " + std::string(MNN::EnumNameUnaryOpOperation(Op))).c_str());
166 }
167 }
168
codegen(LLVMTarget * target)169 Value *BinaryExprAST::codegen(LLVMTarget* target) {
170 Value *L = LHS->codegen(target);
171 Value *R = RHS->codegen(target);
172 if (!L || !R || L->getType() != R->getType()) {
173 return nullptr;
174 }
175 auto builder = target->getBuilder();
176 auto& llvmContext = target->getContext();
177 bool isInt = L->getType()->isIntegerTy();
178 switch (Op) {
179 case MNN::BinaryOpOperation_ADD:
180 return isInt ? builder->CreateAdd(L, R) : builder->CreateFAdd(L, R);
181 case MNN::BinaryOpOperation_SUB:
182 return isInt ? builder->CreateSub(L, R) : builder->CreateFSub(L, R);
183 case MNN::BinaryOpOperation_MUL:
184 return isInt ? builder->CreateMul(L, R) : builder->CreateFMul(L, R);
185 case MNN::BinaryOpOperation_DIV:
186 case MNN::BinaryOpOperation_REALDIV:
187 case MNN::BinaryOpOperation_FLOORDIV:
188 return builder->CreateFDiv(L, R);
189 case MNN::BinaryOpOperation_POW:
190 return builder->CreateBinaryIntrinsic(Intrinsic::pow, L, R);
191 case MNN::BinaryOpOperation_MINIMUM:
192 // Minimun and Maximun Intrinsic will meet bug when LLVM backend select
193 // so use MinNum and MaxNum instead
194 return builder->CreateMinNum(L, R);
195 case MNN::BinaryOpOperation_MAXIMUM:
196 return builder->CreateMaxNum(L, R);
197 case MNN::BinaryOpOperation_GREATER:
198 L = builder->CreateFCmpUGT(L, R);
199 return builder->CreateUIToFP(L, Type::getFloatTy(llvmContext));
200 case MNN::BinaryOpOperation_GREATER_EQUAL:
201 L = builder->CreateFCmpUGE(L, R);
202 return builder->CreateUIToFP(L, Type::getFloatTy(llvmContext));
203 case MNN::BinaryOpOperation_LESS:
204 L = builder->CreateFCmpULT(L, R);
205 // Convert bool 0/1 to double 0.0 or 1.0
206 return builder->CreateUIToFP(L, Type::getFloatTy(llvmContext));
207 case MNN::BinaryOpOperation_LESS_EQUAL:
208 L = builder->CreateFCmpULE(L, R);
209 // Convert bool 0/1 to double 0.0 or 1.0
210 return builder->CreateUIToFP(L, Type::getFloatTy(llvmContext));
211 case MNN::BinaryOpOperation_EQUAL:
212 L = builder->CreateFCmpUEQ(L, R);
213 // Convert bool 0/1 to double 0.0 or 1.0
214 return builder->CreateUIToFP(L, Type::getFloatTy(llvmContext));
215 default:
216 return LogErrorV(std::string("Unknown Binary operator: " + std::string(MNN::EnumNameBinaryOpOperation(Op))).c_str());
217 }
218 }
219
codegen(LLVMTarget * target)220 Value *AssignExprAST::codegen(LLVMTarget* target) {
221 VariableExprAST *LHSV = static_cast<VariableExprAST *>(LHS.get());
222 if (!LHSV) {
223 return LogErrorV("destination of '=' must be a variable");
224 }
225 Value *Variable = LHSV->getRef(target);
226 Value *Val = RHS->codegen(target);
227 // if (!Val || !Variable) {
228 if (!Val) {
229 return nullptr;
230 }
231 target->getBuilder()->CreateStore(Val, Variable);
232 return Val;
233 }
234
codegen(LLVMTarget * target)235 Value *CallExprAST::codegen(LLVMTarget* target) {
236 // Look up the name in the global module table.
237 Function *CalleeF = getFunction(target, Callee);
238 if (!CalleeF) {
239 return LogErrorV("Unknown function referenced");
240 }
241
242 // If argument mismatch error.
243 if (CalleeF->arg_size() != Args.size()) {
244 return LogErrorV("Incorrect # arguments passed");
245 }
246
247 std::vector<Value *> ArgsV;
248 for (unsigned i = 0, e = Args.size(); i != e; ++i) {
249 ArgsV.push_back(Args[i]->codegen(target));
250 if (!ArgsV.back()) {
251 return nullptr;
252 }
253 }
254 return target->getBuilder()->CreateCall(CalleeF, ArgsV);
255 }
256
codegen(LLVMTarget * target)257 Value *IfExprAST::codegen(LLVMTarget* target) {
258 Value *CondV = Cond->codegen(target);
259 if (!CondV) {
260 return nullptr;
261 }
262 auto builder = target->getBuilder();
263 auto& llvmContext = target->getContext();
264
265 // Convert condition to a bool by comparing non-equal to 0.0.
266 CondV = builder->CreateFCmpONE(CondV, ConstantFP::get(llvmContext, APFloat(0.0)));
267
268 Function *TheFunction = builder->GetInsertBlock()->getParent();
269
270 // Create blocks for the then and else cases. Insert the 'then' block at the
271 // end of the function.
272 BasicBlock *ThenBB = BasicBlock::Create(llvmContext, "then", TheFunction);
273 BasicBlock *ElseBB = BasicBlock::Create(llvmContext, "else");
274 BasicBlock *MergeBB = BasicBlock::Create(llvmContext, "ifcont");
275
276 builder->CreateCondBr(CondV, ThenBB, ElseBB);
277
278 // Emit then value.
279 builder->SetInsertPoint(ThenBB);
280
281 Value *ThenV = Then->codegen(target);
282 if (!ThenV) {
283 return nullptr;
284 }
285
286 builder->CreateBr(MergeBB);
287 // Codegen of 'Then' can change the current block, update ThenBB for the PHI.
288 ThenBB = builder->GetInsertBlock();
289
290 // Emit else block.
291 TheFunction->getBasicBlockList().push_back(ElseBB);
292 builder->SetInsertPoint(ElseBB);
293
294 Value *ElseV = Else->codegen(target);
295 if (!ElseV) {
296 return nullptr;
297 }
298
299 builder->CreateBr(MergeBB);
300 // Codegen of 'Else' can change the current block, update ElseBB for the PHI.
301 ElseBB = builder->GetInsertBlock();
302
303 // Emit merge block.
304 TheFunction->getBasicBlockList().push_back(MergeBB);
305 builder->SetInsertPoint(MergeBB);
306 PHINode *PN = builder->CreatePHI(Type::getFloatTy(llvmContext), 2, "iftmp");
307
308 PN->addIncoming(ThenV, ThenBB);
309 PN->addIncoming(ElseV, ElseBB);
310 return PN;
311 }
312
313 // Output for-loop as:
314 // var = alloca double
315 // ...
316 // start = startexpr
317 // store start -> var
318 // goto loop
319 // loop:
320 // ...
321 // bodyexpr
322 // ...
323 // loopend:
324 // step = stepexpr
325 // endcond = endexpr
326 //
327 // curvar = load var
328 // nextvar = curvar + step
329 // store nextvar -> var
330 // br endcond, loop, endloop
331 // outloop:
codegen(LLVMTarget * target)332 Value *ForExprAST::codegen(LLVMTarget* target) {
333 auto builder = target->getBuilder();
334 auto& llvmContext = target->getContext();
335
336 Function *TheFunction = builder->GetInsertBlock()->getParent();
337
338 // Create an alloca for the variable in the entry block.
339 AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, VarName, Type::getInt32Ty(llvmContext));
340
341 // Emit the start code first, without 'variable' in scope.
342 Value *StartVal = Start->codegen(target);
343 if (!StartVal) {
344 return nullptr;
345 }
346
347 // Store the value into the alloca.
348 builder->CreateStore(StartVal, Alloca);
349
350 // Make the new basic block for the loop header, inserting after current
351 // block.
352 BasicBlock *LoopBB = BasicBlock::Create(llvmContext, "loop", TheFunction);
353
354 // Insert an explicit fall through from the current block to the LoopBB.
355 builder->CreateBr(LoopBB);
356
357 // Start insertion in LoopBB.
358 builder->SetInsertPoint(LoopBB);
359
360 // Within the loop, the variable is defined equal to the PHI node. If it
361 // shadows an existing variable, we have to restore it, so save it now.
362 AllocaInst *OldVal = NamedValues[VarName];
363 NamedValues[VarName] = Alloca;
364
365 // Emit the body of the loop. This, like any other expr, can change the
366 // current BB. Note that we ignore the value computed by the body, but don't
367 // allow an error.
368 if (!Body->codegen(target)) {
369 return nullptr;
370 }
371
372 // Emit the step value.
373 Value *StepVal = nullptr;
374 if (Step) {
375 StepVal = Step->codegen(target);
376 if (!StepVal) {
377 return nullptr;
378 }
379 } else {
380 // If not specified, use 1.0.
381 StepVal = ConstantFP::get(llvmContext, APFloat(1.0));
382 }
383
384 // Compute the end condition.
385 Value *EndVar = End->codegen(target);
386 if (!EndVar) {
387 return nullptr;
388 }
389
390 // Reload, increment, and restore the alloca. This handles the case where
391 // the body of the loop mutates the variable.
392 Value *CurVar = builder->CreateLoad(Alloca, VarName.c_str());
393 Value *NextVar = builder->CreateAdd(CurVar, StepVal, "nextvar");
394
395 builder->CreateStore(NextVar, Alloca);
396
397 // Convert condition to a bool by comparing non-equal to 0.0.
398 Value *EndCond = builder->CreateICmpSLT(NextVar, EndVar, "loopcond");
399
400 // Create the "after loop" block and insert it.
401 BasicBlock *AfterBB = BasicBlock::Create(llvmContext, "afterloop", TheFunction);
402
403 // Insert the conditional branch into the end of LoopEndBB.
404 builder->CreateCondBr(EndCond, LoopBB, AfterBB);
405
406 // Any new code will be inserted in AfterBB.
407 builder->SetInsertPoint(AfterBB);
408
409 // Restore the unshadowed variable.
410 if (OldVal) {
411 NamedValues[VarName] = OldVal;
412 } else {
413 NamedValues.erase(VarName);
414 }
415
416 // for expr always returns 0.0.
417 return Constant::getNullValue(Type::getFloatTy(llvmContext));
418 }
codegen(LLVMTarget * target)419 Value *ListExprAST::codegen(LLVMTarget* target) {
420 for (auto& expr : exprs) {
421 expr->codegen(target);
422 }
423 // list exprs always returns 0.0.
424 return Constant::getNullValue(Type::getFloatTy(target->getContext()));
425 }
426
codegen(LLVMTarget * target)427 Value *VarExprAST::codegen(LLVMTarget* target) {
428 std::vector<AllocaInst *> OldBindings;
429
430 Function *TheFunction = target->getBuilder()->GetInsertBlock()->getParent();
431
432 // Register all variables and emit their initializer.
433 for (unsigned i = 0, e = VarNames.size(); i != e; ++i) {
434 const std::string &VarName = VarNames[i].first;
435 ExprAST *Init = VarNames[i].second.get();
436
437 // Emit the initializer before adding the variable to scope, this prevents
438 // the initializer from referencing the variable itself, and permits stuff
439 // like this:
440 // var a = 1 in
441 // var a = a in ... # refers to outer 'a'.
442 Value *InitVal;
443 if (Init) {
444 InitVal = Init->codegen(target);
445 if (!InitVal) {
446 return nullptr;
447 }
448 } else { // If not specified, use 0.0.
449 InitVal = ConstantFP::get(target->getContext(), APFloat(0.0));
450 }
451
452 AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, VarName, Type::getFloatTy(target->getContext()));
453 target->getBuilder()->CreateStore(InitVal, Alloca);
454
455 // Remember the old variable binding so that we can restore the binding when
456 // we unrecurse.
457 OldBindings.push_back(NamedValues[VarName]);
458
459 // Remember this binding.
460 NamedValues[VarName] = Alloca;
461 }
462
463 // Codegen the body, now that all vars are in scope.
464 Value *BodyVal = Body->codegen(target);
465 if (!BodyVal) {
466 return nullptr;
467 }
468
469 // Pop all our variables from scope.
470 for (unsigned i = 0, e = VarNames.size(); i != e; ++i) {
471 NamedValues[VarNames[i].first] = OldBindings[i];
472 }
473
474 // Return the body computation.
475 return BodyVal;
476 }
477
codegen(LLVMTarget * target)478 Function *PrototypeAST::codegen(LLVMTarget* target) {
479 // Make the function type: double(double,double) etc.
480 std::vector<std::string> Args {"inputs", "outputs"};
481 std::vector<Type*> Types {PointerType::getUnqual(Type::getFloatPtrTy(target->getContext())), PointerType::getUnqual(Type::getFloatPtrTy(target->getContext()))};
482 FunctionType *FT = FunctionType::get(Type::getVoidTy(target->getContext()), Types, false);
483
484 Function *F = Function::Create(FT, Function::ExternalLinkage, Name, target->getModule());
485 // Set names for all arguments.
486 unsigned Idx = 0;
487 for (auto &Arg : F->args()) {
488 F->addParamAttr(Idx, Attribute::NoAlias);
489 Arg.setName(Args[Idx++]);
490 }
491
492 return F;
493 }
494
codegen(LLVMTarget * target)495 Function *FunctionAST::codegen(LLVMTarget* target) {
496 // Transfer ownership of the prototype to the FunctionProtos map, but keep a
497 // reference to it for use below.
498 auto &P = *Proto;
499 FunctionProtos[Proto->getName()] = std::move(Proto);
500 Function *TheFunction = getFunction(target, P.getName());
501 if (!TheFunction) {
502 return nullptr;
503 }
504
505 // Create a new basic block to start insertion into.
506 BasicBlock *BB = BasicBlock::Create(target->getContext(), "entry", TheFunction);
507 target->getBuilder()->SetInsertPoint(BB);
508
509 // Record the function arguments in the NamedValues map.
510 NamedValues.clear();
511
512 for (auto &Arg : TheFunction->args()) {
513 // Create an alloca for this variable.
514 AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, Arg.getName(), Arg.getType());
515
516 // Store the initial value into the alloca.
517 target->getBuilder()->CreateStore(&Arg, Alloca);
518
519 // Add arguments to variable symbol table.
520 NamedValues[std::string(Arg.getName())] = Alloca;
521 }
522
523 if (Value *RetVal = Body->codegen(target)) {
524 // ret
525 target->getBuilder()->CreateRetVoid();
526 // Validate the generated code, checking for consistency.
527 verifyFunction(*TheFunction);
528
529 return TheFunction;
530 }
531
532 // Error reading body, remove function.
533 TheFunction->eraseFromParent();
534
535 return nullptr;
536 }
537