1 //
2 // Created by matthew on 1/26/19.
3 //
4 
5 #ifndef PROJECTM_JITCONTEXT_H
6 #define PROJECTM_JITCONTEXT_H
7 
8 #if HAVE_LLVM
9 #include "llvm/ADT/APFloat.h"
10 #include "llvm/ADT/STLExtras.h"
11 #include "llvm/ExecutionEngine/ExecutionEngine.h"
12 #include "llvm/ExecutionEngine/GenericValue.h"
13 #include "llvm/IR/BasicBlock.h"
14 #include "llvm/IR/Constants.h"
15 #include "llvm/IR/DerivedTypes.h"
16 #include "llvm/IR/Function.h"
17 #include "llvm/IR/IRBuilder.h"
18 #include "llvm/IR/LegacyPassManager.h"
19 #include "llvm/IR/LLVMContext.h"
20 #include "llvm/IR/Module.h"
21 #include "llvm/IR/Type.h"
22 #include "llvm/IR/Verifier.h"
23 #include "llvm/Support/TargetSelect.h"
24 #include "llvm/Target/TargetMachine.h"
25 #include "llvm/Transforms/InstCombine/InstCombine.h"
26 #include "llvm/Transforms/Scalar.h"
27 #include "llvm/Transforms/Scalar/GVN.h"
28 
29 
30 llvm::LLVMContext& getGlobalContext();
31 
32 // Wrapper for one module which corresponds to one jit'd Expr
33 // TODO consider associating one JitContext with one Preset
34 struct Symbol
35 {
36     llvm::Value *value = nullptr;
37     llvm::Value *assigned_value = nullptr;
38 };
39 
40 struct JitContext
41 {
42     // NOTE: the module is either "owned" by module_ptr OR executionEngine
43     llvm::LLVMContext &context;
44     std::unique_ptr<llvm::Module> module_ptr;
45     std::unique_ptr<llvm::legacy::FunctionPassManager> fpm;
46     llvm::Module *module;
47     llvm::IRBuilder<llvm::ConstantFolder,llvm::IRBuilderDefaultInserter> builder;
48     llvm::Type *floatType;
49     llvm::Value *mesh_i;
50     llvm::Value *mesh_j;
51     std::map<Param *,Symbol *> symbols;
52 
53 
JitContextJitContext54     JitContext(std::string name="LLVMModule") :
55             context(getGlobalContext()), builder(getGlobalContext())
56     {
57         floatType = llvm::Type::getFloatTy(context);
58         module_ptr = llvm::make_unique<llvm::Module>(name, context);
59         module = module_ptr.get();
60 
61         llvm::FastMathFlags fmf;
62         fmf.set();
63         builder.setFastMathFlags(fmf);
64 
65 //        module->setDataLayout(getTargetMachine().createDataLayout());
66 
67         // Create a new pass manager attached to it.
68         fpm = llvm::make_unique<llvm::legacy::FunctionPassManager>(module);
69         fpm->add(llvm::createInstructionCombiningPass());
70         fpm->add(llvm::createReassociatePass());
71         fpm->add(llvm::createGVNPass());
72         fpm->add(llvm::createCFGSimplificationPass());
73         fpm->doInitialization();
74     }
75 
~JitContextJitContext76     ~JitContext()
77     {
78         traverse<TraverseFunctors::Delete<Symbol> >(symbols);
79     }
80 
81     // helpers
82 
OptimizePassJitContext83     void OptimizePass()
84     {
85         auto end = module->end();
86         for (auto it = module->begin(); it != end; ++it)
87             fpm->run(*it);
88     }
89 
CreateConstantJitContext90     llvm::Value *CreateConstant(float x)
91     {
92         return llvm::ConstantFP::get(floatType, x);
93     }
CreateConstantJitContext94     llvm::Value *CreateConstant(int i32)
95     {
96         return llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), (uint64_t)(int64_t)i32);
97     }
CreateFloatPtrJitContext98     llvm::Constant *CreateFloatPtr(float *p)
99     {
100         llvm::ConstantInt *pv = llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), (uint64_t)p);
101         return llvm::ConstantExpr::getIntToPtr(pv , llvm::PointerType::get(floatType, 1));
102     }
CallIntrinsicJitContext103     llvm::Value *CallIntrinsic(llvm::Intrinsic::ID id, llvm::Value *value)
104     {
105         std::vector<llvm::Type *> arg_type;
106         arg_type.push_back(floatType);
107         llvm::Function *function = llvm::Intrinsic::getDeclaration(module, id, arg_type);
108         std::vector<llvm::Value *> arg;
109         arg.push_back(value);
110         return builder.CreateCall(function, arg);
111     }
CallIntrinsic2JitContext112     llvm::Value *CallIntrinsic2(llvm::Intrinsic::ID id, llvm::Value *x, llvm::Value *y)
113     {
114         std::vector<llvm::Type *> arg_type;
115         arg_type.push_back(floatType);
116         arg_type.push_back(floatType);
117         llvm::Function *function = llvm::Intrinsic::getDeclaration(module, id, arg_type);
118         std::vector<llvm::Value *> arg;
119         arg.push_back(x);
120         arg.push_back(y);
121         return builder.CreateCall(function, arg);
122     }
CallIntrinsic3JitContext123     llvm::Value *CallIntrinsic3(llvm::Intrinsic::ID id, llvm::Value *x, llvm::Value *y, llvm::Value *z)
124     {
125         std::vector<llvm::Type *> arg_type;
126         arg_type.push_back(floatType);
127         arg_type.push_back(floatType);
128         arg_type.push_back(floatType);
129         llvm::Function *function = llvm::Intrinsic::getDeclaration(module, id, arg_type);
130         std::vector<llvm::Value *> arg;
131         arg.push_back(x);
132         arg.push_back(y);
133         arg.push_back(z);
134         return builder.CreateCall(function, arg);
135     }
136 
137     std::vector<llvm::Function *> parent;
138     std::vector<llvm::BasicBlock *> then_block;
139     std::vector<llvm::BasicBlock *> else_block;
140     std::vector<llvm::BasicBlock *> merge_block;
141 
StartTernaryJitContext142     void StartTernary(llvm::Value *condition)
143     {
144         parent.push_back(builder.GetInsertBlock()->getParent());
145         then_block.push_back(llvm::BasicBlock::Create(context, "then", parent.back()));
146         else_block.push_back(llvm::BasicBlock::Create(context, "else"));
147         merge_block.push_back(llvm::BasicBlock::Create(context, "fi"));
148         builder.CreateCondBr(condition, then_block.back(), else_block.back());
149     }
withThenJitContext150     void withThen()
151     {
152         builder.SetInsertPoint(then_block.back());
153     }
withElseJitContext154     void withElse()
155     {
156         // finish the withThen block, remember the last block of the THEN (may have changed)
157         llvm::BasicBlock *lastBlockOfThen = &parent.back()->getBasicBlockList().back();
158         then_block.back() = lastBlockOfThen;
159         builder.CreateBr(merge_block.back());
160         parent.back()->getBasicBlockList().push_back(else_block.back());
161         builder.SetInsertPoint(else_block.back());
162     }
FinishTernaryJitContext163     llvm::Value *FinishTernary(llvm::Value *thenValue, llvm::Value *elseValue)
164     {
165         // finish the withElse block, remember the last block of the ELSE
166         llvm::BasicBlock *lastBlockOfElse = &parent.back()->getBasicBlockList().back();
167         else_block.back() = lastBlockOfElse;
168         builder.CreateBr(merge_block.back());
169         parent.back()->getBasicBlockList().push_back(merge_block.back());
170         builder.SetInsertPoint(merge_block.back());
171         llvm::PHINode *mergeValue = builder.CreatePHI(floatType, 2, "iftmp");
172         mergeValue->addIncoming(thenValue, then_block.back());
173         mergeValue->addIncoming(elseValue, else_block.back());
174         parent.pop_back();
175         then_block.pop_back();
176         else_block.pop_back();
177         merge_block.pop_back();
178         return mergeValue;
179     }
180 
getSymbolValueJitContext181     llvm::Value *getSymbolValue(Param *p)
182     {
183         auto it = symbols.find(p);
184         Symbol *sym = (it == symbols.end()) ? nullptr : it->second;
185         if (sym && sym->value)
186             return sym->value;
187         llvm::Value *v = p->_llvm(*this);
188         // don't remember this value if we are in a conditional
189         if (!merge_block.empty())
190             return v;
191         if (nullptr == sym)
192         {
193             sym = new Symbol();
194             symbols.insert(std::make_pair(p, sym));
195         }
196         sym->value = v;
197         return v;
198     }
199     // TODO: we optimize READ of parameters, but not WRITE
200     // It would help to delay writing to parameters until the last write (or end of function)
201     // NOTE: be careful of the fact that milkdrop allows assignments inside conditionals
202     // I don't think the projectM parser handles this yet (?) but just be aware
assignSymbolValueJitContext203     void assignSymbolValue(Param *p, llvm::Value *v)
204     {
205         v->setName(p->name);
206         auto it = symbols.find(p);
207         Symbol *sym = (it == symbols.end()) ? nullptr : it->second;
208         if (nullptr == sym)
209         {
210             sym = new Symbol();
211             symbols.insert(std::make_pair(p, sym));
212         }
213         sym->value = v;
214         sym->assigned_value = v;
215     }
216 };
217 #endif
218 
219 #endif //PROJECTM_JITCONTEXT_H
220