1 // 2 // Created by matthew on 1/26/19. 3 // 4 5 #ifndef PROJECTM_JITCONTEXT_H 6 #define PROJECTM_JITCONTEXT_H 7 8 #if HAVE_LLVM 9 #include "llvm/ADT/APFloat.h" 10 #include "llvm/ADT/STLExtras.h" 11 #include "llvm/ExecutionEngine/ExecutionEngine.h" 12 #include "llvm/ExecutionEngine/GenericValue.h" 13 #include "llvm/IR/BasicBlock.h" 14 #include "llvm/IR/Constants.h" 15 #include "llvm/IR/DerivedTypes.h" 16 #include "llvm/IR/Function.h" 17 #include "llvm/IR/IRBuilder.h" 18 #include "llvm/IR/LegacyPassManager.h" 19 #include "llvm/IR/LLVMContext.h" 20 #include "llvm/IR/Module.h" 21 #include "llvm/IR/Type.h" 22 #include "llvm/IR/Verifier.h" 23 #include "llvm/Support/TargetSelect.h" 24 #include "llvm/Target/TargetMachine.h" 25 #include "llvm/Transforms/InstCombine/InstCombine.h" 26 #include "llvm/Transforms/Scalar.h" 27 #include "llvm/Transforms/Scalar/GVN.h" 28 29 30 llvm::LLVMContext& getGlobalContext(); 31 32 // Wrapper for one module which corresponds to one jit'd Expr 33 // TODO consider associating one JitContext with one Preset 34 struct Symbol 35 { 36 llvm::Value *value = nullptr; 37 llvm::Value *assigned_value = nullptr; 38 }; 39 40 struct JitContext 41 { 42 // NOTE: the module is either "owned" by module_ptr OR executionEngine 43 llvm::LLVMContext &context; 44 std::unique_ptr<llvm::Module> module_ptr; 45 std::unique_ptr<llvm::legacy::FunctionPassManager> fpm; 46 llvm::Module *module; 47 llvm::IRBuilder<llvm::ConstantFolder,llvm::IRBuilderDefaultInserter> builder; 48 llvm::Type *floatType; 49 llvm::Value *mesh_i; 50 llvm::Value *mesh_j; 51 std::map<Param *,Symbol *> symbols; 52 53 JitContextJitContext54 JitContext(std::string name="LLVMModule") : 55 context(getGlobalContext()), builder(getGlobalContext()) 56 { 57 floatType = llvm::Type::getFloatTy(context); 58 module_ptr = llvm::make_unique<llvm::Module>(name, context); 59 module = module_ptr.get(); 60 61 llvm::FastMathFlags fmf; 62 fmf.set(); 63 builder.setFastMathFlags(fmf); 64 65 // module->setDataLayout(getTargetMachine().createDataLayout()); 66 67 // Create a new pass manager attached to it. 68 fpm = llvm::make_unique<llvm::legacy::FunctionPassManager>(module); 69 fpm->add(llvm::createInstructionCombiningPass()); 70 fpm->add(llvm::createReassociatePass()); 71 fpm->add(llvm::createGVNPass()); 72 fpm->add(llvm::createCFGSimplificationPass()); 73 fpm->doInitialization(); 74 } 75 ~JitContextJitContext76 ~JitContext() 77 { 78 traverse<TraverseFunctors::Delete<Symbol> >(symbols); 79 } 80 81 // helpers 82 OptimizePassJitContext83 void OptimizePass() 84 { 85 auto end = module->end(); 86 for (auto it = module->begin(); it != end; ++it) 87 fpm->run(*it); 88 } 89 CreateConstantJitContext90 llvm::Value *CreateConstant(float x) 91 { 92 return llvm::ConstantFP::get(floatType, x); 93 } CreateConstantJitContext94 llvm::Value *CreateConstant(int i32) 95 { 96 return llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), (uint64_t)(int64_t)i32); 97 } CreateFloatPtrJitContext98 llvm::Constant *CreateFloatPtr(float *p) 99 { 100 llvm::ConstantInt *pv = llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), (uint64_t)p); 101 return llvm::ConstantExpr::getIntToPtr(pv , llvm::PointerType::get(floatType, 1)); 102 } CallIntrinsicJitContext103 llvm::Value *CallIntrinsic(llvm::Intrinsic::ID id, llvm::Value *value) 104 { 105 std::vector<llvm::Type *> arg_type; 106 arg_type.push_back(floatType); 107 llvm::Function *function = llvm::Intrinsic::getDeclaration(module, id, arg_type); 108 std::vector<llvm::Value *> arg; 109 arg.push_back(value); 110 return builder.CreateCall(function, arg); 111 } CallIntrinsic2JitContext112 llvm::Value *CallIntrinsic2(llvm::Intrinsic::ID id, llvm::Value *x, llvm::Value *y) 113 { 114 std::vector<llvm::Type *> arg_type; 115 arg_type.push_back(floatType); 116 arg_type.push_back(floatType); 117 llvm::Function *function = llvm::Intrinsic::getDeclaration(module, id, arg_type); 118 std::vector<llvm::Value *> arg; 119 arg.push_back(x); 120 arg.push_back(y); 121 return builder.CreateCall(function, arg); 122 } CallIntrinsic3JitContext123 llvm::Value *CallIntrinsic3(llvm::Intrinsic::ID id, llvm::Value *x, llvm::Value *y, llvm::Value *z) 124 { 125 std::vector<llvm::Type *> arg_type; 126 arg_type.push_back(floatType); 127 arg_type.push_back(floatType); 128 arg_type.push_back(floatType); 129 llvm::Function *function = llvm::Intrinsic::getDeclaration(module, id, arg_type); 130 std::vector<llvm::Value *> arg; 131 arg.push_back(x); 132 arg.push_back(y); 133 arg.push_back(z); 134 return builder.CreateCall(function, arg); 135 } 136 137 std::vector<llvm::Function *> parent; 138 std::vector<llvm::BasicBlock *> then_block; 139 std::vector<llvm::BasicBlock *> else_block; 140 std::vector<llvm::BasicBlock *> merge_block; 141 StartTernaryJitContext142 void StartTernary(llvm::Value *condition) 143 { 144 parent.push_back(builder.GetInsertBlock()->getParent()); 145 then_block.push_back(llvm::BasicBlock::Create(context, "then", parent.back())); 146 else_block.push_back(llvm::BasicBlock::Create(context, "else")); 147 merge_block.push_back(llvm::BasicBlock::Create(context, "fi")); 148 builder.CreateCondBr(condition, then_block.back(), else_block.back()); 149 } withThenJitContext150 void withThen() 151 { 152 builder.SetInsertPoint(then_block.back()); 153 } withElseJitContext154 void withElse() 155 { 156 // finish the withThen block, remember the last block of the THEN (may have changed) 157 llvm::BasicBlock *lastBlockOfThen = &parent.back()->getBasicBlockList().back(); 158 then_block.back() = lastBlockOfThen; 159 builder.CreateBr(merge_block.back()); 160 parent.back()->getBasicBlockList().push_back(else_block.back()); 161 builder.SetInsertPoint(else_block.back()); 162 } FinishTernaryJitContext163 llvm::Value *FinishTernary(llvm::Value *thenValue, llvm::Value *elseValue) 164 { 165 // finish the withElse block, remember the last block of the ELSE 166 llvm::BasicBlock *lastBlockOfElse = &parent.back()->getBasicBlockList().back(); 167 else_block.back() = lastBlockOfElse; 168 builder.CreateBr(merge_block.back()); 169 parent.back()->getBasicBlockList().push_back(merge_block.back()); 170 builder.SetInsertPoint(merge_block.back()); 171 llvm::PHINode *mergeValue = builder.CreatePHI(floatType, 2, "iftmp"); 172 mergeValue->addIncoming(thenValue, then_block.back()); 173 mergeValue->addIncoming(elseValue, else_block.back()); 174 parent.pop_back(); 175 then_block.pop_back(); 176 else_block.pop_back(); 177 merge_block.pop_back(); 178 return mergeValue; 179 } 180 getSymbolValueJitContext181 llvm::Value *getSymbolValue(Param *p) 182 { 183 auto it = symbols.find(p); 184 Symbol *sym = (it == symbols.end()) ? nullptr : it->second; 185 if (sym && sym->value) 186 return sym->value; 187 llvm::Value *v = p->_llvm(*this); 188 // don't remember this value if we are in a conditional 189 if (!merge_block.empty()) 190 return v; 191 if (nullptr == sym) 192 { 193 sym = new Symbol(); 194 symbols.insert(std::make_pair(p, sym)); 195 } 196 sym->value = v; 197 return v; 198 } 199 // TODO: we optimize READ of parameters, but not WRITE 200 // It would help to delay writing to parameters until the last write (or end of function) 201 // NOTE: be careful of the fact that milkdrop allows assignments inside conditionals 202 // I don't think the projectM parser handles this yet (?) but just be aware assignSymbolValueJitContext203 void assignSymbolValue(Param *p, llvm::Value *v) 204 { 205 v->setName(p->name); 206 auto it = symbols.find(p); 207 Symbol *sym = (it == symbols.end()) ? nullptr : it->second; 208 if (nullptr == sym) 209 { 210 sym = new Symbol(); 211 symbols.insert(std::make_pair(p, sym)); 212 } 213 sym->value = v; 214 sym->assigned_value = v; 215 } 216 }; 217 #endif 218 219 #endif //PROJECTM_JITCONTEXT_H 220