1 //===---------------------- RemoteObjectLayerTest.cpp ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
10 #include "llvm/ExecutionEngine/Orc/NullResolver.h"
11 #include "llvm/ExecutionEngine/Orc/RemoteObjectLayer.h"
12 #include "OrcTestCommon.h"
13 #include "QueueChannel.h"
14 #include "gtest/gtest.h"
15 
16 using namespace llvm;
17 using namespace llvm::orc;
18 
19 namespace {
20 
21 class MockObjectLayer {
22 public:
23 
24   using ObjHandleT = uint64_t;
25 
26   using ObjectPtr = std::unique_ptr<MemoryBuffer>;
27 
28   using LookupFn = std::function<JITSymbol(StringRef, bool)>;
29   using SymbolLookupTable = std::map<ObjHandleT, LookupFn>;
30 
31   using AddObjectFtor =
32     std::function<Expected<ObjHandleT>(ObjectPtr, SymbolLookupTable&)>;
33 
34   class ObjectNotFound : public remote::ResourceNotFound<ObjHandleT> {
35   public:
ObjectNotFound(ObjHandleT H)36     ObjectNotFound(ObjHandleT H) : ResourceNotFound(H, "Object handle") {}
37   };
38 
MockObjectLayer(AddObjectFtor AddObject)39   MockObjectLayer(AddObjectFtor AddObject)
40     : AddObject(std::move(AddObject)) {}
41 
addObject(ObjectPtr Obj,std::shared_ptr<JITSymbolResolver> Resolver)42   Expected<ObjHandleT> addObject(ObjectPtr Obj,
43             std::shared_ptr<JITSymbolResolver> Resolver) {
44     return AddObject(std::move(Obj), SymTab);
45   }
46 
removeObject(ObjHandleT H)47   Error removeObject(ObjHandleT H) {
48     if (SymTab.count(H))
49       return Error::success();
50     else
51       return make_error<ObjectNotFound>(H);
52   }
53 
findSymbol(StringRef Name,bool ExportedSymbolsOnly)54   JITSymbol findSymbol(StringRef Name, bool ExportedSymbolsOnly) {
55     for (auto KV : SymTab) {
56       if (auto Sym = KV.second(Name, ExportedSymbolsOnly))
57         return Sym;
58       else if (auto Err = Sym.takeError())
59         return std::move(Err);
60     }
61     return JITSymbol(nullptr);
62   }
63 
findSymbolIn(ObjHandleT H,StringRef Name,bool ExportedSymbolsOnly)64   JITSymbol findSymbolIn(ObjHandleT H, StringRef Name,
65                          bool ExportedSymbolsOnly) {
66     auto LI = SymTab.find(H);
67     if (LI != SymTab.end())
68       return LI->second(Name, ExportedSymbolsOnly);
69     else
70       return make_error<ObjectNotFound>(H);
71   }
72 
emitAndFinalize(ObjHandleT H)73   Error emitAndFinalize(ObjHandleT H) {
74     if (SymTab.count(H))
75       return Error::success();
76     else
77       return make_error<ObjectNotFound>(H);
78   }
79 
80 private:
81   AddObjectFtor AddObject;
82   SymbolLookupTable SymTab;
83 };
84 
85 using RPCEndpoint = rpc::SingleThreadedRPCEndpoint<rpc::RawByteChannel>;
86 
createTestObject()87 MockObjectLayer::ObjectPtr createTestObject() {
88   OrcNativeTarget::initialize();
89   auto TM = std::unique_ptr<TargetMachine>(EngineBuilder().selectTarget());
90 
91   if (!TM)
92     return nullptr;
93 
94   LLVMContext Ctx;
95   ModuleBuilder MB(Ctx, TM->getTargetTriple().str(), "TestModule");
96   MB.getModule()->setDataLayout(TM->createDataLayout());
97   auto *Main = MB.createFunctionDecl(
98       FunctionType::get(Type::getInt32Ty(Ctx),
99                         {Type::getInt32Ty(Ctx),
100                          Type::getInt8PtrTy(Ctx)->getPointerTo()},
101                         false),
102       "main");
103   Main->getBasicBlockList().push_back(BasicBlock::Create(Ctx));
104   IRBuilder<> B(&Main->back());
105   B.CreateRet(ConstantInt::getSigned(Type::getInt32Ty(Ctx), 42));
106 
107   SimpleCompiler IRCompiler(*TM);
108   return cantFail(IRCompiler(*MB.getModule()));
109 }
110 
TEST(RemoteObjectLayer,AddObject)111 TEST(RemoteObjectLayer, AddObject) {
112   llvm::orc::rpc::registerStringError<rpc::RawByteChannel>();
113   auto TestObject = createTestObject();
114   if (!TestObject)
115     return;
116 
117   auto Channels = createPairedQueueChannels();
118 
119   auto ReportError = [](Error Err) {
120     logAllUnhandledErrors(std::move(Err), llvm::errs());
121   };
122 
123   // Copy the bytes out of the test object: the copy will be used to verify
124   // that the original is correctly transmitted over RPC to the mock layer.
125   StringRef ObjBytes = TestObject->getBuffer();
126   std::vector<char> ObjContents(ObjBytes.size());
127   std::copy(ObjBytes.begin(), ObjBytes.end(), ObjContents.begin());
128 
129   RPCEndpoint ClientEP(*Channels.first, true);
130   RemoteObjectClientLayer<RPCEndpoint> Client(AcknowledgeORCv1Deprecation,
131                                               ClientEP, ReportError);
132 
133   RPCEndpoint ServerEP(*Channels.second, true);
134   MockObjectLayer BaseLayer(
135     [&ObjContents](MockObjectLayer::ObjectPtr Obj,
136                    MockObjectLayer::SymbolLookupTable &SymTab) {
137 
138       // Check that the received object file content matches the original.
139       StringRef RPCObjContents = Obj->getBuffer();
140       EXPECT_EQ(RPCObjContents.size(), ObjContents.size())
141         << "RPC'd object file has incorrect size";
142       EXPECT_TRUE(std::equal(RPCObjContents.begin(), RPCObjContents.end(),
143                              ObjContents.begin()))
144         << "RPC'd object file content does not match original content";
145 
146       return 1;
147     });
148   RemoteObjectServerLayer<MockObjectLayer, RPCEndpoint> Server(
149       AcknowledgeORCv1Deprecation, BaseLayer, ServerEP, ReportError);
150 
151   bool Finished = false;
152   ServerEP.addHandler<remote::utils::TerminateSession>(
153     [&]() { Finished = true; }
154   );
155 
156   auto ServerThread =
157     std::thread([&]() {
158       while (!Finished)
159         cantFail(ServerEP.handleOne());
160     });
161 
162   cantFail(Client.addObject(std::move(TestObject),
163                             std::make_shared<NullLegacyResolver>()));
164   cantFail(ClientEP.callB<remote::utils::TerminateSession>());
165   ServerThread.join();
166 }
167 
TEST(RemoteObjectLayer,AddObjectFailure)168 TEST(RemoteObjectLayer, AddObjectFailure) {
169   llvm::orc::rpc::registerStringError<rpc::RawByteChannel>();
170   auto TestObject = createTestObject();
171   if (!TestObject)
172     return;
173 
174   auto Channels = createPairedQueueChannels();
175 
176   auto ReportError =
177     [](Error Err) {
178       auto ErrMsg = toString(std::move(Err));
179       EXPECT_EQ(ErrMsg, "AddObjectFailure - Test Message")
180         << "Expected error string to be \"AddObjectFailure - Test Message\"";
181     };
182 
183   RPCEndpoint ClientEP(*Channels.first, true);
184   RemoteObjectClientLayer<RPCEndpoint> Client(AcknowledgeORCv1Deprecation,
185                                               ClientEP, ReportError);
186 
187   RPCEndpoint ServerEP(*Channels.second, true);
188   MockObjectLayer BaseLayer(
189     [](MockObjectLayer::ObjectPtr Obj,
190        MockObjectLayer::SymbolLookupTable &SymTab)
191         -> Expected<MockObjectLayer::ObjHandleT> {
192       return make_error<StringError>("AddObjectFailure - Test Message",
193                                      inconvertibleErrorCode());
194     });
195   RemoteObjectServerLayer<MockObjectLayer, RPCEndpoint> Server(
196       AcknowledgeORCv1Deprecation, BaseLayer, ServerEP, ReportError);
197 
198   bool Finished = false;
199   ServerEP.addHandler<remote::utils::TerminateSession>(
200     [&]() { Finished = true; }
201   );
202 
203   auto ServerThread =
204     std::thread([&]() {
205       while (!Finished)
206         cantFail(ServerEP.handleOne());
207     });
208 
209   auto HandleOrErr = Client.addObject(std::move(TestObject),
210                                       std::make_shared<NullLegacyResolver>());
211 
212   EXPECT_FALSE(HandleOrErr) << "Expected error from addObject";
213 
214   auto ErrMsg = toString(HandleOrErr.takeError());
215   EXPECT_EQ(ErrMsg, "AddObjectFailure - Test Message")
216     << "Expected error string to be \"AddObjectFailure - Test Message\"";
217 
218   cantFail(ClientEP.callB<remote::utils::TerminateSession>());
219   ServerThread.join();
220 }
221 
222 
TEST(RemoteObjectLayer,RemoveObject)223 TEST(RemoteObjectLayer, RemoveObject) {
224   llvm::orc::rpc::registerStringError<rpc::RawByteChannel>();
225   auto TestObject = createTestObject();
226   if (!TestObject)
227     return;
228 
229   auto Channels = createPairedQueueChannels();
230 
231   auto ReportError = [](Error Err) {
232     logAllUnhandledErrors(std::move(Err), llvm::errs());
233   };
234 
235   RPCEndpoint ClientEP(*Channels.first, true);
236   RemoteObjectClientLayer<RPCEndpoint> Client(AcknowledgeORCv1Deprecation,
237                                               ClientEP, ReportError);
238 
239   RPCEndpoint ServerEP(*Channels.second, true);
240 
241   MockObjectLayer BaseLayer(
242     [](MockObjectLayer::ObjectPtr Obj,
243        MockObjectLayer::SymbolLookupTable &SymTab) {
244       SymTab[1] = MockObjectLayer::LookupFn();
245       return 1;
246     });
247   RemoteObjectServerLayer<MockObjectLayer, RPCEndpoint> Server(
248       AcknowledgeORCv1Deprecation, BaseLayer, ServerEP, ReportError);
249 
250   bool Finished = false;
251   ServerEP.addHandler<remote::utils::TerminateSession>(
252     [&]() { Finished = true; }
253   );
254 
255   auto ServerThread =
256     std::thread([&]() {
257       while (!Finished)
258         cantFail(ServerEP.handleOne());
259     });
260 
261   auto H = cantFail(Client.addObject(std::move(TestObject),
262                                      std::make_shared<NullLegacyResolver>()));
263 
264   cantFail(Client.removeObject(H));
265 
266   cantFail(ClientEP.callB<remote::utils::TerminateSession>());
267   ServerThread.join();
268 }
269 
TEST(RemoteObjectLayer,RemoveObjectFailure)270 TEST(RemoteObjectLayer, RemoveObjectFailure) {
271   llvm::orc::rpc::registerStringError<rpc::RawByteChannel>();
272   auto TestObject = createTestObject();
273   if (!TestObject)
274     return;
275 
276   auto Channels = createPairedQueueChannels();
277 
278   auto ReportError =
279     [](Error Err) {
280       auto ErrMsg = toString(std::move(Err));
281       EXPECT_EQ(ErrMsg, "Object handle 42 not found")
282         << "Expected error string to be \"Object handle 42 not found\"";
283     };
284 
285   RPCEndpoint ClientEP(*Channels.first, true);
286   RemoteObjectClientLayer<RPCEndpoint> Client(AcknowledgeORCv1Deprecation,
287                                               ClientEP, ReportError);
288 
289   RPCEndpoint ServerEP(*Channels.second, true);
290 
291   // AddObject lambda does not update symbol table, so removeObject will treat
292   // this as a bad object handle.
293   MockObjectLayer BaseLayer(
294     [](MockObjectLayer::ObjectPtr Obj,
295        MockObjectLayer::SymbolLookupTable &SymTab) {
296       return 42;
297     });
298   RemoteObjectServerLayer<MockObjectLayer, RPCEndpoint> Server(
299       AcknowledgeORCv1Deprecation, BaseLayer, ServerEP, ReportError);
300 
301   bool Finished = false;
302   ServerEP.addHandler<remote::utils::TerminateSession>(
303     [&]() { Finished = true; }
304   );
305 
306   auto ServerThread =
307     std::thread([&]() {
308       while (!Finished)
309         cantFail(ServerEP.handleOne());
310     });
311 
312   auto H = cantFail(Client.addObject(std::move(TestObject),
313                                      std::make_shared<NullLegacyResolver>()));
314 
315   auto Err = Client.removeObject(H);
316   EXPECT_TRUE(!!Err) << "Expected error from removeObject";
317 
318   auto ErrMsg = toString(std::move(Err));
319   EXPECT_EQ(ErrMsg, "Object handle 42 not found")
320     << "Expected error string to be \"Object handle 42 not found\"";
321 
322   cantFail(ClientEP.callB<remote::utils::TerminateSession>());
323   ServerThread.join();
324 }
325 
TEST(RemoteObjectLayer,FindSymbol)326 TEST(RemoteObjectLayer, FindSymbol) {
327   llvm::orc::rpc::registerStringError<rpc::RawByteChannel>();
328   auto TestObject = createTestObject();
329   if (!TestObject)
330     return;
331 
332   auto Channels = createPairedQueueChannels();
333 
334   auto ReportError =
335     [](Error Err) {
336       auto ErrMsg = toString(std::move(Err));
337       EXPECT_EQ(ErrMsg, "Could not find symbol 'badsymbol'")
338         << "Expected error string to be \"Object handle 42 not found\"";
339     };
340 
341   RPCEndpoint ClientEP(*Channels.first, true);
342   RemoteObjectClientLayer<RPCEndpoint> Client(AcknowledgeORCv1Deprecation,
343                                               ClientEP, ReportError);
344 
345   RPCEndpoint ServerEP(*Channels.second, true);
346 
347   // AddObject lambda does not update symbol table, so removeObject will treat
348   // this as a bad object handle.
349   MockObjectLayer BaseLayer(
350     [](MockObjectLayer::ObjectPtr Obj,
351        MockObjectLayer::SymbolLookupTable &SymTab) {
352       SymTab[42] =
353         [](StringRef Name, bool ExportedSymbolsOnly) -> JITSymbol {
354           if (Name == "foobar")
355             return JITSymbol(0x12348765, JITSymbolFlags::Exported);
356           if (Name == "badsymbol")
357             return make_error<JITSymbolNotFound>(Name);
358           return nullptr;
359         };
360       return 42;
361     });
362   RemoteObjectServerLayer<MockObjectLayer, RPCEndpoint> Server(
363       AcknowledgeORCv1Deprecation, BaseLayer, ServerEP, ReportError);
364 
365   bool Finished = false;
366   ServerEP.addHandler<remote::utils::TerminateSession>(
367     [&]() { Finished = true; }
368   );
369 
370   auto ServerThread =
371     std::thread([&]() {
372       while (!Finished)
373         cantFail(ServerEP.handleOne());
374     });
375 
376   cantFail(Client.addObject(std::move(TestObject),
377                             std::make_shared<NullLegacyResolver>()));
378 
379   // Check that we can find and materialize a valid symbol.
380   auto Sym1 = Client.findSymbol("foobar", true);
381   EXPECT_TRUE(!!Sym1) << "Symbol 'foobar' should be findable";
382   EXPECT_EQ(cantFail(Sym1.getAddress()), 0x12348765ULL)
383     << "Symbol 'foobar' does not return the correct address";
384 
385   {
386     // Check that we can return a symbol containing an error.
387     auto Sym2 = Client.findSymbol("badsymbol", true);
388     EXPECT_FALSE(!!Sym2) << "Symbol 'badsymbol' should not be findable";
389     auto Err = Sym2.takeError();
390     EXPECT_TRUE(!!Err) << "Sym2 should contain an error value";
391     auto ErrMsg = toString(std::move(Err));
392     EXPECT_EQ(ErrMsg, "Could not find symbol 'badsymbol'")
393       << "Expected symbol-not-found error for Sym2";
394   }
395 
396   {
397     // Check that we can return a 'null' symbol.
398     auto Sym3 = Client.findSymbol("baz", true);
399     EXPECT_FALSE(!!Sym3) << "Symbol 'baz' should convert to false";
400     auto Err = Sym3.takeError();
401     EXPECT_FALSE(!!Err) << "Symbol 'baz' should not contain an error";
402   }
403 
404   cantFail(ClientEP.callB<remote::utils::TerminateSession>());
405   ServerThread.join();
406 }
407 
TEST(RemoteObjectLayer,FindSymbolIn)408 TEST(RemoteObjectLayer, FindSymbolIn) {
409   llvm::orc::rpc::registerStringError<rpc::RawByteChannel>();
410   auto TestObject = createTestObject();
411   if (!TestObject)
412     return;
413 
414   auto Channels = createPairedQueueChannels();
415 
416   auto ReportError =
417     [](Error Err) {
418       auto ErrMsg = toString(std::move(Err));
419       EXPECT_EQ(ErrMsg, "Could not find symbol 'barbaz'")
420         << "Expected error string to be \"Object handle 42 not found\"";
421     };
422 
423   RPCEndpoint ClientEP(*Channels.first, true);
424   RemoteObjectClientLayer<RPCEndpoint> Client(AcknowledgeORCv1Deprecation,
425                                               ClientEP, ReportError);
426 
427   RPCEndpoint ServerEP(*Channels.second, true);
428 
429   // AddObject lambda does not update symbol table, so removeObject will treat
430   // this as a bad object handle.
431   MockObjectLayer BaseLayer(
432     [](MockObjectLayer::ObjectPtr Obj,
433        MockObjectLayer::SymbolLookupTable &SymTab) {
434       SymTab[42] =
435         [](StringRef Name, bool ExportedSymbolsOnly) -> JITSymbol {
436           if (Name == "foobar")
437             return JITSymbol(0x12348765, JITSymbolFlags::Exported);
438           return make_error<JITSymbolNotFound>(Name);
439         };
440       // Dummy symbol table entry - this should not be visible to
441       // findSymbolIn.
442       SymTab[43] =
443         [](StringRef Name, bool ExportedSymbolsOnly) -> JITSymbol {
444           if (Name == "barbaz")
445             return JITSymbol(0xdeadbeef, JITSymbolFlags::Exported);
446           return make_error<JITSymbolNotFound>(Name);
447         };
448 
449       return 42;
450     });
451   RemoteObjectServerLayer<MockObjectLayer, RPCEndpoint> Server(
452       AcknowledgeORCv1Deprecation, BaseLayer, ServerEP, ReportError);
453 
454   bool Finished = false;
455   ServerEP.addHandler<remote::utils::TerminateSession>(
456     [&]() { Finished = true; }
457   );
458 
459   auto ServerThread =
460     std::thread([&]() {
461       while (!Finished)
462         cantFail(ServerEP.handleOne());
463     });
464 
465   auto H = cantFail(Client.addObject(std::move(TestObject),
466                                      std::make_shared<NullLegacyResolver>()));
467 
468   auto Sym1 = Client.findSymbolIn(H, "foobar", true);
469 
470   EXPECT_TRUE(!!Sym1) << "Symbol 'foobar' should be findable";
471   EXPECT_EQ(cantFail(Sym1.getAddress()), 0x12348765ULL)
472     << "Symbol 'foobar' does not return the correct address";
473 
474   auto Sym2 = Client.findSymbolIn(H, "barbaz", true);
475   EXPECT_FALSE(!!Sym2) << "Symbol 'barbaz' should not be findable";
476   auto Err = Sym2.takeError();
477   EXPECT_TRUE(!!Err) << "Sym2 should contain an error value";
478   auto ErrMsg = toString(std::move(Err));
479   EXPECT_EQ(ErrMsg, "Could not find symbol 'barbaz'")
480     << "Expected symbol-not-found error for Sym2";
481 
482   cantFail(ClientEP.callB<remote::utils::TerminateSession>());
483   ServerThread.join();
484 }
485 
TEST(RemoteObjectLayer,EmitAndFinalize)486 TEST(RemoteObjectLayer, EmitAndFinalize) {
487   llvm::orc::rpc::registerStringError<rpc::RawByteChannel>();
488   auto TestObject = createTestObject();
489   if (!TestObject)
490     return;
491 
492   auto Channels = createPairedQueueChannels();
493 
494   auto ReportError = [](Error Err) {
495     logAllUnhandledErrors(std::move(Err), llvm::errs());
496   };
497 
498   RPCEndpoint ClientEP(*Channels.first, true);
499   RemoteObjectClientLayer<RPCEndpoint> Client(AcknowledgeORCv1Deprecation,
500                                               ClientEP, ReportError);
501 
502   RPCEndpoint ServerEP(*Channels.second, true);
503 
504   MockObjectLayer BaseLayer(
505     [](MockObjectLayer::ObjectPtr Obj,
506        MockObjectLayer::SymbolLookupTable &SymTab) {
507       SymTab[1] = MockObjectLayer::LookupFn();
508       return 1;
509     });
510   RemoteObjectServerLayer<MockObjectLayer, RPCEndpoint> Server(
511       AcknowledgeORCv1Deprecation, BaseLayer, ServerEP, ReportError);
512 
513   bool Finished = false;
514   ServerEP.addHandler<remote::utils::TerminateSession>(
515     [&]() { Finished = true; }
516   );
517 
518   auto ServerThread =
519     std::thread([&]() {
520       while (!Finished)
521         cantFail(ServerEP.handleOne());
522     });
523 
524   auto H = cantFail(Client.addObject(std::move(TestObject),
525                                      std::make_shared<NullLegacyResolver>()));
526 
527   auto Err = Client.emitAndFinalize(H);
528   EXPECT_FALSE(!!Err) << "emitAndFinalize should work";
529 
530   cantFail(ClientEP.callB<remote::utils::TerminateSession>());
531   ServerThread.join();
532 }
533 
TEST(RemoteObjectLayer,EmitAndFinalizeFailure)534 TEST(RemoteObjectLayer, EmitAndFinalizeFailure) {
535   llvm::orc::rpc::registerStringError<rpc::RawByteChannel>();
536   auto TestObject = createTestObject();
537   if (!TestObject)
538     return;
539 
540   auto Channels = createPairedQueueChannels();
541 
542   auto ReportError =
543     [](Error Err) {
544       auto ErrMsg = toString(std::move(Err));
545       EXPECT_EQ(ErrMsg, "Object handle 1 not found")
546         << "Expected bad handle error";
547     };
548 
549   RPCEndpoint ClientEP(*Channels.first, true);
550   RemoteObjectClientLayer<RPCEndpoint> Client(AcknowledgeORCv1Deprecation,
551                                               ClientEP, ReportError);
552 
553   RPCEndpoint ServerEP(*Channels.second, true);
554 
555   MockObjectLayer BaseLayer(
556     [](MockObjectLayer::ObjectPtr Obj,
557        MockObjectLayer::SymbolLookupTable &SymTab) {
558       return 1;
559     });
560   RemoteObjectServerLayer<MockObjectLayer, RPCEndpoint> Server(
561       AcknowledgeORCv1Deprecation, BaseLayer, ServerEP, ReportError);
562 
563   bool Finished = false;
564   ServerEP.addHandler<remote::utils::TerminateSession>(
565     [&]() { Finished = true; }
566   );
567 
568   auto ServerThread =
569     std::thread([&]() {
570       while (!Finished)
571         cantFail(ServerEP.handleOne());
572     });
573 
574   auto H = cantFail(Client.addObject(std::move(TestObject),
575                                      std::make_shared<NullLegacyResolver>()));
576 
577   auto Err = Client.emitAndFinalize(H);
578   EXPECT_TRUE(!!Err) << "emitAndFinalize should work";
579 
580   auto ErrMsg = toString(std::move(Err));
581   EXPECT_EQ(ErrMsg, "Object handle 1 not found")
582     << "emitAndFinalize returned incorrect error";
583 
584   cantFail(ClientEP.callB<remote::utils::TerminateSession>());
585   ServerThread.join();
586 }
587 
588 }
589