1 /*
2  Copyright Disney Enterprises, Inc.  All rights reserved.
3 
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License
6  and the following modification to it: Section 6 Trademarks.
7  deleted and replaced with:
8 
9  6. Trademarks. This License does not grant permission to use the
10  trade names, trademarks, service marks, or product names of the
11  Licensor and its affiliates, except as required for reproducing
12  the content of the NOTICE file.
13 
14  You may obtain a copy of the License at
15  http://www.apache.org/licenses/LICENSE-2.0
16 */
17 
18 #include "ExprConfig.h"
19 #include "ExprLLVMAll.h"
20 #include "VarBlock.h"
21 
22 #if defined(SEEXPR_ENABLE_LLVM)
23 #include <llvm/Config/llvm-config.h>
24 #include <llvm/Support/Compiler.h>
25 #endif
26 
27 extern "C" void SeExpr2LLVMEvalFPVarRef(SeExpr2::ExprVarRef *seVR, double *result);
28 extern "C" void SeExpr2LLVMEvalStrVarRef(SeExpr2::ExprVarRef *seVR, double *result);
29 extern "C" void SeExpr2LLVMEvalCustomFunction(int *opDataArg,
30                                               double *fpArg,
31                                               char **strArg,
32                                               void **funcdata,
33                                               const SeExpr2::ExprFuncNode *node);
34 
35 namespace SeExpr2 {
36 #if defined(SEEXPR_ENABLE_LLVM)
37 
38 LLVM_VALUE promoteToDim(LLVM_VALUE val, unsigned dim, llvm::IRBuilder<> &Builder);
39 
40 class LLVMEvaluator {
41     // TODO: this seems needlessly complex, let's fix it
42     // TODO: let the dev code allocate memory?
43     // FP is the native function for this expression.
44     template <class T>
45     class LLVMEvaluationContext {
46       private:
47         typedef void (*FunctionPtr)(T *, char **, uint32_t);
48         typedef void (*FunctionPtrMultiple)(char **, uint32_t, uint32_t, uint32_t);
49         FunctionPtr functionPtr;
50         FunctionPtrMultiple functionPtrMultiple;
51         T *resultData;
52 
53       public:
54         LLVMEvaluationContext(const LLVMEvaluationContext &) = delete;
55         LLVMEvaluationContext &operator=(const LLVMEvaluationContext &) = delete;
~LLVMEvaluationContext()56         ~LLVMEvaluationContext() { delete[] resultData; }
LLVMEvaluationContext()57         LLVMEvaluationContext() : functionPtr(nullptr), resultData(nullptr) {}
init(void * fp,void * fpLoop,int dim)58         void init(void *fp, void *fpLoop, int dim) {
59             reset();
60             functionPtr = reinterpret_cast<FunctionPtr>(fp);
61             functionPtrMultiple = reinterpret_cast<FunctionPtrMultiple>(fpLoop);
62             resultData = new T[dim];
63         }
reset()64         void reset() {
65             if (resultData) delete[] resultData;
66             functionPtr = nullptr;
67             resultData = nullptr;
68         }
operator()69         const T *operator()(VarBlock *varBlock) {
70             assert(functionPtr && resultData);
71             functionPtr(resultData, varBlock ? varBlock->data() : nullptr, varBlock ? varBlock->indirectIndex : 0);
72             return resultData;
73         }
operator()74         void operator()(VarBlock *varBlock, size_t outputVarBlockOffset, size_t rangeStart, size_t rangeEnd) {
75             assert(functionPtr && resultData);
76             functionPtrMultiple(varBlock ? varBlock->data() : nullptr, outputVarBlockOffset, rangeStart, rangeEnd);
77         }
78     };
79     std::unique_ptr<LLVMEvaluationContext<double>> _llvmEvalFP;
80     std::unique_ptr<LLVMEvaluationContext<char *>> _llvmEvalStr;
81 
82     std::unique_ptr<llvm::LLVMContext> _llvmContext;
83     std::unique_ptr<llvm::ExecutionEngine> TheExecutionEngine;
84 
85   public:
LLVMEvaluator()86     LLVMEvaluator() {}
87 
evalStr(VarBlock * varBlock)88     const char *evalStr(VarBlock *varBlock) { return *(*_llvmEvalStr)(varBlock); }
evalFP(VarBlock * varBlock)89     const double *evalFP(VarBlock *varBlock) { return (*_llvmEvalFP)(varBlock); }
90 
evalMultiple(VarBlock * varBlock,uint32_t outputVarBlockOffset,uint32_t rangeStart,uint32_t rangeEnd)91     void evalMultiple(VarBlock *varBlock, uint32_t outputVarBlockOffset, uint32_t rangeStart, uint32_t rangeEnd) {
92         return (*_llvmEvalFP)(varBlock, outputVarBlockOffset, rangeStart, rangeEnd);
93     }
94 
debugPrint()95     void debugPrint() {
96         // TheModule->print(llvm::errs(), nullptr);
97     }
98 
prepLLVM(ExprNode * parseTree,ExprType desiredReturnType)99     bool prepLLVM(ExprNode *parseTree, ExprType desiredReturnType) {
100         using namespace llvm;
101         InitializeNativeTarget();
102         InitializeNativeTargetAsmPrinter();
103         InitializeNativeTargetAsmParser();
104 
105         std::string uniqueName = getUniqueName();
106 
107         // create Module
108         _llvmContext.reset(new LLVMContext());
109 
110         std::unique_ptr<Module> TheModule(new Module(uniqueName + "_module", *_llvmContext));
111 
112         // create all needed types
113         Type        *i8PtrTy        = Type::getInt8PtrTy(*_llvmContext);        // char *
114         PointerType *i8PtrPtrTy     = PointerType::getUnqual(i8PtrTy);          // char **
115         PointerType *i8PtrPtrPtrTy  = PointerType::getUnqual(i8PtrPtrTy);       // char ***
116         Type        *i32Ty          = Type::getInt32Ty(*_llvmContext);          // int
117         Type        *i32PtrTy       = Type::getInt32PtrTy(*_llvmContext);       // int *
118         Type        *i64Ty          = Type::getInt64Ty(*_llvmContext);          // int64 *
119         Type        *doublePtrTy    = Type::getDoublePtrTy(*_llvmContext);      // double *
120         PointerType *doublePtrPtrTy = PointerType::getUnqual(doublePtrTy);      // double **
121         Type        *voidTy         = Type::getVoidTy(*_llvmContext);           // void
122 
123         // create bindings to helper functions for variables and fucntions
124         Function *SeExpr2LLVMEvalCustomFunctionFunc = nullptr;
125         Function *SeExpr2LLVMEvalFPVarRefFunc = nullptr;
126         Function *SeExpr2LLVMEvalStrVarRefFunc = nullptr;
127         Function *SeExpr2LLVMEvalstrlenFunc = nullptr;
128         Function *SeExpr2LLVMEvalmallocFunc = nullptr;
129         Function *SeExpr2LLVMEvalfreeFunc = nullptr;
130         Function *SeExpr2LLVMEvalmemsetFunc = nullptr;
131         Function *SeExpr2LLVMEvalstrcatFunc = nullptr;
132         {
133             {
134                 FunctionType *FT = FunctionType::get(voidTy, {i32PtrTy, doublePtrTy, i8PtrPtrTy, i8PtrPtrTy, i64Ty}, false);
135                 SeExpr2LLVMEvalCustomFunctionFunc = Function::Create(FT, GlobalValue::ExternalLinkage, "SeExpr2LLVMEvalCustomFunction", TheModule.get());
136             }
137             {
138                 FunctionType *FT = FunctionType::get(voidTy, {i8PtrTy, doublePtrTy}, false);
139                 SeExpr2LLVMEvalFPVarRefFunc = Function::Create(FT, GlobalValue::ExternalLinkage, "SeExpr2LLVMEvalFPVarRef", TheModule.get());
140             }
141             {
142                 FunctionType *FT = FunctionType::get(voidTy, {i8PtrTy, i8PtrPtrTy}, false);
143                 SeExpr2LLVMEvalStrVarRefFunc = Function::Create(FT, GlobalValue::ExternalLinkage, "SeExpr2LLVMEvalStrVarRef", TheModule.get());
144             }
145             {
146                 FunctionType *FT = FunctionType::get(i32Ty, { i8PtrTy }, false);
147                 SeExpr2LLVMEvalstrlenFunc = Function::Create(FT, Function::ExternalLinkage, "strlen", TheModule.get());
148             }
149             {
150                 FunctionType *FT = FunctionType::get(i8PtrTy, { i32Ty }, false);
151                 SeExpr2LLVMEvalmallocFunc = Function::Create(FT, Function::ExternalLinkage, "malloc", TheModule.get());
152             }
153             {
154                 FunctionType *FT = FunctionType::get(voidTy, { i8PtrTy }, false);
155                 SeExpr2LLVMEvalfreeFunc = Function::Create(FT, Function::ExternalLinkage, "free", TheModule.get());
156             }
157             {
158                 FunctionType *FT = FunctionType::get(voidTy, { i8PtrTy, i32Ty, i32Ty }, false);
159                 SeExpr2LLVMEvalmemsetFunc = Function::Create(FT, Function::ExternalLinkage, "memset", TheModule.get());
160             }
161             {
162                 FunctionType *FT = FunctionType::get(i8PtrTy, { i8PtrTy, i8PtrTy }, false);
163                 SeExpr2LLVMEvalstrcatFunc = Function::Create(FT, Function::ExternalLinkage, "strcat", TheModule.get());
164             }
165         }
166 
167         // create function and entry BB
168         bool desireFP = desiredReturnType.isFP();
169         Type *ParamTys[] = {
170             desireFP ? doublePtrTy : i8PtrPtrTy,
171             doublePtrPtrTy,
172             i32Ty
173         };
174         FunctionType *FT = FunctionType::get(voidTy, ParamTys, false);
175         Function *F = Function::Create(FT, Function::ExternalLinkage, uniqueName + "_func", TheModule.get());
176 #if LLVM_VERSION_MAJOR > 4
177         F->addAttribute(llvm::AttributeList::FunctionIndex, llvm::Attribute::AlwaysInline);
178 #else
179         F->addAttribute(llvm::AttributeSet::FunctionIndex, llvm::Attribute::AlwaysInline);
180 #endif
181         {
182             // label the function with names
183             const char *names[] = {"outputPointer", "dataBlock", "indirectIndex"};
184             int idx = 0;
185             for (auto &arg : F->args()) arg.setName(names[idx++]);
186         }
187 
188         unsigned int dimDesired = (unsigned)desiredReturnType.dim();
189         unsigned int dimGenerated = parseTree->type().dim();
190         {
191             BasicBlock *BB = BasicBlock::Create(*_llvmContext, "entry", F);
192             IRBuilder<> Builder(BB);
193 
194             // codegen
195             Value *lastVal = parseTree->codegen(Builder);
196 
197             // return values through parameter.
198             Value *firstArg = &*F->arg_begin();
199             if (desireFP) {
200                 if (dimGenerated > 1) {
201                     Value *newLastVal = promoteToDim(lastVal, dimDesired, Builder);
202                     assert(newLastVal->getType()->getVectorNumElements() >= dimDesired);
203                     for (unsigned i = 0; i < dimDesired; ++i) {
204                         Value *idx = ConstantInt::get(Type::getInt64Ty(*_llvmContext), i);
205                         Value *val = Builder.CreateExtractElement(newLastVal, idx);
206                         Value *ptr = Builder.CreateInBoundsGEP(firstArg, idx);
207                         Builder.CreateStore(val, ptr);
208                     }
209                 } else if (dimGenerated == 1) {
210                     for (unsigned i = 0; i < dimDesired; ++i) {
211                         Value *ptr = Builder.CreateConstInBoundsGEP1_32(nullptr, firstArg, i);
212                         Builder.CreateStore(lastVal, ptr);
213                     }
214                 } else {
215                     assert(false && "error. dim of FP is less than 1.");
216                 }
217             } else {
218                 Builder.CreateStore(lastVal, firstArg);
219             }
220 
221             Builder.CreateRetVoid();
222         }
223 
224         // write a new function
225         FunctionType *FTLOOP = FunctionType::get(voidTy, {i8PtrTy, i32Ty, i32Ty, i32Ty}, false);
226         Function *FLOOP = Function::Create(FTLOOP, Function::ExternalLinkage, uniqueName + "_loopfunc", TheModule.get());
227         {
228             // label the function with names
229             const char *names[] = {"dataBlock", "outputVarBlockOffset", "rangeStart", "rangeEnd"};
230             int idx = 0;
231             for (auto &arg : FLOOP->args()) {
232                 arg.setName(names[idx++]);
233             }
234         }
235         {
236             // Local variables
237             Value *dimValue = ConstantInt::get(i32Ty, dimDesired);
238             Value *oneValue = ConstantInt::get(i32Ty, 1);
239 
240             // Basic blocks
241             BasicBlock *entryBlock = BasicBlock::Create(*_llvmContext, "entry", FLOOP);
242             BasicBlock *loopCmpBlock = BasicBlock::Create(*_llvmContext, "loopCmp", FLOOP);
243             BasicBlock *loopRepeatBlock = BasicBlock::Create(*_llvmContext, "loopRepeat", FLOOP);
244             BasicBlock *loopIncBlock = BasicBlock::Create(*_llvmContext, "loopInc", FLOOP);
245             BasicBlock *loopEndBlock = BasicBlock::Create(*_llvmContext, "loopEnd", FLOOP);
246             IRBuilder<> Builder(entryBlock);
247             Builder.SetInsertPoint(entryBlock);
248 
249             // Get arguments
250             Function::arg_iterator argIterator = FLOOP->arg_begin();
251             Value *varBlockCharPtrPtrArg = &*argIterator;       ++argIterator;
252             Value *outputVarBlockOffsetArg = &*argIterator;     ++argIterator;
253             Value *rangeStartArg = &*argIterator;                ++argIterator;
254             Value *rangeEndArg = &*argIterator;                    ++argIterator;
255 
256             // Allocate Variables
257             Value *rangeStartVar = Builder.CreateAlloca(Type::getInt32Ty(*_llvmContext), oneValue, "rangeStartVar");
258             Value *rangeEndVar = Builder.CreateAlloca(Type::getInt32Ty(*_llvmContext), oneValue, "rangeEndVar");
259             Value *indexVar = Builder.CreateAlloca(Type::getInt32Ty(*_llvmContext), oneValue, "indexVar");
260             Value *outputVarBlockOffsetVar = Builder.CreateAlloca(Type::getInt32Ty(*_llvmContext), oneValue, "outputVarBlockOffsetVar");
261             Value *varBlockDoublePtrPtrVar = Builder.CreateAlloca(doublePtrPtrTy, oneValue, "varBlockDoublePtrPtrVar");
262             Value *varBlockTPtrPtrVar = Builder.CreateAlloca(desireFP == true ? doublePtrPtrTy : i8PtrPtrPtrTy, oneValue, "varBlockTPtrPtrVar");
263 
264             // Copy variables from args
265             Builder.CreateStore(Builder.CreatePointerCast(varBlockCharPtrPtrArg, doublePtrPtrTy, "varBlockAsDoublePtrPtr"), varBlockDoublePtrPtrVar);
266             Builder.CreateStore(Builder.CreatePointerCast(varBlockCharPtrPtrArg, desireFP ? doublePtrPtrTy : i8PtrPtrPtrTy, "varBlockAsTPtrPtr"), varBlockTPtrPtrVar);
267             Builder.CreateStore(rangeStartArg, rangeStartVar);
268             Builder.CreateStore(rangeEndArg, rangeEndVar);
269             Builder.CreateStore(outputVarBlockOffsetArg, outputVarBlockOffsetVar);
270 
271             // Set output pointer
272             Value *outputBasePtrPtr = Builder.CreateGEP(nullptr, Builder.CreateLoad(varBlockTPtrPtrVar), outputVarBlockOffsetArg, "outputBasePtrPtr");
273             Value *outputBasePtr = Builder.CreateLoad(outputBasePtrPtr, "outputBasePtr");
274             Builder.CreateStore(Builder.CreateLoad(rangeStartVar), indexVar);
275 
276             Builder.CreateBr(loopCmpBlock);
277             Builder.SetInsertPoint(loopCmpBlock);
278             Value *cond = Builder.CreateICmpULT(Builder.CreateLoad(indexVar), Builder.CreateLoad(rangeEndVar));
279             Builder.CreateCondBr(cond, loopRepeatBlock, loopEndBlock);
280 
281             Builder.SetInsertPoint(loopRepeatBlock);
282             Value *myOutputPtr = Builder.CreateGEP(nullptr, outputBasePtr, Builder.CreateMul(dimValue, Builder.CreateLoad(indexVar)));
283             Builder.CreateCall(F, {myOutputPtr, Builder.CreateLoad(varBlockDoublePtrPtrVar), Builder.CreateLoad(indexVar)});
284 
285             Builder.CreateBr(loopIncBlock);
286 
287             Builder.SetInsertPoint(loopIncBlock);
288             Builder.CreateStore(Builder.CreateAdd(Builder.CreateLoad(indexVar), oneValue), indexVar);
289             Builder.CreateBr(loopCmpBlock);
290 
291             Builder.SetInsertPoint(loopEndBlock);
292             Builder.CreateRetVoid();
293         }
294 
295         if (Expression::debugging) {
296             #ifdef DEBUG
297             std::cerr << "Pre verified LLVM byte code " << std::endl;
298             TheModule->print(llvm::errs(), nullptr);
299             #endif
300         }
301 
302         // TODO: Find out if there is a new way to veirfy
303         // if (verifyModule(*TheModule)) {
304         //     std::cerr << "Logic error in code generation of LLVM alert developers" << std::endl;
305         //     TheModule->print(llvm::errs(), nullptr);
306         // }
307         Module *altModule = TheModule.get();
308         std::string ErrStr;
309         TheExecutionEngine.reset(EngineBuilder(std::move(TheModule))
310                                      .setErrorStr(&ErrStr)
311                                  //     .setUseMCJIT(true)
312                                      .setOptLevel(CodeGenOpt::Aggressive)
313                                      .create());
314 
315         altModule->setDataLayout(TheExecutionEngine->getDataLayout());
316 
317         // Add bindings to C linkage helper functions
318         TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalFPVarRefFunc, (void *)SeExpr2LLVMEvalFPVarRef);
319         TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalStrVarRefFunc, (void *)SeExpr2LLVMEvalStrVarRef);
320         TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalCustomFunctionFunc, (void *)SeExpr2LLVMEvalCustomFunction);
321         TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalstrlenFunc, (void *)strlen);
322         TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalstrcatFunc, (void *)strcat);
323         TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalmemsetFunc, (void *)memset);
324         TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalmallocFunc, (void *)malloc);
325         TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalfreeFunc, (void *)free);
326 
327         // [verify]
328         std::string errorStr;
329         llvm::raw_string_ostream raw(errorStr);
330         if (llvm::verifyModule(*altModule, &raw)) {
331             parseTree->addError(ErrorCode::Unknown, { raw.str() });
332             return false;
333         }
334 
335         // Setup optimization
336         llvm::PassManagerBuilder builder;
337         std::unique_ptr<llvm::legacy::PassManager> pm(new llvm::legacy::PassManager);
338         std::unique_ptr<llvm::legacy::FunctionPassManager> fpm(new llvm::legacy::FunctionPassManager(altModule));
339         builder.OptLevel = 3;
340 #if (LLVM_VERSION_MAJOR >= 4)
341         builder.Inliner = llvm::createAlwaysInlinerLegacyPass();
342 #else
343         builder.Inliner = llvm::createAlwaysInlinerPass();
344 #endif
345         builder.populateModulePassManager(*pm);
346         // fpm->add(new llvm::DataLayoutPass());
347         builder.populateFunctionPassManager(*fpm);
348         fpm->run(*F);
349         fpm->run(*FLOOP);
350         pm->run(*altModule);
351 
352         // Create the JIT.  This takes ownership of the module.
353 
354         if (!TheExecutionEngine) {
355             fprintf(stderr, "Could not create ExecutionEngine: %s\n", ErrStr.c_str());
356             exit(1);
357         }
358 
359         TheExecutionEngine->finalizeObject();
360         void *fp = TheExecutionEngine->getPointerToFunction(F);
361         void *fpLoop = TheExecutionEngine->getPointerToFunction(FLOOP);
362         if (desireFP) {
363             _llvmEvalFP.reset(new LLVMEvaluationContext<double>);
364             _llvmEvalFP->init(fp, fpLoop, dimDesired);
365         } else {
366             _llvmEvalStr.reset(new LLVMEvaluationContext<char *>);
367             _llvmEvalStr->init(fp, fpLoop, dimDesired);
368         }
369 
370         if (Expression::debugging) {
371             #ifdef DEBUG
372             std::cerr << "Pre verified LLVM byte code " << std::endl;
373             altModule->print(llvm::errs(), nullptr);
374             #endif
375         }
376 
377         return true;
378     }
379 
getUniqueName()380     std::string getUniqueName() const {
381         std::ostringstream o;
382         o << std::setbase(16) << (uint64_t)(this);
383         return ("_" + o.str());
384     }
385 };
386 
387 #else  // no LLVM support
388 class LLVMEvaluator {
389   public:
390     void unsupported() { throw std::runtime_error("LLVM is not enabled in build"); }
391     const char *evalStr(VarBlock *varBlock) {
392         unsupported();
393         return "";
394     }
395     const double *evalFP(VarBlock *varBlock) {
396         unsupported();
397         return 0;
398     }
399     bool prepLLVM(ExprNode *parseTree, ExprType desiredReturnType) {
400         unsupported();
401         return false;
402     }
403     void evalMultiple(VarBlock *varBlock, int outputVarBlockOffset, size_t rangeStart, size_t rangeEnd) {
404         unsupported();
405     }
406     void debugPrint() {}
407 };
408 #endif
409 
410 }  // end namespace SeExpr2
411