1 //===- KaleidoscopeJIT.h - A simple JIT for Kaleidoscope --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Contains a simple JIT definition for use in the kaleidoscope tutorials.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
14 #define LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
15 
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ExecutionEngine/ExecutionEngine.h"
18 #include "llvm/ExecutionEngine/JITSymbol.h"
19 #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
20 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
21 #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
22 #include "llvm/ExecutionEngine/Orc/IndirectionUtils.h"
23 #include "llvm/ExecutionEngine/Orc/LambdaResolver.h"
24 #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
25 #include "llvm/ExecutionEngine/RTDyldMemoryManager.h"
26 #include "llvm/ExecutionEngine/SectionMemoryManager.h"
27 #include "llvm/IR/DataLayout.h"
28 #include "llvm/IR/LegacyPassManager.h"
29 #include "llvm/IR/Mangler.h"
30 #include "llvm/Support/DynamicLibrary.h"
31 #include "llvm/Support/Error.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "llvm/Target/TargetMachine.h"
34 #include "llvm/Transforms/InstCombine/InstCombine.h"
35 #include "llvm/Transforms/Scalar.h"
36 #include "llvm/Transforms/Scalar/GVN.h"
37 #include <algorithm>
38 #include <cassert>
39 #include <cstdlib>
40 #include <map>
41 #include <memory>
42 #include <string>
43 #include <vector>
44 
45 class PrototypeAST;
46 class ExprAST;
47 
48 /// FunctionAST - This class represents a function definition itself.
49 class FunctionAST {
50   std::unique_ptr<PrototypeAST> Proto;
51   std::unique_ptr<ExprAST> Body;
52 
53 public:
FunctionAST(std::unique_ptr<PrototypeAST> Proto,std::unique_ptr<ExprAST> Body)54   FunctionAST(std::unique_ptr<PrototypeAST> Proto,
55               std::unique_ptr<ExprAST> Body)
56       : Proto(std::move(Proto)), Body(std::move(Body)) {}
57 
58   const PrototypeAST& getProto() const;
59   const std::string& getName() const;
60   llvm::Function *codegen();
61 };
62 
63 /// This will compile FnAST to IR, rename the function to add the given
64 /// suffix (needed to prevent a name-clash with the function's stub),
65 /// and then take ownership of the module that the function was compiled
66 /// into.
67 std::unique_ptr<llvm::Module>
68 irgenAndTakeOwnership(FunctionAST &FnAST, const std::string &Suffix);
69 
70 namespace llvm {
71 namespace orc {
72 
73 class KaleidoscopeJIT {
74 private:
75   ExecutionSession ES;
76   std::shared_ptr<SymbolResolver> Resolver;
77   std::unique_ptr<TargetMachine> TM;
78   const DataLayout DL;
79   LegacyRTDyldObjectLinkingLayer ObjectLayer;
80   LegacyIRCompileLayer<decltype(ObjectLayer), SimpleCompiler> CompileLayer;
81 
82   using OptimizeFunction =
83       std::function<std::unique_ptr<Module>(std::unique_ptr<Module>)>;
84 
85   LegacyIRTransformLayer<decltype(CompileLayer), OptimizeFunction> OptimizeLayer;
86 
87   std::unique_ptr<JITCompileCallbackManager> CompileCallbackMgr;
88   std::unique_ptr<IndirectStubsManager> IndirectStubsMgr;
89 
90 public:
KaleidoscopeJIT()91   KaleidoscopeJIT()
92       : Resolver(createLegacyLookupResolver(
93             ES,
94             [this](StringRef Name) -> JITSymbol {
95               if (auto Sym = IndirectStubsMgr->findStub(Name, false))
96                 return Sym;
97               if (auto Sym = OptimizeLayer.findSymbol(std::string(Name), false))
98                 return Sym;
99               else if (auto Err = Sym.takeError())
100                 return std::move(Err);
101               if (auto SymAddr = RTDyldMemoryManager::getSymbolAddressInProcess(
102                       std::string(Name)))
103                 return JITSymbol(SymAddr, JITSymbolFlags::Exported);
104               return nullptr;
105             },
106             [](Error Err) { cantFail(std::move(Err), "lookupFlags failed"); })),
107         TM(EngineBuilder().selectTarget()), DL(TM->createDataLayout()),
108         ObjectLayer(AcknowledgeORCv1Deprecation, ES,
109                     [this](VModuleKey K) {
110                       return LegacyRTDyldObjectLinkingLayer::Resources{
111                           std::make_shared<SectionMemoryManager>(), Resolver};
112                     }),
113         CompileLayer(AcknowledgeORCv1Deprecation, ObjectLayer,
114                      SimpleCompiler(*TM)),
115         OptimizeLayer(AcknowledgeORCv1Deprecation, CompileLayer,
116                       [this](std::unique_ptr<Module> M) {
117                         return optimizeModule(std::move(M));
118                       }),
119         CompileCallbackMgr(cantFail(orc::createLocalCompileCallbackManager(
120             TM->getTargetTriple(), ES, 0))) {
121     auto IndirectStubsMgrBuilder =
122       orc::createLocalIndirectStubsManagerBuilder(TM->getTargetTriple());
123     IndirectStubsMgr = IndirectStubsMgrBuilder();
124     llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
125   }
126 
getTargetMachine()127   TargetMachine &getTargetMachine() { return *TM; }
128 
addModule(std::unique_ptr<Module> M)129   VModuleKey addModule(std::unique_ptr<Module> M) {
130     // Add the module to the JIT with a new VModuleKey.
131     auto K = ES.allocateVModule();
132     cantFail(OptimizeLayer.addModule(K, std::move(M)));
133     return K;
134   }
135 
addFunctionAST(std::unique_ptr<FunctionAST> FnAST)136   Error addFunctionAST(std::unique_ptr<FunctionAST> FnAST) {
137     // Move ownership of FnAST to a shared pointer - C++11 lambdas don't support
138     // capture-by-move, which is be required for unique_ptr.
139     auto SharedFnAST = std::shared_ptr<FunctionAST>(std::move(FnAST));
140 
141     // Set the action to compile our AST. This lambda will be run if/when
142     // execution hits the compile callback (via the stub).
143     //
144     // The steps to compile are:
145     // (1) IRGen the function.
146     // (2) Add the IR module to the JIT to make it executable like any other
147     //     module.
148     // (3) Use findSymbol to get the address of the compiled function.
149     // (4) Update the stub pointer to point at the implementation so that
150     ///    subsequent calls go directly to it and bypass the compiler.
151     // (5) Return the address of the implementation: this lambda will actually
152     //     be run inside an attempted call to the function, and we need to
153     //     continue on to the implementation to complete the attempted call.
154     //     The JIT runtime (the resolver block) will use the return address of
155     //     this function as the address to continue at once it has reset the
156     //     CPU state to what it was immediately before the call.
157     auto CompileAction = [this, SharedFnAST]() {
158       auto M = irgenAndTakeOwnership(*SharedFnAST, "$impl");
159       addModule(std::move(M));
160       auto Sym = findSymbol(SharedFnAST->getName() + "$impl");
161       assert(Sym && "Couldn't find compiled function?");
162       JITTargetAddress SymAddr = cantFail(Sym.getAddress());
163       if (auto Err = IndirectStubsMgr->updatePointer(
164               mangle(SharedFnAST->getName()), SymAddr)) {
165         logAllUnhandledErrors(std::move(Err), errs(),
166                               "Error updating function pointer: ");
167         exit(1);
168       }
169 
170       return SymAddr;
171     };
172 
173     // Create a CompileCallback using the CompileAction - this is the re-entry
174     // point into the compiler for functions that haven't been compiled yet.
175     auto CCAddr = cantFail(
176         CompileCallbackMgr->getCompileCallback(std::move(CompileAction)));
177 
178     // Create an indirect stub. This serves as the functions "canonical
179     // definition" - an unchanging (constant address) entry point to the
180     // function implementation.
181     // Initially we point the stub's function-pointer at the compile callback
182     // that we just created. When the compile action for the callback is run we
183     // will update the stub's function pointer to point at the function
184     // implementation that we just implemented.
185     if (auto Err = IndirectStubsMgr->createStub(
186             mangle(SharedFnAST->getName()), CCAddr, JITSymbolFlags::Exported))
187       return Err;
188 
189     return Error::success();
190   }
191 
findSymbol(const std::string Name)192   JITSymbol findSymbol(const std::string Name) {
193     return OptimizeLayer.findSymbol(mangle(Name), true);
194   }
195 
removeModule(VModuleKey K)196   void removeModule(VModuleKey K) {
197     cantFail(OptimizeLayer.removeModule(K));
198   }
199 
200 private:
mangle(const std::string & Name)201   std::string mangle(const std::string &Name) {
202     std::string MangledName;
203     raw_string_ostream MangledNameStream(MangledName);
204     Mangler::getNameWithPrefix(MangledNameStream, Name, DL);
205     return MangledNameStream.str();
206   }
207 
optimizeModule(std::unique_ptr<Module> M)208   std::unique_ptr<Module> optimizeModule(std::unique_ptr<Module> M) {
209     // Create a function pass manager.
210     auto FPM = std::make_unique<legacy::FunctionPassManager>(M.get());
211 
212     // Add some optimizations.
213     FPM->add(createInstructionCombiningPass());
214     FPM->add(createReassociatePass());
215     FPM->add(createGVNPass());
216     FPM->add(createCFGSimplificationPass());
217     FPM->doInitialization();
218 
219     // Run the optimizations over all functions in the module being added to
220     // the JIT.
221     for (auto &F : *M)
222       FPM->run(F);
223 
224     return M;
225   }
226 };
227 
228 } // end namespace orc
229 } // end namespace llvm
230 
231 #endif // LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
232