1 //===-- OrcRPCExecutorProcessControl.h - Remote target control --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Executor control via ORC RPC.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_EXECUTIONENGINE_ORC_ORCRPCEXECUTORPROCESSCONTROL_H
14 #define LLVM_EXECUTIONENGINE_ORC_ORCRPCEXECUTORPROCESSCONTROL_H
15 
16 #include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h"
17 #include "llvm/ExecutionEngine/Orc/Shared/RPCUtils.h"
18 #include "llvm/ExecutionEngine/Orc/Shared/RawByteChannel.h"
19 #include "llvm/ExecutionEngine/Orc/TargetProcess/OrcRPCTPCServer.h"
20 #include "llvm/Support/MSVCErrorWorkarounds.h"
21 
22 namespace llvm {
23 namespace orc {
24 
25 /// JITLinkMemoryManager implementation for a process connected via an ORC RPC
26 /// endpoint.
27 template <typename OrcRPCEPCImplT>
28 class OrcRPCEPCJITLinkMemoryManager : public jitlink::JITLinkMemoryManager {
29 private:
30   struct HostAlloc {
31     std::unique_ptr<char[]> Mem;
32     uint64_t Size;
33   };
34 
35   struct TargetAlloc {
36     JITTargetAddress Address = 0;
37     uint64_t AllocatedSize = 0;
38   };
39 
40   using HostAllocMap = DenseMap<int, HostAlloc>;
41   using TargetAllocMap = DenseMap<int, TargetAlloc>;
42 
43 public:
44   class OrcRPCAllocation : public Allocation {
45   public:
OrcRPCAllocation(OrcRPCEPCJITLinkMemoryManager<OrcRPCEPCImplT> & Parent,HostAllocMap HostAllocs,TargetAllocMap TargetAllocs)46     OrcRPCAllocation(OrcRPCEPCJITLinkMemoryManager<OrcRPCEPCImplT> &Parent,
47                      HostAllocMap HostAllocs, TargetAllocMap TargetAllocs)
48         : Parent(Parent), HostAllocs(std::move(HostAllocs)),
49           TargetAllocs(std::move(TargetAllocs)) {
50       assert(HostAllocs.size() == TargetAllocs.size() &&
51              "HostAllocs size should match TargetAllocs");
52     }
53 
~OrcRPCAllocation()54     ~OrcRPCAllocation() override {
55       assert(TargetAllocs.empty() && "failed to deallocate");
56     }
57 
getWorkingMemory(ProtectionFlags Seg)58     MutableArrayRef<char> getWorkingMemory(ProtectionFlags Seg) override {
59       auto I = HostAllocs.find(Seg);
60       assert(I != HostAllocs.end() && "No host allocation for segment");
61       auto &HA = I->second;
62       return {HA.Mem.get(), static_cast<size_t>(HA.Size)};
63     }
64 
getTargetMemory(ProtectionFlags Seg)65     JITTargetAddress getTargetMemory(ProtectionFlags Seg) override {
66       auto I = TargetAllocs.find(Seg);
67       assert(I != TargetAllocs.end() && "No target allocation for segment");
68       return I->second.Address;
69     }
70 
finalizeAsync(FinalizeContinuation OnFinalize)71     void finalizeAsync(FinalizeContinuation OnFinalize) override {
72 
73       std::vector<tpctypes::BufferWrite> BufferWrites;
74       orcrpctpc::ReleaseOrFinalizeMemRequest FMR;
75 
76       for (auto &KV : HostAllocs) {
77         assert(TargetAllocs.count(KV.first) &&
78                "No target allocation for buffer");
79         auto &HA = KV.second;
80         auto &TA = TargetAllocs[KV.first];
81         BufferWrites.push_back({TA.Address, StringRef(HA.Mem.get(), HA.Size)});
82         FMR.push_back({orcrpctpc::toWireProtectionFlags(
83                            static_cast<sys::Memory::ProtectionFlags>(KV.first)),
84                        TA.Address, TA.AllocatedSize});
85       }
86 
87       DEBUG_WITH_TYPE("orc", {
88         dbgs() << "finalizeAsync " << (void *)this << ":\n";
89         auto FMRI = FMR.begin();
90         for (auto &B : BufferWrites) {
91           auto Prot = FMRI->Prot;
92           ++FMRI;
93           dbgs() << "  Writing " << formatv("{0:x16}", B.Buffer.size())
94                  << " bytes to " << ((Prot & orcrpctpc::WPF_Read) ? 'R' : '-')
95                  << ((Prot & orcrpctpc::WPF_Write) ? 'W' : '-')
96                  << ((Prot & orcrpctpc::WPF_Exec) ? 'X' : '-')
97                  << " segment: local " << (const void *)B.Buffer.data()
98                  << " -> target " << formatv("{0:x16}", B.Address) << "\n";
99         }
100       });
101       if (auto Err =
102               Parent.Parent.getMemoryAccess().writeBuffers(BufferWrites)) {
103         OnFinalize(std::move(Err));
104         return;
105       }
106 
107       DEBUG_WITH_TYPE("orc", dbgs() << " Applying permissions...\n");
108       if (auto Err =
109               Parent.getEndpoint().template callAsync<orcrpctpc::FinalizeMem>(
110                   [OF = std::move(OnFinalize)](Error Err2) {
111                     // FIXME: Dispatch to work queue.
112                     std::thread([OF = std::move(OF),
113                                  Err3 = std::move(Err2)]() mutable {
114                       DEBUG_WITH_TYPE(
115                           "orc", { dbgs() << "  finalizeAsync complete\n"; });
116                       OF(std::move(Err3));
117                     }).detach();
118                     return Error::success();
119                   },
120                   FMR)) {
121         DEBUG_WITH_TYPE("orc", dbgs() << "    failed.\n");
122         Parent.getEndpoint().abandonPendingResponses();
123         Parent.reportError(std::move(Err));
124       }
125       DEBUG_WITH_TYPE("orc", {
126         dbgs() << "Leaving finalizeAsync (finalization may continue in "
127                   "background)\n";
128       });
129     }
130 
deallocate()131     Error deallocate() override {
132       orcrpctpc::ReleaseOrFinalizeMemRequest RMR;
133       for (auto &KV : TargetAllocs)
134         RMR.push_back({orcrpctpc::toWireProtectionFlags(
135                            static_cast<sys::Memory::ProtectionFlags>(KV.first)),
136                        KV.second.Address, KV.second.AllocatedSize});
137       TargetAllocs.clear();
138 
139       return Parent.getEndpoint().template callB<orcrpctpc::ReleaseMem>(RMR);
140     }
141 
142   private:
143     OrcRPCEPCJITLinkMemoryManager<OrcRPCEPCImplT> &Parent;
144     HostAllocMap HostAllocs;
145     TargetAllocMap TargetAllocs;
146   };
147 
OrcRPCEPCJITLinkMemoryManager(OrcRPCEPCImplT & Parent)148   OrcRPCEPCJITLinkMemoryManager(OrcRPCEPCImplT &Parent) : Parent(Parent) {}
149 
150   Expected<std::unique_ptr<Allocation>>
allocate(const jitlink::JITLinkDylib * JD,const SegmentsRequestMap & Request)151   allocate(const jitlink::JITLinkDylib *JD,
152            const SegmentsRequestMap &Request) override {
153     orcrpctpc::ReserveMemRequest RMR;
154     HostAllocMap HostAllocs;
155 
156     for (auto &KV : Request) {
157       assert(KV.second.getContentSize() <= std::numeric_limits<size_t>::max() &&
158              "Content size is out-of-range for host");
159 
160       RMR.push_back({orcrpctpc::toWireProtectionFlags(
161                          static_cast<sys::Memory::ProtectionFlags>(KV.first)),
162                      KV.second.getContentSize() + KV.second.getZeroFillSize(),
163                      KV.second.getAlignment()});
164       HostAllocs[KV.first] = {
165           std::make_unique<char[]>(KV.second.getContentSize()),
166           KV.second.getContentSize()};
167     }
168 
169     DEBUG_WITH_TYPE("orc", {
170       dbgs() << "Orc remote memmgr got request:\n";
171       for (auto &KV : Request)
172         dbgs() << "  permissions: "
173                << ((KV.first & sys::Memory::MF_READ) ? 'R' : '-')
174                << ((KV.first & sys::Memory::MF_WRITE) ? 'W' : '-')
175                << ((KV.first & sys::Memory::MF_EXEC) ? 'X' : '-')
176                << ", content size: "
177                << formatv("{0:x16}", KV.second.getContentSize())
178                << " + zero-fill-size: "
179                << formatv("{0:x16}", KV.second.getZeroFillSize())
180                << ", align: " << KV.second.getAlignment() << "\n";
181     });
182 
183     // FIXME: LLVM RPC needs to be fixed to support alt
184     // serialization/deserialization on return types. For now just
185     // translate from std::map to DenseMap manually.
186     auto TmpTargetAllocs =
187         Parent.getEndpoint().template callB<orcrpctpc::ReserveMem>(RMR);
188     if (!TmpTargetAllocs)
189       return TmpTargetAllocs.takeError();
190 
191     if (TmpTargetAllocs->size() != RMR.size())
192       return make_error<StringError>(
193           "Number of target allocations does not match request",
194           inconvertibleErrorCode());
195 
196     TargetAllocMap TargetAllocs;
197     for (auto &E : *TmpTargetAllocs)
198       TargetAllocs[orcrpctpc::fromWireProtectionFlags(E.Prot)] = {
199           E.Address, E.AllocatedSize};
200 
201     DEBUG_WITH_TYPE("orc", {
202       auto HAI = HostAllocs.begin();
203       for (auto &KV : TargetAllocs)
204         dbgs() << "  permissions: "
205                << ((KV.first & sys::Memory::MF_READ) ? 'R' : '-')
206                << ((KV.first & sys::Memory::MF_WRITE) ? 'W' : '-')
207                << ((KV.first & sys::Memory::MF_EXEC) ? 'X' : '-')
208                << " assigned local " << (void *)HAI->second.Mem.get()
209                << ", target " << formatv("{0:x16}", KV.second.Address) << "\n";
210     });
211 
212     return std::make_unique<OrcRPCAllocation>(*this, std::move(HostAllocs),
213                                               std::move(TargetAllocs));
214   }
215 
216 private:
reportError(Error Err)217   void reportError(Error Err) { Parent.reportError(std::move(Err)); }
218 
getEndpoint()219   decltype(std::declval<OrcRPCEPCImplT>().getEndpoint()) getEndpoint() {
220     return Parent.getEndpoint();
221   }
222 
223   OrcRPCEPCImplT &Parent;
224 };
225 
226 /// ExecutorProcessControl::MemoryAccess implementation for a process connected
227 /// via an ORC RPC endpoint.
228 template <typename OrcRPCEPCImplT>
229 class OrcRPCEPCMemoryAccess : public ExecutorProcessControl::MemoryAccess {
230 public:
OrcRPCEPCMemoryAccess(OrcRPCEPCImplT & Parent)231   OrcRPCEPCMemoryAccess(OrcRPCEPCImplT &Parent) : Parent(Parent) {}
232 
writeUInt8s(ArrayRef<tpctypes::UInt8Write> Ws,WriteResultFn OnWriteComplete)233   void writeUInt8s(ArrayRef<tpctypes::UInt8Write> Ws,
234                    WriteResultFn OnWriteComplete) override {
235     writeViaRPC<orcrpctpc::WriteUInt8s>(Ws, std::move(OnWriteComplete));
236   }
237 
writeUInt16s(ArrayRef<tpctypes::UInt16Write> Ws,WriteResultFn OnWriteComplete)238   void writeUInt16s(ArrayRef<tpctypes::UInt16Write> Ws,
239                     WriteResultFn OnWriteComplete) override {
240     writeViaRPC<orcrpctpc::WriteUInt16s>(Ws, std::move(OnWriteComplete));
241   }
242 
writeUInt32s(ArrayRef<tpctypes::UInt32Write> Ws,WriteResultFn OnWriteComplete)243   void writeUInt32s(ArrayRef<tpctypes::UInt32Write> Ws,
244                     WriteResultFn OnWriteComplete) override {
245     writeViaRPC<orcrpctpc::WriteUInt32s>(Ws, std::move(OnWriteComplete));
246   }
247 
writeUInt64s(ArrayRef<tpctypes::UInt64Write> Ws,WriteResultFn OnWriteComplete)248   void writeUInt64s(ArrayRef<tpctypes::UInt64Write> Ws,
249                     WriteResultFn OnWriteComplete) override {
250     writeViaRPC<orcrpctpc::WriteUInt64s>(Ws, std::move(OnWriteComplete));
251   }
252 
writeBuffers(ArrayRef<tpctypes::BufferWrite> Ws,WriteResultFn OnWriteComplete)253   void writeBuffers(ArrayRef<tpctypes::BufferWrite> Ws,
254                     WriteResultFn OnWriteComplete) override {
255     writeViaRPC<orcrpctpc::WriteBuffers>(Ws, std::move(OnWriteComplete));
256   }
257 
258 private:
259   template <typename WriteRPCFunction, typename WriteElementT>
writeViaRPC(ArrayRef<WriteElementT> Ws,WriteResultFn OnWriteComplete)260   void writeViaRPC(ArrayRef<WriteElementT> Ws, WriteResultFn OnWriteComplete) {
261     if (auto Err = Parent.getEndpoint().template callAsync<WriteRPCFunction>(
262             [OWC = std::move(OnWriteComplete)](Error Err2) mutable -> Error {
263               OWC(std::move(Err2));
264               return Error::success();
265             },
266             Ws)) {
267       Parent.reportError(std::move(Err));
268       Parent.getEndpoint().abandonPendingResponses();
269     }
270   }
271 
272   OrcRPCEPCImplT &Parent;
273 };
274 
275 // ExecutorProcessControl for a process connected via an ORC RPC Endpoint.
276 template <typename RPCEndpointT>
277 class OrcRPCExecutorProcessControlBase : public ExecutorProcessControl {
278 public:
279   using ErrorReporter = unique_function<void(Error)>;
280 
281   using OnCloseConnectionFunction = unique_function<Error(Error)>;
282 
OrcRPCExecutorProcessControlBase(std::shared_ptr<SymbolStringPool> SSP,RPCEndpointT & EP,ErrorReporter ReportError)283   OrcRPCExecutorProcessControlBase(std::shared_ptr<SymbolStringPool> SSP,
284                                    RPCEndpointT &EP, ErrorReporter ReportError)
285       : ExecutorProcessControl(std::move(SSP)),
286         ReportError(std::move(ReportError)), EP(EP) {
287     using ThisT = OrcRPCExecutorProcessControlBase<RPCEndpointT>;
288     EP.template addAsyncHandler<orcrpctpc::RunWrapper>(*this,
289                                                        &ThisT::runWrapperInJIT);
290   }
291 
reportError(Error Err)292   void reportError(Error Err) { ReportError(std::move(Err)); }
293 
getEndpoint()294   RPCEndpointT &getEndpoint() { return EP; }
295 
loadDylib(const char * DylibPath)296   Expected<tpctypes::DylibHandle> loadDylib(const char *DylibPath) override {
297     DEBUG_WITH_TYPE("orc", {
298       dbgs() << "Loading dylib \"" << (DylibPath ? DylibPath : "") << "\" ";
299       if (!DylibPath)
300         dbgs() << "(process symbols)";
301       dbgs() << "\n";
302     });
303     if (!DylibPath)
304       DylibPath = "";
305     auto H = EP.template callB<orcrpctpc::LoadDylib>(DylibPath);
306     DEBUG_WITH_TYPE("orc", {
307       if (H)
308         dbgs() << "  got handle " << formatv("{0:x16}", *H) << "\n";
309       else
310         dbgs() << "  error, unable to load\n";
311     });
312     return H;
313   }
314 
315   Expected<std::vector<tpctypes::LookupResult>>
lookupSymbols(ArrayRef<LookupRequest> Request)316   lookupSymbols(ArrayRef<LookupRequest> Request) override {
317     std::vector<orcrpctpc::RemoteLookupRequest> RR;
318     for (auto &E : Request) {
319       RR.push_back({});
320       RR.back().first = E.Handle;
321       for (auto &KV : E.Symbols)
322         RR.back().second.push_back(
323             {(*KV.first).str(),
324              KV.second == SymbolLookupFlags::WeaklyReferencedSymbol});
325     }
326     DEBUG_WITH_TYPE("orc", {
327       dbgs() << "Compound lookup:\n";
328       for (auto &R : Request) {
329         dbgs() << "  In " << formatv("{0:x16}", R.Handle) << ": {";
330         bool First = true;
331         for (auto &KV : R.Symbols) {
332           dbgs() << (First ? "" : ",") << " " << *KV.first;
333           First = false;
334         }
335         dbgs() << " }\n";
336       }
337     });
338     return EP.template callB<orcrpctpc::LookupSymbols>(RR);
339   }
340 
runAsMain(JITTargetAddress MainFnAddr,ArrayRef<std::string> Args)341   Expected<int32_t> runAsMain(JITTargetAddress MainFnAddr,
342                               ArrayRef<std::string> Args) override {
343     DEBUG_WITH_TYPE("orc", {
344       dbgs() << "Running as main: " << formatv("{0:x16}", MainFnAddr)
345              << ", args = [";
346       for (unsigned I = 0; I != Args.size(); ++I)
347         dbgs() << (I ? "," : "") << " \"" << Args[I] << "\"";
348       dbgs() << "]\n";
349     });
350     auto Result = EP.template callB<orcrpctpc::RunMain>(MainFnAddr, Args);
351     DEBUG_WITH_TYPE("orc", {
352       dbgs() << "  call to " << formatv("{0:x16}", MainFnAddr);
353       if (Result)
354         dbgs() << " returned result " << *Result << "\n";
355       else
356         dbgs() << " failed\n";
357     });
358     return Result;
359   }
360 
callWrapperAsync(SendResultFunction OnComplete,JITTargetAddress WrapperFnAddr,ArrayRef<char> ArgBuffer)361   void callWrapperAsync(SendResultFunction OnComplete,
362                         JITTargetAddress WrapperFnAddr,
363                         ArrayRef<char> ArgBuffer) override {
364     DEBUG_WITH_TYPE("orc", {
365       dbgs() << "Running as wrapper function "
366              << formatv("{0:x16}", WrapperFnAddr) << " with "
367              << formatv("{0:x16}", ArgBuffer.size()) << " argument buffer\n";
368     });
369     auto Result = EP.template callB<orcrpctpc::RunWrapper>(
370         WrapperFnAddr,
371         ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(ArgBuffer.data()),
372                           ArgBuffer.size()));
373 
374     if (!Result)
375       OnComplete(shared::WrapperFunctionResult::createOutOfBandError(
376           toString(Result.takeError())));
377     OnComplete(std::move(*Result));
378   }
379 
closeConnection(OnCloseConnectionFunction OnCloseConnection)380   Error closeConnection(OnCloseConnectionFunction OnCloseConnection) {
381     DEBUG_WITH_TYPE("orc", dbgs() << "Closing connection to remote\n");
382     return EP.template callAsync<orcrpctpc::CloseConnection>(
383         std::move(OnCloseConnection));
384   }
385 
closeConnectionAndWait()386   Error closeConnectionAndWait() {
387     std::promise<MSVCPError> P;
388     auto F = P.get_future();
389     if (auto Err = closeConnection([&](Error Err2) -> Error {
390           P.set_value(std::move(Err2));
391           return Error::success();
392         })) {
393       EP.abandonAllPendingResponses();
394       return joinErrors(std::move(Err), F.get());
395     }
396     return F.get();
397   }
398 
399 protected:
400   /// Subclasses must call this during construction to initialize the
401   /// TargetTriple and PageSize members.
initializeORCRPCEPCBase()402   Error initializeORCRPCEPCBase() {
403     if (auto EPI = EP.template callB<orcrpctpc::GetExecutorProcessInfo>()) {
404       this->TargetTriple = Triple(EPI->Triple);
405       this->PageSize = PageSize;
406       this->JDI = {ExecutorAddress(EPI->DispatchFuncAddr),
407                    ExecutorAddress(EPI->DispatchCtxAddr)};
408       return Error::success();
409     } else
410       return EPI.takeError();
411   }
412 
413 private:
runWrapperInJIT(std::function<Error (Expected<shared::WrapperFunctionResult>)> SendResult,JITTargetAddress FunctionTag,std::vector<uint8_t> ArgBuffer)414   Error runWrapperInJIT(
415       std::function<Error(Expected<shared::WrapperFunctionResult>)> SendResult,
416       JITTargetAddress FunctionTag, std::vector<uint8_t> ArgBuffer) {
417 
418     getExecutionSession().runJITDispatchHandler(
419         [this, SendResult = std::move(SendResult)](
420             Expected<shared::WrapperFunctionResult> R) {
421           if (auto Err = SendResult(std::move(R)))
422             ReportError(std::move(Err));
423         },
424         FunctionTag,
425         {reinterpret_cast<const char *>(ArgBuffer.data()), ArgBuffer.size()});
426     return Error::success();
427   }
428 
429   ErrorReporter ReportError;
430   RPCEndpointT &EP;
431 };
432 
433 } // end namespace orc
434 } // end namespace llvm
435 
436 #endif // LLVM_EXECUTIONENGINE_ORC_ORCRPCEXECUTORPROCESSCONTROL_H
437