1 //===- KaleidoscopeJIT.h - A simple JIT for Kaleidoscope --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Contains a simple JIT definition for use in the kaleidoscope tutorials.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
14 #define LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
15 
16 #include "RemoteJITUtils.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/Triple.h"
20 #include "llvm/ExecutionEngine/ExecutionEngine.h"
21 #include "llvm/ExecutionEngine/JITSymbol.h"
22 #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
23 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
24 #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
25 #include "llvm/ExecutionEngine/Orc/IndirectionUtils.h"
26 #include "llvm/ExecutionEngine/Orc/LambdaResolver.h"
27 #include "llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h"
28 #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
29 #include "llvm/IR/DataLayout.h"
30 #include "llvm/IR/LegacyPassManager.h"
31 #include "llvm/IR/Mangler.h"
32 #include "llvm/Support/DynamicLibrary.h"
33 #include "llvm/Support/Error.h"
34 #include "llvm/Support/raw_ostream.h"
35 #include "llvm/Target/TargetMachine.h"
36 #include "llvm/Transforms/InstCombine/InstCombine.h"
37 #include "llvm/Transforms/Scalar.h"
38 #include "llvm/Transforms/Scalar/GVN.h"
39 #include <algorithm>
40 #include <cassert>
41 #include <cstdlib>
42 #include <map>
43 #include <memory>
44 #include <string>
45 #include <vector>
46 
47 class PrototypeAST;
48 class ExprAST;
49 
50 /// FunctionAST - This class represents a function definition itself.
51 class FunctionAST {
52   std::unique_ptr<PrototypeAST> Proto;
53   std::unique_ptr<ExprAST> Body;
54 
55 public:
FunctionAST(std::unique_ptr<PrototypeAST> Proto,std::unique_ptr<ExprAST> Body)56   FunctionAST(std::unique_ptr<PrototypeAST> Proto,
57               std::unique_ptr<ExprAST> Body)
58       : Proto(std::move(Proto)), Body(std::move(Body)) {}
59 
60   const PrototypeAST& getProto() const;
61   const std::string& getName() const;
62   llvm::Function *codegen();
63 };
64 
65 /// This will compile FnAST to IR, rename the function to add the given
66 /// suffix (needed to prevent a name-clash with the function's stub),
67 /// and then take ownership of the module that the function was compiled
68 /// into.
69 std::unique_ptr<llvm::Module>
70 irgenAndTakeOwnership(FunctionAST &FnAST, const std::string &Suffix);
71 
72 namespace llvm {
73 namespace orc {
74 
75 // Typedef the remote-client API.
76 using MyRemote = remote::OrcRemoteTargetClient;
77 
78 class KaleidoscopeJIT {
79 private:
80   ExecutionSession &ES;
81   std::shared_ptr<SymbolResolver> Resolver;
82   std::unique_ptr<TargetMachine> TM;
83   const DataLayout DL;
84   LegacyRTDyldObjectLinkingLayer ObjectLayer;
85   LegacyIRCompileLayer<decltype(ObjectLayer), SimpleCompiler> CompileLayer;
86 
87   using OptimizeFunction =
88       std::function<std::unique_ptr<Module>(std::unique_ptr<Module>)>;
89 
90   LegacyIRTransformLayer<decltype(CompileLayer), OptimizeFunction> OptimizeLayer;
91 
92   JITCompileCallbackManager *CompileCallbackMgr;
93   std::unique_ptr<IndirectStubsManager> IndirectStubsMgr;
94   MyRemote &Remote;
95 
96 public:
KaleidoscopeJIT(ExecutionSession & ES,MyRemote & Remote)97   KaleidoscopeJIT(ExecutionSession &ES, MyRemote &Remote)
98       : ES(ES),
99         Resolver(createLegacyLookupResolver(
100             ES,
101             [this](const std::string &Name) -> JITSymbol {
102               if (auto Sym = IndirectStubsMgr->findStub(Name, false))
103                 return Sym;
104               if (auto Sym = OptimizeLayer.findSymbol(Name, false))
105                 return Sym;
106               else if (auto Err = Sym.takeError())
107                 return std::move(Err);
108               if (auto Addr = cantFail(this->Remote.getSymbolAddress(Name)))
109                 return JITSymbol(Addr, JITSymbolFlags::Exported);
110               return nullptr;
111             },
112             [](Error Err) { cantFail(std::move(Err), "lookupFlags failed"); })),
113         TM(EngineBuilder().selectTarget(Triple(Remote.getTargetTriple()), "",
114                                         "", SmallVector<std::string, 0>())),
115         DL(TM->createDataLayout()),
116         ObjectLayer(AcknowledgeORCv1Deprecation, ES,
117                     [this](VModuleKey K) {
118                       return LegacyRTDyldObjectLinkingLayer::Resources{
119                           cantFail(this->Remote.createRemoteMemoryManager()),
120                           Resolver};
121                     }),
122         CompileLayer(AcknowledgeORCv1Deprecation, ObjectLayer,
123                      SimpleCompiler(*TM)),
124         OptimizeLayer(AcknowledgeORCv1Deprecation, CompileLayer,
125                       [this](std::unique_ptr<Module> M) {
126                         return optimizeModule(std::move(M));
127                       }),
128         Remote(Remote) {
129     auto CCMgrOrErr = Remote.enableCompileCallbacks(0);
130     if (!CCMgrOrErr) {
131       logAllUnhandledErrors(CCMgrOrErr.takeError(), errs(),
132                             "Error enabling remote compile callbacks:");
133       exit(1);
134     }
135     CompileCallbackMgr = &*CCMgrOrErr;
136     IndirectStubsMgr = cantFail(Remote.createIndirectStubsManager());
137     llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
138   }
139 
getTargetMachine()140   TargetMachine &getTargetMachine() { return *TM; }
141 
addModule(std::unique_ptr<Module> M)142   VModuleKey addModule(std::unique_ptr<Module> M) {
143     // Add the module with a new VModuleKey.
144     auto K = ES.allocateVModule();
145     cantFail(OptimizeLayer.addModule(K, std::move(M)));
146     return K;
147   }
148 
addFunctionAST(std::unique_ptr<FunctionAST> FnAST)149   Error addFunctionAST(std::unique_ptr<FunctionAST> FnAST) {
150     // Move ownership of FnAST to a shared pointer - C++11 lambdas don't support
151     // capture-by-move, which is be required for unique_ptr.
152     auto SharedFnAST = std::shared_ptr<FunctionAST>(std::move(FnAST));
153 
154     // Set the action to compile our AST. This lambda will be run if/when
155     // execution hits the compile callback (via the stub).
156     //
157     // The steps to compile are:
158     // (1) IRGen the function.
159     // (2) Add the IR module to the JIT to make it executable like any other
160     //     module.
161     // (3) Use findSymbol to get the address of the compiled function.
162     // (4) Update the stub pointer to point at the implementation so that
163     ///    subsequent calls go directly to it and bypass the compiler.
164     // (5) Return the address of the implementation: this lambda will actually
165     //     be run inside an attempted call to the function, and we need to
166     //     continue on to the implementation to complete the attempted call.
167     //     The JIT runtime (the resolver block) will use the return address of
168     //     this function as the address to continue at once it has reset the
169     //     CPU state to what it was immediately before the call.
170     auto CompileAction = [this, SharedFnAST]() {
171       auto M = irgenAndTakeOwnership(*SharedFnAST, "$impl");
172       addModule(std::move(M));
173       auto Sym = findSymbol(SharedFnAST->getName() + "$impl");
174       assert(Sym && "Couldn't find compiled function?");
175       JITTargetAddress SymAddr = cantFail(Sym.getAddress());
176       if (auto Err = IndirectStubsMgr->updatePointer(
177               mangle(SharedFnAST->getName()), SymAddr)) {
178         logAllUnhandledErrors(std::move(Err), errs(),
179                               "Error updating function pointer: ");
180         exit(1);
181       }
182 
183       return SymAddr;
184     };
185 
186     // Create a CompileCallback suing the CompileAction - this is the re-entry
187     // point into the compiler for functions that haven't been compiled yet.
188     auto CCAddr = cantFail(
189         CompileCallbackMgr->getCompileCallback(std::move(CompileAction)));
190 
191     // Create an indirect stub. This serves as the functions "canonical
192     // definition" - an unchanging (constant address) entry point to the
193     // function implementation.
194     // Initially we point the stub's function-pointer at the compile callback
195     // that we just created. In the compile action for the callback we will
196     // update the stub's function pointer to point at the function
197     // implementation that we just implemented.
198     if (auto Err = IndirectStubsMgr->createStub(
199             mangle(SharedFnAST->getName()), CCAddr, JITSymbolFlags::Exported))
200       return Err;
201 
202     return Error::success();
203   }
204 
executeRemoteExpr(JITTargetAddress ExprAddr)205   Error executeRemoteExpr(JITTargetAddress ExprAddr) {
206     return Remote.callVoidVoid(ExprAddr);
207   }
208 
findSymbol(const std::string Name)209   JITSymbol findSymbol(const std::string Name) {
210     return OptimizeLayer.findSymbol(mangle(Name), true);
211   }
212 
removeModule(VModuleKey K)213   void removeModule(VModuleKey K) {
214     cantFail(OptimizeLayer.removeModule(K));
215   }
216 
217 private:
mangle(const std::string & Name)218   std::string mangle(const std::string &Name) {
219     std::string MangledName;
220     raw_string_ostream MangledNameStream(MangledName);
221     Mangler::getNameWithPrefix(MangledNameStream, Name, DL);
222     return MangledNameStream.str();
223   }
224 
optimizeModule(std::unique_ptr<Module> M)225   std::unique_ptr<Module> optimizeModule(std::unique_ptr<Module> M) {
226     // Create a function pass manager.
227     auto FPM = std::make_unique<legacy::FunctionPassManager>(M.get());
228 
229     // Add some optimizations.
230     FPM->add(createInstructionCombiningPass());
231     FPM->add(createReassociatePass());
232     FPM->add(createGVNPass());
233     FPM->add(createCFGSimplificationPass());
234     FPM->doInitialization();
235 
236     // Run the optimizations over all functions in the module being added to
237     // the JIT.
238     for (auto &F : *M)
239       FPM->run(F);
240 
241     return M;
242   }
243 };
244 
245 } // end namespace orc
246 } // end namespace llvm
247 
248 #endif // LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
249