1 //===- KaleidoscopeJIT.h - A simple JIT for Kaleidoscope --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Contains a simple JIT definition for use in the kaleidoscope tutorials.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
14 #define LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
15 
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ExecutionEngine/ExecutionEngine.h"
18 #include "llvm/ExecutionEngine/JITSymbol.h"
19 #include "llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h"
20 #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
21 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
22 #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
23 #include "llvm/ExecutionEngine/Orc/LambdaResolver.h"
24 #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
25 #include "llvm/ExecutionEngine/RTDyldMemoryManager.h"
26 #include "llvm/ExecutionEngine/RuntimeDyld.h"
27 #include "llvm/ExecutionEngine/SectionMemoryManager.h"
28 #include "llvm/IR/DataLayout.h"
29 #include "llvm/IR/LegacyPassManager.h"
30 #include "llvm/IR/Mangler.h"
31 #include "llvm/Support/DynamicLibrary.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "llvm/Target/TargetMachine.h"
34 #include "llvm/Transforms/InstCombine/InstCombine.h"
35 #include "llvm/Transforms/Scalar.h"
36 #include "llvm/Transforms/Scalar/GVN.h"
37 #include <algorithm>
38 #include <map>
39 #include <memory>
40 #include <set>
41 #include <string>
42 #include <vector>
43 
44 namespace llvm {
45 namespace orc {
46 
47 class KaleidoscopeJIT {
48 private:
49   ExecutionSession ES;
50   std::map<VModuleKey, std::shared_ptr<SymbolResolver>> Resolvers;
51   std::unique_ptr<TargetMachine> TM;
52   const DataLayout DL;
53   LegacyRTDyldObjectLinkingLayer ObjectLayer;
54   LegacyIRCompileLayer<decltype(ObjectLayer), SimpleCompiler> CompileLayer;
55 
56   using OptimizeFunction =
57       std::function<std::unique_ptr<Module>(std::unique_ptr<Module>)>;
58 
59   LegacyIRTransformLayer<decltype(CompileLayer), OptimizeFunction> OptimizeLayer;
60 
61   std::unique_ptr<JITCompileCallbackManager> CompileCallbackManager;
62   LegacyCompileOnDemandLayer<decltype(OptimizeLayer)> CODLayer;
63 
64 public:
KaleidoscopeJIT()65   KaleidoscopeJIT()
66       : TM(EngineBuilder().selectTarget()), DL(TM->createDataLayout()),
67         ObjectLayer(AcknowledgeORCv1Deprecation, ES,
68                     [this](VModuleKey K) {
69                       return LegacyRTDyldObjectLinkingLayer::Resources{
70                           std::make_shared<SectionMemoryManager>(),
71                           Resolvers[K]};
72                     }),
73         CompileLayer(AcknowledgeORCv1Deprecation, ObjectLayer,
74                      SimpleCompiler(*TM)),
75         OptimizeLayer(AcknowledgeORCv1Deprecation, CompileLayer,
76                       [this](std::unique_ptr<Module> M) {
77                         return optimizeModule(std::move(M));
78                       }),
79         CompileCallbackManager(cantFail(orc::createLocalCompileCallbackManager(
80             TM->getTargetTriple(), ES, 0))),
81         CODLayer(
82             AcknowledgeORCv1Deprecation, ES, OptimizeLayer,
83             [&](orc::VModuleKey K) { return Resolvers[K]; },
84             [&](orc::VModuleKey K, std::shared_ptr<SymbolResolver> R) {
85               Resolvers[K] = std::move(R);
86             },
87             [](Function &F) { return std::set<Function *>({&F}); },
88             *CompileCallbackManager,
89             orc::createLocalIndirectStubsManagerBuilder(
90                 TM->getTargetTriple())) {
91     llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
92   }
93 
getTargetMachine()94   TargetMachine &getTargetMachine() { return *TM; }
95 
addModule(std::unique_ptr<Module> M)96   VModuleKey addModule(std::unique_ptr<Module> M) {
97     // Create a new VModuleKey.
98     VModuleKey K = ES.allocateVModule();
99 
100     // Build a resolver and associate it with the new key.
101     Resolvers[K] = createLegacyLookupResolver(
102         ES,
103         [this](const std::string &Name) -> JITSymbol {
104           if (auto Sym = CompileLayer.findSymbol(Name, false))
105             return Sym;
106           else if (auto Err = Sym.takeError())
107             return std::move(Err);
108           if (auto SymAddr =
109                   RTDyldMemoryManager::getSymbolAddressInProcess(Name))
110             return JITSymbol(SymAddr, JITSymbolFlags::Exported);
111           return nullptr;
112         },
113         [](Error Err) { cantFail(std::move(Err), "lookupFlags failed"); });
114 
115     // Add the module to the JIT with the new key.
116     cantFail(CODLayer.addModule(K, std::move(M)));
117     return K;
118   }
119 
findSymbol(const std::string Name)120   JITSymbol findSymbol(const std::string Name) {
121     std::string MangledName;
122     raw_string_ostream MangledNameStream(MangledName);
123     Mangler::getNameWithPrefix(MangledNameStream, Name, DL);
124     return CODLayer.findSymbol(MangledNameStream.str(), true);
125   }
126 
removeModule(VModuleKey K)127   void removeModule(VModuleKey K) {
128     cantFail(CODLayer.removeModule(K));
129   }
130 
131 private:
optimizeModule(std::unique_ptr<Module> M)132   std::unique_ptr<Module> optimizeModule(std::unique_ptr<Module> M) {
133     // Create a function pass manager.
134     auto FPM = llvm::make_unique<legacy::FunctionPassManager>(M.get());
135 
136     // Add some optimizations.
137     FPM->add(createInstructionCombiningPass());
138     FPM->add(createReassociatePass());
139     FPM->add(createGVNPass());
140     FPM->add(createCFGSimplificationPass());
141     FPM->doInitialization();
142 
143     // Run the optimizations over all functions in the module being added to
144     // the JIT.
145     for (auto &F : *M)
146       FPM->run(F);
147 
148     return M;
149   }
150 };
151 
152 } // end namespace orc
153 } // end namespace llvm
154 
155 #endif // LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
156