1 //===------ OrcTestCommon.h - Utilities for Orc Unit Tests ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Common utilities for the Orc unit tests.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 
14 #ifndef LLVM_UNITTESTS_EXECUTIONENGINE_ORC_ORCTESTCOMMON_H
15 #define LLVM_UNITTESTS_EXECUTIONENGINE_ORC_ORCTESTCOMMON_H
16 
17 #include "llvm/ExecutionEngine/ExecutionEngine.h"
18 #include "llvm/ExecutionEngine/JITSymbol.h"
19 #include "llvm/ExecutionEngine/Orc/IndirectionUtils.h"
20 #include "llvm/IR/Function.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/LLVMContext.h"
23 #include "llvm/IR/Module.h"
24 #include "llvm/Object/ObjectFile.h"
25 #include "llvm/Support/TargetRegistry.h"
26 #include "llvm/Support/TargetSelect.h"
27 #include "gtest/gtest.h"
28 
29 #include <memory>
30 
31 namespace llvm {
32 
33 namespace orc {
34 // CoreAPIsStandardTest that saves a bunch of boilerplate by providing the
35 // following:
36 //
37 // (1) ES -- An ExecutionSession
38 // (2) Foo, Bar, Baz, Qux -- SymbolStringPtrs for strings "foo", "bar", "baz",
39 //     and "qux" respectively.
40 // (3) FooAddr, BarAddr, BazAddr, QuxAddr -- Dummy addresses. Guaranteed
41 //     distinct and non-null.
42 // (4) FooSym, BarSym, BazSym, QuxSym -- JITEvaluatedSymbols with FooAddr,
43 //     BarAddr, BazAddr, and QuxAddr respectively. All with default strong,
44 //     linkage and non-hidden visibility.
45 // (5) V -- A JITDylib associated with ES.
46 class CoreAPIsBasedStandardTest : public testing::Test {
47 public:
~CoreAPIsBasedStandardTest()48   ~CoreAPIsBasedStandardTest() {
49     if (auto Err = ES.endSession())
50       ES.reportError(std::move(Err));
51   }
52 
53 protected:
54   std::shared_ptr<SymbolStringPool> SSP = std::make_shared<SymbolStringPool>();
55   ExecutionSession ES{SSP};
56   JITDylib &JD = ES.createBareJITDylib("JD");
57   SymbolStringPtr Foo = ES.intern("foo");
58   SymbolStringPtr Bar = ES.intern("bar");
59   SymbolStringPtr Baz = ES.intern("baz");
60   SymbolStringPtr Qux = ES.intern("qux");
61   static const JITTargetAddress FooAddr = 1U;
62   static const JITTargetAddress BarAddr = 2U;
63   static const JITTargetAddress BazAddr = 3U;
64   static const JITTargetAddress QuxAddr = 4U;
65   JITEvaluatedSymbol FooSym =
66       JITEvaluatedSymbol(FooAddr, JITSymbolFlags::Exported);
67   JITEvaluatedSymbol BarSym =
68       JITEvaluatedSymbol(BarAddr, JITSymbolFlags::Exported);
69   JITEvaluatedSymbol BazSym =
70       JITEvaluatedSymbol(BazAddr, JITSymbolFlags::Exported);
71   JITEvaluatedSymbol QuxSym =
72       JITEvaluatedSymbol(QuxAddr, JITSymbolFlags::Exported);
73 };
74 
75 } // end namespace orc
76 
77 class OrcNativeTarget {
78 public:
initialize()79   static void initialize() {
80     if (!NativeTargetInitialized) {
81       InitializeNativeTarget();
82       InitializeNativeTargetAsmParser();
83       InitializeNativeTargetAsmPrinter();
84       NativeTargetInitialized = true;
85     }
86   }
87 
88 private:
89   static bool NativeTargetInitialized;
90 };
91 
92 class SimpleMaterializationUnit : public orc::MaterializationUnit {
93 public:
94   using MaterializeFunction =
95       std::function<void(std::unique_ptr<orc::MaterializationResponsibility>)>;
96   using DiscardFunction =
97       std::function<void(const orc::JITDylib &, orc::SymbolStringPtr)>;
98   using DestructorFunction = std::function<void()>;
99 
100   SimpleMaterializationUnit(
101       orc::SymbolFlagsMap SymbolFlags, MaterializeFunction Materialize,
102       orc::SymbolStringPtr InitSym = nullptr,
103       DiscardFunction Discard = DiscardFunction(),
104       DestructorFunction Destructor = DestructorFunction())
MaterializationUnit(std::move (SymbolFlags),std::move (InitSym))105       : MaterializationUnit(std::move(SymbolFlags), std::move(InitSym)),
106         Materialize(std::move(Materialize)), Discard(std::move(Discard)),
107         Destructor(std::move(Destructor)) {}
108 
~SimpleMaterializationUnit()109   ~SimpleMaterializationUnit() override {
110     if (Destructor)
111       Destructor();
112   }
113 
getName()114   StringRef getName() const override { return "<Simple>"; }
115 
116   void
materialize(std::unique_ptr<orc::MaterializationResponsibility> R)117   materialize(std::unique_ptr<orc::MaterializationResponsibility> R) override {
118     Materialize(std::move(R));
119   }
120 
discard(const orc::JITDylib & JD,const orc::SymbolStringPtr & Name)121   void discard(const orc::JITDylib &JD,
122                const orc::SymbolStringPtr &Name) override {
123     if (Discard)
124       Discard(JD, std::move(Name));
125     else
126       llvm_unreachable("Discard not supported");
127   }
128 
129 private:
130   MaterializeFunction Materialize;
131   DiscardFunction Discard;
132   DestructorFunction Destructor;
133 };
134 
135 // Base class for Orc tests that will execute code.
136 class OrcExecutionTest {
137 public:
138 
OrcExecutionTest()139   OrcExecutionTest() {
140 
141     // Initialize the native target if it hasn't been done already.
142     OrcNativeTarget::initialize();
143 
144     // Try to select a TargetMachine for the host.
145     TM.reset(EngineBuilder().selectTarget());
146 
147     if (TM) {
148       // If we found a TargetMachine, check that it's one that Orc supports.
149       const Triple& TT = TM->getTargetTriple();
150 
151       // Bail out for windows platforms. We do not support these yet.
152       if ((TT.getArch() != Triple::x86_64 && TT.getArch() != Triple::x86) ||
153            TT.isOSWindows())
154         return;
155 
156       // Target can JIT?
157       SupportsJIT = TM->getTarget().hasJIT();
158       // Use ability to create callback manager to detect whether Orc
159       // has indirection support on this platform. This way the test
160       // and Orc code do not get out of sync.
161       SupportsIndirection = !!orc::createLocalCompileCallbackManager(TT, ES, 0);
162     }
163   };
164 
~OrcExecutionTest()165   ~OrcExecutionTest() {
166     if (auto Err = ES.endSession())
167       ES.reportError(std::move(Err));
168   }
169 
170 protected:
171   orc::ExecutionSession ES;
172   LLVMContext Context;
173   std::unique_ptr<TargetMachine> TM;
174   bool SupportsJIT = false;
175   bool SupportsIndirection = false;
176 };
177 
178 class ModuleBuilder {
179 public:
180   ModuleBuilder(LLVMContext &Context, StringRef Triple,
181                 StringRef Name);
182 
createFunctionDecl(FunctionType * FTy,StringRef Name)183   Function *createFunctionDecl(FunctionType *FTy, StringRef Name) {
184     return Function::Create(FTy, GlobalValue::ExternalLinkage, Name, M.get());
185   }
186 
getModule()187   Module* getModule() { return M.get(); }
getModule()188   const Module* getModule() const { return M.get(); }
takeModule()189   std::unique_ptr<Module> takeModule() { return std::move(M); }
190 
191 private:
192   std::unique_ptr<Module> M;
193 };
194 
195 // Dummy struct type.
196 struct DummyStruct {
197   int X[256];
198 };
199 
getDummyStructTy(LLVMContext & Context)200 inline StructType *getDummyStructTy(LLVMContext &Context) {
201   return StructType::get(ArrayType::get(Type::getInt32Ty(Context), 256));
202 }
203 
204 template <typename HandleT, typename ModuleT>
205 class MockBaseLayer {
206 public:
207 
208   using ModuleHandleT = HandleT;
209 
210   using AddModuleSignature =
211     Expected<ModuleHandleT>(ModuleT M,
212                             std::shared_ptr<JITSymbolResolver> R);
213 
214   using RemoveModuleSignature = Error(ModuleHandleT H);
215   using FindSymbolSignature = JITSymbol(const std::string &Name,
216                                         bool ExportedSymbolsOnly);
217   using FindSymbolInSignature = JITSymbol(ModuleHandleT H,
218                                           const std::string &Name,
219                                           bool ExportedSymbolsONly);
220   using EmitAndFinalizeSignature = Error(ModuleHandleT H);
221 
222   std::function<AddModuleSignature> addModuleImpl;
223   std::function<RemoveModuleSignature> removeModuleImpl;
224   std::function<FindSymbolSignature> findSymbolImpl;
225   std::function<FindSymbolInSignature> findSymbolInImpl;
226   std::function<EmitAndFinalizeSignature> emitAndFinalizeImpl;
227 
addModule(ModuleT M,std::shared_ptr<JITSymbolResolver> R)228   Expected<ModuleHandleT> addModule(ModuleT M,
229                                     std::shared_ptr<JITSymbolResolver> R) {
230     assert(addModuleImpl &&
231            "addModule called, but no mock implementation was provided");
232     return addModuleImpl(std::move(M), std::move(R));
233   }
234 
removeModule(ModuleHandleT H)235   Error removeModule(ModuleHandleT H) {
236     assert(removeModuleImpl &&
237            "removeModule called, but no mock implementation was provided");
238     return removeModuleImpl(H);
239   }
240 
findSymbol(const std::string & Name,bool ExportedSymbolsOnly)241   JITSymbol findSymbol(const std::string &Name, bool ExportedSymbolsOnly) {
242     assert(findSymbolImpl &&
243            "findSymbol called, but no mock implementation was provided");
244     return findSymbolImpl(Name, ExportedSymbolsOnly);
245   }
246 
findSymbolIn(ModuleHandleT H,const std::string & Name,bool ExportedSymbolsOnly)247   JITSymbol findSymbolIn(ModuleHandleT H, const std::string &Name,
248                          bool ExportedSymbolsOnly) {
249     assert(findSymbolInImpl &&
250            "findSymbolIn called, but no mock implementation was provided");
251     return findSymbolInImpl(H, Name, ExportedSymbolsOnly);
252   }
253 
emitAndFinaliez(ModuleHandleT H)254   Error emitAndFinaliez(ModuleHandleT H) {
255     assert(emitAndFinalizeImpl &&
256            "emitAndFinalize called, but no mock implementation was provided");
257     return emitAndFinalizeImpl(H);
258   }
259 };
260 
261 class ReturnNullJITSymbol {
262 public:
263   template <typename... Args>
operator()264   JITSymbol operator()(Args...) const {
265     return nullptr;
266   }
267 };
268 
269 template <typename ReturnT>
270 class DoNothingAndReturn {
271 public:
DoNothingAndReturn(ReturnT Ret)272   DoNothingAndReturn(ReturnT Ret) : Ret(std::move(Ret)) {}
273 
274   template <typename... Args>
operator()275   void operator()(Args...) const { return Ret; }
276 private:
277   ReturnT Ret;
278 };
279 
280 template <>
281 class DoNothingAndReturn<void> {
282 public:
283   template <typename... Args>
operator()284   void operator()(Args...) const { }
285 };
286 
287 } // namespace llvm
288 
289 #endif
290