1 //===-- OrcRPCExecutorProcessControl.h - Remote target control --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Executor control via ORC RPC. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_EXECUTIONENGINE_ORC_ORCRPCEXECUTORPROCESSCONTROL_H 14 #define LLVM_EXECUTIONENGINE_ORC_ORCRPCEXECUTORPROCESSCONTROL_H 15 16 #include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" 17 #include "llvm/ExecutionEngine/Orc/Shared/RPCUtils.h" 18 #include "llvm/ExecutionEngine/Orc/Shared/RawByteChannel.h" 19 #include "llvm/ExecutionEngine/Orc/TargetProcess/OrcRPCTPCServer.h" 20 #include "llvm/Support/MSVCErrorWorkarounds.h" 21 22 namespace llvm { 23 namespace orc { 24 25 /// JITLinkMemoryManager implementation for a process connected via an ORC RPC 26 /// endpoint. 27 template <typename OrcRPCEPCImplT> 28 class OrcRPCEPCJITLinkMemoryManager : public jitlink::JITLinkMemoryManager { 29 private: 30 struct HostAlloc { 31 std::unique_ptr<char[]> Mem; 32 uint64_t Size; 33 }; 34 35 struct TargetAlloc { 36 JITTargetAddress Address = 0; 37 uint64_t AllocatedSize = 0; 38 }; 39 40 using HostAllocMap = DenseMap<int, HostAlloc>; 41 using TargetAllocMap = DenseMap<int, TargetAlloc>; 42 43 public: 44 class OrcRPCAllocation : public Allocation { 45 public: OrcRPCAllocation(OrcRPCEPCJITLinkMemoryManager<OrcRPCEPCImplT> & Parent,HostAllocMap HostAllocs,TargetAllocMap TargetAllocs)46 OrcRPCAllocation(OrcRPCEPCJITLinkMemoryManager<OrcRPCEPCImplT> &Parent, 47 HostAllocMap HostAllocs, TargetAllocMap TargetAllocs) 48 : Parent(Parent), HostAllocs(std::move(HostAllocs)), 49 TargetAllocs(std::move(TargetAllocs)) { 50 assert(HostAllocs.size() == TargetAllocs.size() && 51 "HostAllocs size should match TargetAllocs"); 52 } 53 ~OrcRPCAllocation()54 ~OrcRPCAllocation() override { 55 assert(TargetAllocs.empty() && "failed to deallocate"); 56 } 57 getWorkingMemory(ProtectionFlags Seg)58 MutableArrayRef<char> getWorkingMemory(ProtectionFlags Seg) override { 59 auto I = HostAllocs.find(Seg); 60 assert(I != HostAllocs.end() && "No host allocation for segment"); 61 auto &HA = I->second; 62 return {HA.Mem.get(), static_cast<size_t>(HA.Size)}; 63 } 64 getTargetMemory(ProtectionFlags Seg)65 JITTargetAddress getTargetMemory(ProtectionFlags Seg) override { 66 auto I = TargetAllocs.find(Seg); 67 assert(I != TargetAllocs.end() && "No target allocation for segment"); 68 return I->second.Address; 69 } 70 finalizeAsync(FinalizeContinuation OnFinalize)71 void finalizeAsync(FinalizeContinuation OnFinalize) override { 72 73 std::vector<tpctypes::BufferWrite> BufferWrites; 74 orcrpctpc::ReleaseOrFinalizeMemRequest FMR; 75 76 for (auto &KV : HostAllocs) { 77 assert(TargetAllocs.count(KV.first) && 78 "No target allocation for buffer"); 79 auto &HA = KV.second; 80 auto &TA = TargetAllocs[KV.first]; 81 BufferWrites.push_back({TA.Address, StringRef(HA.Mem.get(), HA.Size)}); 82 FMR.push_back({orcrpctpc::toWireProtectionFlags( 83 static_cast<sys::Memory::ProtectionFlags>(KV.first)), 84 TA.Address, TA.AllocatedSize}); 85 } 86 87 DEBUG_WITH_TYPE("orc", { 88 dbgs() << "finalizeAsync " << (void *)this << ":\n"; 89 auto FMRI = FMR.begin(); 90 for (auto &B : BufferWrites) { 91 auto Prot = FMRI->Prot; 92 ++FMRI; 93 dbgs() << " Writing " << formatv("{0:x16}", B.Buffer.size()) 94 << " bytes to " << ((Prot & orcrpctpc::WPF_Read) ? 'R' : '-') 95 << ((Prot & orcrpctpc::WPF_Write) ? 'W' : '-') 96 << ((Prot & orcrpctpc::WPF_Exec) ? 'X' : '-') 97 << " segment: local " << (const void *)B.Buffer.data() 98 << " -> target " << formatv("{0:x16}", B.Address) << "\n"; 99 } 100 }); 101 if (auto Err = 102 Parent.Parent.getMemoryAccess().writeBuffers(BufferWrites)) { 103 OnFinalize(std::move(Err)); 104 return; 105 } 106 107 DEBUG_WITH_TYPE("orc", dbgs() << " Applying permissions...\n"); 108 if (auto Err = 109 Parent.getEndpoint().template callAsync<orcrpctpc::FinalizeMem>( 110 [OF = std::move(OnFinalize)](Error Err2) { 111 // FIXME: Dispatch to work queue. 112 std::thread([OF = std::move(OF), 113 Err3 = std::move(Err2)]() mutable { 114 DEBUG_WITH_TYPE( 115 "orc", { dbgs() << " finalizeAsync complete\n"; }); 116 OF(std::move(Err3)); 117 }).detach(); 118 return Error::success(); 119 }, 120 FMR)) { 121 DEBUG_WITH_TYPE("orc", dbgs() << " failed.\n"); 122 Parent.getEndpoint().abandonPendingResponses(); 123 Parent.reportError(std::move(Err)); 124 } 125 DEBUG_WITH_TYPE("orc", { 126 dbgs() << "Leaving finalizeAsync (finalization may continue in " 127 "background)\n"; 128 }); 129 } 130 deallocate()131 Error deallocate() override { 132 orcrpctpc::ReleaseOrFinalizeMemRequest RMR; 133 for (auto &KV : TargetAllocs) 134 RMR.push_back({orcrpctpc::toWireProtectionFlags( 135 static_cast<sys::Memory::ProtectionFlags>(KV.first)), 136 KV.second.Address, KV.second.AllocatedSize}); 137 TargetAllocs.clear(); 138 139 return Parent.getEndpoint().template callB<orcrpctpc::ReleaseMem>(RMR); 140 } 141 142 private: 143 OrcRPCEPCJITLinkMemoryManager<OrcRPCEPCImplT> &Parent; 144 HostAllocMap HostAllocs; 145 TargetAllocMap TargetAllocs; 146 }; 147 OrcRPCEPCJITLinkMemoryManager(OrcRPCEPCImplT & Parent)148 OrcRPCEPCJITLinkMemoryManager(OrcRPCEPCImplT &Parent) : Parent(Parent) {} 149 150 Expected<std::unique_ptr<Allocation>> allocate(const jitlink::JITLinkDylib * JD,const SegmentsRequestMap & Request)151 allocate(const jitlink::JITLinkDylib *JD, 152 const SegmentsRequestMap &Request) override { 153 orcrpctpc::ReserveMemRequest RMR; 154 HostAllocMap HostAllocs; 155 156 for (auto &KV : Request) { 157 assert(KV.second.getContentSize() <= std::numeric_limits<size_t>::max() && 158 "Content size is out-of-range for host"); 159 160 RMR.push_back({orcrpctpc::toWireProtectionFlags( 161 static_cast<sys::Memory::ProtectionFlags>(KV.first)), 162 KV.second.getContentSize() + KV.second.getZeroFillSize(), 163 KV.second.getAlignment()}); 164 HostAllocs[KV.first] = { 165 std::make_unique<char[]>(KV.second.getContentSize()), 166 KV.second.getContentSize()}; 167 } 168 169 DEBUG_WITH_TYPE("orc", { 170 dbgs() << "Orc remote memmgr got request:\n"; 171 for (auto &KV : Request) 172 dbgs() << " permissions: " 173 << ((KV.first & sys::Memory::MF_READ) ? 'R' : '-') 174 << ((KV.first & sys::Memory::MF_WRITE) ? 'W' : '-') 175 << ((KV.first & sys::Memory::MF_EXEC) ? 'X' : '-') 176 << ", content size: " 177 << formatv("{0:x16}", KV.second.getContentSize()) 178 << " + zero-fill-size: " 179 << formatv("{0:x16}", KV.second.getZeroFillSize()) 180 << ", align: " << KV.second.getAlignment() << "\n"; 181 }); 182 183 // FIXME: LLVM RPC needs to be fixed to support alt 184 // serialization/deserialization on return types. For now just 185 // translate from std::map to DenseMap manually. 186 auto TmpTargetAllocs = 187 Parent.getEndpoint().template callB<orcrpctpc::ReserveMem>(RMR); 188 if (!TmpTargetAllocs) 189 return TmpTargetAllocs.takeError(); 190 191 if (TmpTargetAllocs->size() != RMR.size()) 192 return make_error<StringError>( 193 "Number of target allocations does not match request", 194 inconvertibleErrorCode()); 195 196 TargetAllocMap TargetAllocs; 197 for (auto &E : *TmpTargetAllocs) 198 TargetAllocs[orcrpctpc::fromWireProtectionFlags(E.Prot)] = { 199 E.Address, E.AllocatedSize}; 200 201 DEBUG_WITH_TYPE("orc", { 202 auto HAI = HostAllocs.begin(); 203 for (auto &KV : TargetAllocs) 204 dbgs() << " permissions: " 205 << ((KV.first & sys::Memory::MF_READ) ? 'R' : '-') 206 << ((KV.first & sys::Memory::MF_WRITE) ? 'W' : '-') 207 << ((KV.first & sys::Memory::MF_EXEC) ? 'X' : '-') 208 << " assigned local " << (void *)HAI->second.Mem.get() 209 << ", target " << formatv("{0:x16}", KV.second.Address) << "\n"; 210 }); 211 212 return std::make_unique<OrcRPCAllocation>(*this, std::move(HostAllocs), 213 std::move(TargetAllocs)); 214 } 215 216 private: reportError(Error Err)217 void reportError(Error Err) { Parent.reportError(std::move(Err)); } 218 getEndpoint()219 decltype(std::declval<OrcRPCEPCImplT>().getEndpoint()) getEndpoint() { 220 return Parent.getEndpoint(); 221 } 222 223 OrcRPCEPCImplT &Parent; 224 }; 225 226 /// ExecutorProcessControl::MemoryAccess implementation for a process connected 227 /// via an ORC RPC endpoint. 228 template <typename OrcRPCEPCImplT> 229 class OrcRPCEPCMemoryAccess : public ExecutorProcessControl::MemoryAccess { 230 public: OrcRPCEPCMemoryAccess(OrcRPCEPCImplT & Parent)231 OrcRPCEPCMemoryAccess(OrcRPCEPCImplT &Parent) : Parent(Parent) {} 232 writeUInt8s(ArrayRef<tpctypes::UInt8Write> Ws,WriteResultFn OnWriteComplete)233 void writeUInt8s(ArrayRef<tpctypes::UInt8Write> Ws, 234 WriteResultFn OnWriteComplete) override { 235 writeViaRPC<orcrpctpc::WriteUInt8s>(Ws, std::move(OnWriteComplete)); 236 } 237 writeUInt16s(ArrayRef<tpctypes::UInt16Write> Ws,WriteResultFn OnWriteComplete)238 void writeUInt16s(ArrayRef<tpctypes::UInt16Write> Ws, 239 WriteResultFn OnWriteComplete) override { 240 writeViaRPC<orcrpctpc::WriteUInt16s>(Ws, std::move(OnWriteComplete)); 241 } 242 writeUInt32s(ArrayRef<tpctypes::UInt32Write> Ws,WriteResultFn OnWriteComplete)243 void writeUInt32s(ArrayRef<tpctypes::UInt32Write> Ws, 244 WriteResultFn OnWriteComplete) override { 245 writeViaRPC<orcrpctpc::WriteUInt32s>(Ws, std::move(OnWriteComplete)); 246 } 247 writeUInt64s(ArrayRef<tpctypes::UInt64Write> Ws,WriteResultFn OnWriteComplete)248 void writeUInt64s(ArrayRef<tpctypes::UInt64Write> Ws, 249 WriteResultFn OnWriteComplete) override { 250 writeViaRPC<orcrpctpc::WriteUInt64s>(Ws, std::move(OnWriteComplete)); 251 } 252 writeBuffers(ArrayRef<tpctypes::BufferWrite> Ws,WriteResultFn OnWriteComplete)253 void writeBuffers(ArrayRef<tpctypes::BufferWrite> Ws, 254 WriteResultFn OnWriteComplete) override { 255 writeViaRPC<orcrpctpc::WriteBuffers>(Ws, std::move(OnWriteComplete)); 256 } 257 258 private: 259 template <typename WriteRPCFunction, typename WriteElementT> writeViaRPC(ArrayRef<WriteElementT> Ws,WriteResultFn OnWriteComplete)260 void writeViaRPC(ArrayRef<WriteElementT> Ws, WriteResultFn OnWriteComplete) { 261 if (auto Err = Parent.getEndpoint().template callAsync<WriteRPCFunction>( 262 [OWC = std::move(OnWriteComplete)](Error Err2) mutable -> Error { 263 OWC(std::move(Err2)); 264 return Error::success(); 265 }, 266 Ws)) { 267 Parent.reportError(std::move(Err)); 268 Parent.getEndpoint().abandonPendingResponses(); 269 } 270 } 271 272 OrcRPCEPCImplT &Parent; 273 }; 274 275 // ExecutorProcessControl for a process connected via an ORC RPC Endpoint. 276 template <typename RPCEndpointT> 277 class OrcRPCExecutorProcessControlBase : public ExecutorProcessControl { 278 public: 279 using ErrorReporter = unique_function<void(Error)>; 280 281 using OnCloseConnectionFunction = unique_function<Error(Error)>; 282 OrcRPCExecutorProcessControlBase(std::shared_ptr<SymbolStringPool> SSP,RPCEndpointT & EP,ErrorReporter ReportError)283 OrcRPCExecutorProcessControlBase(std::shared_ptr<SymbolStringPool> SSP, 284 RPCEndpointT &EP, ErrorReporter ReportError) 285 : ExecutorProcessControl(std::move(SSP)), 286 ReportError(std::move(ReportError)), EP(EP) { 287 using ThisT = OrcRPCExecutorProcessControlBase<RPCEndpointT>; 288 EP.template addAsyncHandler<orcrpctpc::RunWrapper>(*this, 289 &ThisT::runWrapperInJIT); 290 } 291 reportError(Error Err)292 void reportError(Error Err) { ReportError(std::move(Err)); } 293 getEndpoint()294 RPCEndpointT &getEndpoint() { return EP; } 295 loadDylib(const char * DylibPath)296 Expected<tpctypes::DylibHandle> loadDylib(const char *DylibPath) override { 297 DEBUG_WITH_TYPE("orc", { 298 dbgs() << "Loading dylib \"" << (DylibPath ? DylibPath : "") << "\" "; 299 if (!DylibPath) 300 dbgs() << "(process symbols)"; 301 dbgs() << "\n"; 302 }); 303 if (!DylibPath) 304 DylibPath = ""; 305 auto H = EP.template callB<orcrpctpc::LoadDylib>(DylibPath); 306 DEBUG_WITH_TYPE("orc", { 307 if (H) 308 dbgs() << " got handle " << formatv("{0:x16}", *H) << "\n"; 309 else 310 dbgs() << " error, unable to load\n"; 311 }); 312 return H; 313 } 314 315 Expected<std::vector<tpctypes::LookupResult>> lookupSymbols(ArrayRef<LookupRequest> Request)316 lookupSymbols(ArrayRef<LookupRequest> Request) override { 317 std::vector<orcrpctpc::RemoteLookupRequest> RR; 318 for (auto &E : Request) { 319 RR.push_back({}); 320 RR.back().first = E.Handle; 321 for (auto &KV : E.Symbols) 322 RR.back().second.push_back( 323 {(*KV.first).str(), 324 KV.second == SymbolLookupFlags::WeaklyReferencedSymbol}); 325 } 326 DEBUG_WITH_TYPE("orc", { 327 dbgs() << "Compound lookup:\n"; 328 for (auto &R : Request) { 329 dbgs() << " In " << formatv("{0:x16}", R.Handle) << ": {"; 330 bool First = true; 331 for (auto &KV : R.Symbols) { 332 dbgs() << (First ? "" : ",") << " " << *KV.first; 333 First = false; 334 } 335 dbgs() << " }\n"; 336 } 337 }); 338 return EP.template callB<orcrpctpc::LookupSymbols>(RR); 339 } 340 runAsMain(JITTargetAddress MainFnAddr,ArrayRef<std::string> Args)341 Expected<int32_t> runAsMain(JITTargetAddress MainFnAddr, 342 ArrayRef<std::string> Args) override { 343 DEBUG_WITH_TYPE("orc", { 344 dbgs() << "Running as main: " << formatv("{0:x16}", MainFnAddr) 345 << ", args = ["; 346 for (unsigned I = 0; I != Args.size(); ++I) 347 dbgs() << (I ? "," : "") << " \"" << Args[I] << "\""; 348 dbgs() << "]\n"; 349 }); 350 auto Result = EP.template callB<orcrpctpc::RunMain>(MainFnAddr, Args); 351 DEBUG_WITH_TYPE("orc", { 352 dbgs() << " call to " << formatv("{0:x16}", MainFnAddr); 353 if (Result) 354 dbgs() << " returned result " << *Result << "\n"; 355 else 356 dbgs() << " failed\n"; 357 }); 358 return Result; 359 } 360 callWrapperAsync(SendResultFunction OnComplete,JITTargetAddress WrapperFnAddr,ArrayRef<char> ArgBuffer)361 void callWrapperAsync(SendResultFunction OnComplete, 362 JITTargetAddress WrapperFnAddr, 363 ArrayRef<char> ArgBuffer) override { 364 DEBUG_WITH_TYPE("orc", { 365 dbgs() << "Running as wrapper function " 366 << formatv("{0:x16}", WrapperFnAddr) << " with " 367 << formatv("{0:x16}", ArgBuffer.size()) << " argument buffer\n"; 368 }); 369 auto Result = EP.template callB<orcrpctpc::RunWrapper>( 370 WrapperFnAddr, 371 ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(ArgBuffer.data()), 372 ArgBuffer.size())); 373 374 if (!Result) 375 OnComplete(shared::WrapperFunctionResult::createOutOfBandError( 376 toString(Result.takeError()))); 377 OnComplete(std::move(*Result)); 378 } 379 closeConnection(OnCloseConnectionFunction OnCloseConnection)380 Error closeConnection(OnCloseConnectionFunction OnCloseConnection) { 381 DEBUG_WITH_TYPE("orc", dbgs() << "Closing connection to remote\n"); 382 return EP.template callAsync<orcrpctpc::CloseConnection>( 383 std::move(OnCloseConnection)); 384 } 385 closeConnectionAndWait()386 Error closeConnectionAndWait() { 387 std::promise<MSVCPError> P; 388 auto F = P.get_future(); 389 if (auto Err = closeConnection([&](Error Err2) -> Error { 390 P.set_value(std::move(Err2)); 391 return Error::success(); 392 })) { 393 EP.abandonAllPendingResponses(); 394 return joinErrors(std::move(Err), F.get()); 395 } 396 return F.get(); 397 } 398 399 protected: 400 /// Subclasses must call this during construction to initialize the 401 /// TargetTriple and PageSize members. initializeORCRPCEPCBase()402 Error initializeORCRPCEPCBase() { 403 if (auto EPI = EP.template callB<orcrpctpc::GetExecutorProcessInfo>()) { 404 this->TargetTriple = Triple(EPI->Triple); 405 this->PageSize = PageSize; 406 this->JDI = {ExecutorAddress(EPI->DispatchFuncAddr), 407 ExecutorAddress(EPI->DispatchCtxAddr)}; 408 return Error::success(); 409 } else 410 return EPI.takeError(); 411 } 412 413 private: runWrapperInJIT(std::function<Error (Expected<shared::WrapperFunctionResult>)> SendResult,JITTargetAddress FunctionTag,std::vector<uint8_t> ArgBuffer)414 Error runWrapperInJIT( 415 std::function<Error(Expected<shared::WrapperFunctionResult>)> SendResult, 416 JITTargetAddress FunctionTag, std::vector<uint8_t> ArgBuffer) { 417 418 getExecutionSession().runJITDispatchHandler( 419 [this, SendResult = std::move(SendResult)]( 420 Expected<shared::WrapperFunctionResult> R) { 421 if (auto Err = SendResult(std::move(R))) 422 ReportError(std::move(Err)); 423 }, 424 FunctionTag, 425 {reinterpret_cast<const char *>(ArgBuffer.data()), ArgBuffer.size()}); 426 return Error::success(); 427 } 428 429 ErrorReporter ReportError; 430 RPCEndpointT &EP; 431 }; 432 433 } // end namespace orc 434 } // end namespace llvm 435 436 #endif // LLVM_EXECUTIONENGINE_ORC_ORCRPCEXECUTORPROCESSCONTROL_H 437