1 //===---- SimpleRemoteEPCServer.h - EPC over abstract channel ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // EPC over simple abstract channel.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_SIMPLEREMOTEEPCSERVER_H
14 #define LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_SIMPLEREMOTEEPCSERVER_H
15 
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/FunctionExtras.h"
18 #include "llvm/Config/llvm-config.h"
19 #include "llvm/ExecutionEngine/Orc/Shared/SimpleRemoteEPCUtils.h"
20 #include "llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h"
21 #include "llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h"
22 #include "llvm/ExecutionEngine/Orc/TargetProcess/ExecutorBootstrapService.h"
23 #include "llvm/ExecutionEngine/Orc/TargetProcess/SimpleExecutorDylibManager.h"
24 #include "llvm/Support/DynamicLibrary.h"
25 #include "llvm/Support/Error.h"
26 
27 #include <condition_variable>
28 #include <future>
29 #include <memory>
30 #include <mutex>
31 
32 namespace llvm {
33 namespace orc {
34 
35 /// A simple EPC server implementation.
36 class SimpleRemoteEPCServer : public SimpleRemoteEPCTransportClient {
37 public:
38   using ReportErrorFunction = unique_function<void(Error)>;
39 
40   /// Dispatches calls to runWrapper.
41   class Dispatcher {
42   public:
43     virtual ~Dispatcher();
44     virtual void dispatch(unique_function<void()> Work) = 0;
45     virtual void shutdown() = 0;
46   };
47 
48 #if LLVM_ENABLE_THREADS
49   class ThreadDispatcher : public Dispatcher {
50   public:
51     void dispatch(unique_function<void()> Work) override;
52     void shutdown() override;
53 
54   private:
55     std::mutex DispatchMutex;
56     bool Running = true;
57     size_t Outstanding = 0;
58     std::condition_variable OutstandingCV;
59   };
60 #endif
61 
62   class Setup {
63     friend class SimpleRemoteEPCServer;
64 
65   public:
66     SimpleRemoteEPCServer &server() { return S; }
67     StringMap<std::vector<char>> &bootstrapMap() { return BootstrapMap; }
68     template <typename T, typename SPSTagT>
69     void setBootstrapMapValue(std::string Key, const T &Value) {
70       std::vector<char> Buffer;
71       Buffer.resize(shared::SPSArgList<SPSTagT>::size(Value));
72       shared::SPSOutputBuffer OB(Buffer.data(), Buffer.size());
73       bool Success = shared::SPSArgList<SPSTagT>::serialize(OB, Value);
74       (void)Success;
75       assert(Success && "Bootstrap map value serialization failed");
76       BootstrapMap[std::move(Key)] = std::move(Buffer);
77     }
78     StringMap<ExecutorAddr> &bootstrapSymbols() { return BootstrapSymbols; }
79     std::vector<std::unique_ptr<ExecutorBootstrapService>> &services() {
80       return Services;
81     }
82     void setDispatcher(std::unique_ptr<Dispatcher> D) { S.D = std::move(D); }
83     void setErrorReporter(unique_function<void(Error)> ReportError) {
84       S.ReportError = std::move(ReportError);
85     }
86 
87   private:
88     Setup(SimpleRemoteEPCServer &S) : S(S) {}
89     SimpleRemoteEPCServer &S;
90     StringMap<std::vector<char>> BootstrapMap;
91     StringMap<ExecutorAddr> BootstrapSymbols;
92     std::vector<std::unique_ptr<ExecutorBootstrapService>> Services;
93   };
94 
95   static StringMap<ExecutorAddr> defaultBootstrapSymbols();
96 
97   template <typename TransportT, typename... TransportTCtorArgTs>
98   static Expected<std::unique_ptr<SimpleRemoteEPCServer>>
99   Create(unique_function<Error(Setup &S)> SetupFunction,
100          TransportTCtorArgTs &&...TransportTCtorArgs) {
101     auto Server = std::make_unique<SimpleRemoteEPCServer>();
102     Setup S(*Server);
103     if (auto Err = SetupFunction(S))
104       return std::move(Err);
105 
106     // Set ReportError up-front so that it can be used if construction
107     // process fails.
108     if (!Server->ReportError)
109       Server->ReportError = [](Error Err) {
110         logAllUnhandledErrors(std::move(Err), errs(), "SimpleRemoteEPCServer ");
111       };
112 
113     // Attempt to create transport.
114     auto T = TransportT::Create(
115         *Server, std::forward<TransportTCtorArgTs>(TransportTCtorArgs)...);
116     if (!T)
117       return T.takeError();
118     Server->T = std::move(*T);
119     if (auto Err = Server->T->start())
120       return std::move(Err);
121 
122     // If transport creation succeeds then start up services.
123     Server->Services = std::move(S.services());
124     Server->Services.push_back(
125         std::make_unique<rt_bootstrap::SimpleExecutorDylibManager>());
126     for (auto &Service : Server->Services)
127       Service->addBootstrapSymbols(S.bootstrapSymbols());
128 
129     if (auto Err = Server->sendSetupMessage(std::move(S.BootstrapMap),
130                                             std::move(S.BootstrapSymbols)))
131       return std::move(Err);
132     return std::move(Server);
133   }
134 
135   /// Set an error reporter for this server.
136   void setErrorReporter(ReportErrorFunction ReportError) {
137     this->ReportError = std::move(ReportError);
138   }
139 
140   /// Call to handle an incoming message.
141   ///
142   /// Returns 'Disconnect' if the message is a 'detach' message from the remote
143   /// otherwise returns 'Continue'. If the server has moved to an error state,
144   /// returns an error, which should be reported and treated as a 'Disconnect'.
145   Expected<HandleMessageAction>
146   handleMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo, ExecutorAddr TagAddr,
147                 SimpleRemoteEPCArgBytesVector ArgBytes) override;
148 
149   Error waitForDisconnect();
150 
151   void handleDisconnect(Error Err) override;
152 
153 private:
154   Error sendMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
155                     ExecutorAddr TagAddr, ArrayRef<char> ArgBytes);
156 
157   Error sendSetupMessage(StringMap<std::vector<char>> BootstrapMap,
158                          StringMap<ExecutorAddr> BootstrapSymbols);
159 
160   Error handleResult(uint64_t SeqNo, ExecutorAddr TagAddr,
161                      SimpleRemoteEPCArgBytesVector ArgBytes);
162   void handleCallWrapper(uint64_t RemoteSeqNo, ExecutorAddr TagAddr,
163                          SimpleRemoteEPCArgBytesVector ArgBytes);
164 
165   shared::WrapperFunctionResult
166   doJITDispatch(const void *FnTag, const char *ArgData, size_t ArgSize);
167 
168   static shared::CWrapperFunctionResult jitDispatchEntry(void *DispatchCtx,
169                                                          const void *FnTag,
170                                                          const char *ArgData,
171                                                          size_t ArgSize);
172 
173   uint64_t getNextSeqNo() { return NextSeqNo++; }
174   void releaseSeqNo(uint64_t) {}
175 
176   using PendingJITDispatchResultsMap =
177       DenseMap<uint64_t, std::promise<shared::WrapperFunctionResult> *>;
178 
179   std::mutex ServerStateMutex;
180   std::condition_variable ShutdownCV;
181   enum { ServerRunning, ServerShuttingDown, ServerShutDown } RunState;
182   Error ShutdownErr = Error::success();
183   std::unique_ptr<SimpleRemoteEPCTransport> T;
184   std::unique_ptr<Dispatcher> D;
185   std::vector<std::unique_ptr<ExecutorBootstrapService>> Services;
186   ReportErrorFunction ReportError;
187 
188   uint64_t NextSeqNo = 0;
189   PendingJITDispatchResultsMap PendingJITDispatchResults;
190   std::vector<sys::DynamicLibrary> Dylibs;
191 };
192 
193 } // end namespace orc
194 } // end namespace llvm
195 
196 #endif // LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_SIMPLEREMOTEEPCSERVER_H
197