1 //===- KaleidoscopeJIT.h - A simple JIT for Kaleidoscope --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Contains a simple JIT definition for use in the kaleidoscope tutorials. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H 14 #define LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H 15 16 #include "RemoteJITUtils.h" 17 #include "llvm/ADT/STLExtras.h" 18 #include "llvm/ADT/SmallVector.h" 19 #include "llvm/ADT/Triple.h" 20 #include "llvm/ExecutionEngine/ExecutionEngine.h" 21 #include "llvm/ExecutionEngine/JITSymbol.h" 22 #include "llvm/ExecutionEngine/Orc/CompileUtils.h" 23 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" 24 #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h" 25 #include "llvm/ExecutionEngine/Orc/IndirectionUtils.h" 26 #include "llvm/ExecutionEngine/Orc/LambdaResolver.h" 27 #include "llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h" 28 #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" 29 #include "llvm/IR/DataLayout.h" 30 #include "llvm/IR/LegacyPassManager.h" 31 #include "llvm/IR/Mangler.h" 32 #include "llvm/Support/DynamicLibrary.h" 33 #include "llvm/Support/Error.h" 34 #include "llvm/Support/raw_ostream.h" 35 #include "llvm/Target/TargetMachine.h" 36 #include "llvm/Transforms/InstCombine/InstCombine.h" 37 #include "llvm/Transforms/Scalar.h" 38 #include "llvm/Transforms/Scalar/GVN.h" 39 #include <algorithm> 40 #include <cassert> 41 #include <cstdlib> 42 #include <map> 43 #include <memory> 44 #include <string> 45 #include <vector> 46 47 class PrototypeAST; 48 class ExprAST; 49 50 /// FunctionAST - This class represents a function definition itself. 51 class FunctionAST { 52 std::unique_ptr<PrototypeAST> Proto; 53 std::unique_ptr<ExprAST> Body; 54 55 public: FunctionAST(std::unique_ptr<PrototypeAST> Proto,std::unique_ptr<ExprAST> Body)56 FunctionAST(std::unique_ptr<PrototypeAST> Proto, 57 std::unique_ptr<ExprAST> Body) 58 : Proto(std::move(Proto)), Body(std::move(Body)) {} 59 60 const PrototypeAST& getProto() const; 61 const std::string& getName() const; 62 llvm::Function *codegen(); 63 }; 64 65 /// This will compile FnAST to IR, rename the function to add the given 66 /// suffix (needed to prevent a name-clash with the function's stub), 67 /// and then take ownership of the module that the function was compiled 68 /// into. 69 std::unique_ptr<llvm::Module> 70 irgenAndTakeOwnership(FunctionAST &FnAST, const std::string &Suffix); 71 72 namespace llvm { 73 namespace orc { 74 75 // Typedef the remote-client API. 76 using MyRemote = remote::OrcRemoteTargetClient; 77 78 class KaleidoscopeJIT { 79 private: 80 ExecutionSession &ES; 81 std::shared_ptr<SymbolResolver> Resolver; 82 std::unique_ptr<TargetMachine> TM; 83 const DataLayout DL; 84 LegacyRTDyldObjectLinkingLayer ObjectLayer; 85 LegacyIRCompileLayer<decltype(ObjectLayer), SimpleCompiler> CompileLayer; 86 87 using OptimizeFunction = 88 std::function<std::unique_ptr<Module>(std::unique_ptr<Module>)>; 89 90 LegacyIRTransformLayer<decltype(CompileLayer), OptimizeFunction> OptimizeLayer; 91 92 JITCompileCallbackManager *CompileCallbackMgr; 93 std::unique_ptr<IndirectStubsManager> IndirectStubsMgr; 94 MyRemote &Remote; 95 96 public: KaleidoscopeJIT(ExecutionSession & ES,MyRemote & Remote)97 KaleidoscopeJIT(ExecutionSession &ES, MyRemote &Remote) 98 : ES(ES), 99 Resolver(createLegacyLookupResolver( 100 ES, 101 [this](const std::string &Name) -> JITSymbol { 102 if (auto Sym = IndirectStubsMgr->findStub(Name, false)) 103 return Sym; 104 if (auto Sym = OptimizeLayer.findSymbol(Name, false)) 105 return Sym; 106 else if (auto Err = Sym.takeError()) 107 return std::move(Err); 108 if (auto Addr = cantFail(this->Remote.getSymbolAddress(Name))) 109 return JITSymbol(Addr, JITSymbolFlags::Exported); 110 return nullptr; 111 }, 112 [](Error Err) { cantFail(std::move(Err), "lookupFlags failed"); })), 113 TM(EngineBuilder().selectTarget(Triple(Remote.getTargetTriple()), "", 114 "", SmallVector<std::string, 0>())), 115 DL(TM->createDataLayout()), 116 ObjectLayer(AcknowledgeORCv1Deprecation, ES, 117 [this](VModuleKey K) { 118 return LegacyRTDyldObjectLinkingLayer::Resources{ 119 cantFail(this->Remote.createRemoteMemoryManager()), 120 Resolver}; 121 }), 122 CompileLayer(AcknowledgeORCv1Deprecation, ObjectLayer, 123 SimpleCompiler(*TM)), 124 OptimizeLayer(AcknowledgeORCv1Deprecation, CompileLayer, 125 [this](std::unique_ptr<Module> M) { 126 return optimizeModule(std::move(M)); 127 }), 128 Remote(Remote) { 129 auto CCMgrOrErr = Remote.enableCompileCallbacks(0); 130 if (!CCMgrOrErr) { 131 logAllUnhandledErrors(CCMgrOrErr.takeError(), errs(), 132 "Error enabling remote compile callbacks:"); 133 exit(1); 134 } 135 CompileCallbackMgr = &*CCMgrOrErr; 136 IndirectStubsMgr = cantFail(Remote.createIndirectStubsManager()); 137 llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr); 138 } 139 getTargetMachine()140 TargetMachine &getTargetMachine() { return *TM; } 141 addModule(std::unique_ptr<Module> M)142 VModuleKey addModule(std::unique_ptr<Module> M) { 143 // Add the module with a new VModuleKey. 144 auto K = ES.allocateVModule(); 145 cantFail(OptimizeLayer.addModule(K, std::move(M))); 146 return K; 147 } 148 addFunctionAST(std::unique_ptr<FunctionAST> FnAST)149 Error addFunctionAST(std::unique_ptr<FunctionAST> FnAST) { 150 // Move ownership of FnAST to a shared pointer - C++11 lambdas don't support 151 // capture-by-move, which is be required for unique_ptr. 152 auto SharedFnAST = std::shared_ptr<FunctionAST>(std::move(FnAST)); 153 154 // Set the action to compile our AST. This lambda will be run if/when 155 // execution hits the compile callback (via the stub). 156 // 157 // The steps to compile are: 158 // (1) IRGen the function. 159 // (2) Add the IR module to the JIT to make it executable like any other 160 // module. 161 // (3) Use findSymbol to get the address of the compiled function. 162 // (4) Update the stub pointer to point at the implementation so that 163 /// subsequent calls go directly to it and bypass the compiler. 164 // (5) Return the address of the implementation: this lambda will actually 165 // be run inside an attempted call to the function, and we need to 166 // continue on to the implementation to complete the attempted call. 167 // The JIT runtime (the resolver block) will use the return address of 168 // this function as the address to continue at once it has reset the 169 // CPU state to what it was immediately before the call. 170 auto CompileAction = [this, SharedFnAST]() { 171 auto M = irgenAndTakeOwnership(*SharedFnAST, "$impl"); 172 addModule(std::move(M)); 173 auto Sym = findSymbol(SharedFnAST->getName() + "$impl"); 174 assert(Sym && "Couldn't find compiled function?"); 175 JITTargetAddress SymAddr = cantFail(Sym.getAddress()); 176 if (auto Err = IndirectStubsMgr->updatePointer( 177 mangle(SharedFnAST->getName()), SymAddr)) { 178 logAllUnhandledErrors(std::move(Err), errs(), 179 "Error updating function pointer: "); 180 exit(1); 181 } 182 183 return SymAddr; 184 }; 185 186 // Create a CompileCallback suing the CompileAction - this is the re-entry 187 // point into the compiler for functions that haven't been compiled yet. 188 auto CCAddr = cantFail( 189 CompileCallbackMgr->getCompileCallback(std::move(CompileAction))); 190 191 // Create an indirect stub. This serves as the functions "canonical 192 // definition" - an unchanging (constant address) entry point to the 193 // function implementation. 194 // Initially we point the stub's function-pointer at the compile callback 195 // that we just created. In the compile action for the callback we will 196 // update the stub's function pointer to point at the function 197 // implementation that we just implemented. 198 if (auto Err = IndirectStubsMgr->createStub( 199 mangle(SharedFnAST->getName()), CCAddr, JITSymbolFlags::Exported)) 200 return Err; 201 202 return Error::success(); 203 } 204 executeRemoteExpr(JITTargetAddress ExprAddr)205 Error executeRemoteExpr(JITTargetAddress ExprAddr) { 206 return Remote.callVoidVoid(ExprAddr); 207 } 208 findSymbol(const std::string Name)209 JITSymbol findSymbol(const std::string Name) { 210 return OptimizeLayer.findSymbol(mangle(Name), true); 211 } 212 removeModule(VModuleKey K)213 void removeModule(VModuleKey K) { 214 cantFail(OptimizeLayer.removeModule(K)); 215 } 216 217 private: mangle(const std::string & Name)218 std::string mangle(const std::string &Name) { 219 std::string MangledName; 220 raw_string_ostream MangledNameStream(MangledName); 221 Mangler::getNameWithPrefix(MangledNameStream, Name, DL); 222 return MangledNameStream.str(); 223 } 224 optimizeModule(std::unique_ptr<Module> M)225 std::unique_ptr<Module> optimizeModule(std::unique_ptr<Module> M) { 226 // Create a function pass manager. 227 auto FPM = std::make_unique<legacy::FunctionPassManager>(M.get()); 228 229 // Add some optimizations. 230 FPM->add(createInstructionCombiningPass()); 231 FPM->add(createReassociatePass()); 232 FPM->add(createGVNPass()); 233 FPM->add(createCFGSimplificationPass()); 234 FPM->doInitialization(); 235 236 // Run the optimizations over all functions in the module being added to 237 // the JIT. 238 for (auto &F : *M) 239 FPM->run(F); 240 241 return M; 242 } 243 }; 244 245 } // end namespace orc 246 } // end namespace llvm 247 248 #endif // LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H 249