1 //===- WrapperFunctionUtils.h - Utilities for wrapper functions -*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // A buffer for serialized results. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_EXECUTIONENGINE_ORC_WRAPPERFUNCTIONUTILS_H 14 #define LLVM_EXECUTIONENGINE_ORC_WRAPPERFUNCTIONUTILS_H 15 16 #include "llvm/ExecutionEngine/Orc/Shared/SimplePackedSerialization.h" 17 #include "llvm/Support/Error.h" 18 19 #include <type_traits> 20 21 namespace llvm { 22 namespace orc { 23 namespace shared { 24 25 namespace detail { 26 27 // DO NOT USE DIRECTLY. 28 // Must be kept in-sync with compiler-rt/lib/orc/c-api.h. 29 union CWrapperFunctionResultDataUnion { 30 char *ValuePtr; 31 char Value[sizeof(ValuePtr)]; 32 }; 33 34 // DO NOT USE DIRECTLY. 35 // Must be kept in-sync with compiler-rt/lib/orc/c-api.h. 36 typedef struct { 37 CWrapperFunctionResultDataUnion Data; 38 size_t Size; 39 } CWrapperFunctionResult; 40 41 } // end namespace detail 42 43 /// C++ wrapper function result: Same as CWrapperFunctionResult but 44 /// auto-releases memory. 45 class WrapperFunctionResult { 46 public: 47 /// Create a default WrapperFunctionResult. 48 WrapperFunctionResult() { init(R); } 49 50 /// Create a WrapperFunctionResult by taking ownership of a 51 /// detail::CWrapperFunctionResult. 52 /// 53 /// Warning: This should only be used by clients writing wrapper-function 54 /// caller utilities (like TargetProcessControl). 55 WrapperFunctionResult(detail::CWrapperFunctionResult R) : R(R) { 56 // Reset R. 57 init(R); 58 } 59 60 WrapperFunctionResult(const WrapperFunctionResult &) = delete; 61 WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete; 62 63 WrapperFunctionResult(WrapperFunctionResult &&Other) { 64 init(R); 65 std::swap(R, Other.R); 66 } 67 68 WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) { 69 WrapperFunctionResult Tmp(std::move(Other)); 70 std::swap(R, Tmp.R); 71 return *this; 72 } 73 74 ~WrapperFunctionResult() { 75 if ((R.Size > sizeof(R.Data.Value)) || 76 (R.Size == 0 && R.Data.ValuePtr != nullptr)) 77 free(R.Data.ValuePtr); 78 } 79 80 /// Release ownership of the contained detail::CWrapperFunctionResult. 81 /// Warning: Do not use -- this method will be removed in the future. It only 82 /// exists to temporarily support some code that will eventually be moved to 83 /// the ORC runtime. 84 detail::CWrapperFunctionResult release() { 85 detail::CWrapperFunctionResult Tmp; 86 init(Tmp); 87 std::swap(R, Tmp); 88 return Tmp; 89 } 90 91 /// Get a pointer to the data contained in this instance. 92 const char *data() const { 93 assert((R.Size != 0 || R.Data.ValuePtr == nullptr) && 94 "Cannot get data for out-of-band error value"); 95 return R.Size > sizeof(R.Data.Value) ? R.Data.ValuePtr : R.Data.Value; 96 } 97 98 /// Returns the size of the data contained in this instance. 99 size_t size() const { 100 assert((R.Size != 0 || R.Data.ValuePtr == nullptr) && 101 "Cannot get data for out-of-band error value"); 102 return R.Size; 103 } 104 105 /// Returns true if this value is equivalent to a default-constructed 106 /// WrapperFunctionResult. 107 bool empty() const { return R.Size == 0 && R.Data.ValuePtr == nullptr; } 108 109 /// Create a WrapperFunctionResult with the given size and return a pointer 110 /// to the underlying memory. 111 static char *allocate(WrapperFunctionResult &WFR, size_t Size) { 112 // Reset. 113 WFR = WrapperFunctionResult(); 114 WFR.R.Size = Size; 115 char *DataPtr; 116 if (WFR.R.Size > sizeof(WFR.R.Data.Value)) { 117 DataPtr = (char *)malloc(WFR.R.Size); 118 WFR.R.Data.ValuePtr = DataPtr; 119 } else 120 DataPtr = WFR.R.Data.Value; 121 return DataPtr; 122 } 123 124 /// Copy from the given char range. 125 static WrapperFunctionResult copyFrom(const char *Source, size_t Size) { 126 WrapperFunctionResult WFR; 127 char *DataPtr = allocate(WFR, Size); 128 memcpy(DataPtr, Source, Size); 129 return WFR; 130 } 131 132 /// Copy from the given null-terminated string (includes the null-terminator). 133 static WrapperFunctionResult copyFrom(const char *Source) { 134 return copyFrom(Source, strlen(Source) + 1); 135 } 136 137 /// Copy from the given std::string (includes the null terminator). 138 static WrapperFunctionResult copyFrom(const std::string &Source) { 139 return copyFrom(Source.c_str()); 140 } 141 142 /// Create an out-of-band error by copying the given string. 143 static WrapperFunctionResult createOutOfBandError(const char *Msg) { 144 // Reset. 145 WrapperFunctionResult WFR; 146 char *Tmp = (char *)malloc(strlen(Msg) + 1); 147 strcpy(Tmp, Msg); 148 WFR.R.Data.ValuePtr = Tmp; 149 return WFR; 150 } 151 152 /// Create an out-of-band error by copying the given string. 153 static WrapperFunctionResult createOutOfBandError(const std::string &Msg) { 154 return createOutOfBandError(Msg.c_str()); 155 } 156 157 /// If this value is an out-of-band error then this returns the error message, 158 /// otherwise returns nullptr. 159 const char *getOutOfBandError() const { 160 return R.Size == 0 ? R.Data.ValuePtr : nullptr; 161 } 162 163 private: 164 static void init(detail::CWrapperFunctionResult &R) { 165 R.Data.ValuePtr = nullptr; 166 R.Size = 0; 167 } 168 169 detail::CWrapperFunctionResult R; 170 }; 171 172 namespace detail { 173 174 template <typename SPSArgListT, typename... ArgTs> 175 WrapperFunctionResult 176 serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) { 177 WrapperFunctionResult Result; 178 char *DataPtr = 179 WrapperFunctionResult::allocate(Result, SPSArgListT::size(Args...)); 180 SPSOutputBuffer OB(DataPtr, Result.size()); 181 if (!SPSArgListT::serialize(OB, Args...)) 182 return WrapperFunctionResult::createOutOfBandError( 183 "Error serializing arguments to blob in call"); 184 return Result; 185 } 186 187 template <typename RetT> class WrapperFunctionHandlerCaller { 188 public: 189 template <typename HandlerT, typename ArgTupleT, std::size_t... I> 190 static decltype(auto) call(HandlerT &&H, ArgTupleT &Args, 191 std::index_sequence<I...>) { 192 return std::forward<HandlerT>(H)(std::get<I>(Args)...); 193 } 194 }; 195 196 template <> class WrapperFunctionHandlerCaller<void> { 197 public: 198 template <typename HandlerT, typename ArgTupleT, std::size_t... I> 199 static SPSEmpty call(HandlerT &&H, ArgTupleT &Args, 200 std::index_sequence<I...>) { 201 std::forward<HandlerT>(H)(std::get<I>(Args)...); 202 return SPSEmpty(); 203 } 204 }; 205 206 template <typename WrapperFunctionImplT, 207 template <typename> class ResultSerializer, typename... SPSTagTs> 208 class WrapperFunctionHandlerHelper 209 : public WrapperFunctionHandlerHelper< 210 decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()), 211 ResultSerializer, SPSTagTs...> {}; 212 213 template <typename RetT, typename... ArgTs, 214 template <typename> class ResultSerializer, typename... SPSTagTs> 215 class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 216 SPSTagTs...> { 217 public: 218 using ArgTuple = std::tuple<std::decay_t<ArgTs>...>; 219 using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>; 220 221 template <typename HandlerT> 222 static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData, 223 size_t ArgSize) { 224 ArgTuple Args; 225 if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) 226 return WrapperFunctionResult::createOutOfBandError( 227 "Could not deserialize arguments for wrapper function call"); 228 229 auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call( 230 std::forward<HandlerT>(H), Args, ArgIndices{}); 231 232 return ResultSerializer<decltype(HandlerResult)>::serialize( 233 std::move(HandlerResult)); 234 } 235 236 private: 237 template <std::size_t... I> 238 static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args, 239 std::index_sequence<I...>) { 240 SPSInputBuffer IB(ArgData, ArgSize); 241 return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...); 242 } 243 }; 244 245 // Map function pointers to function types. 246 template <typename RetT, typename... ArgTs, 247 template <typename> class ResultSerializer, typename... SPSTagTs> 248 class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer, 249 SPSTagTs...> 250 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 251 SPSTagTs...> {}; 252 253 // Map non-const member function types to function types. 254 template <typename ClassT, typename RetT, typename... ArgTs, 255 template <typename> class ResultSerializer, typename... SPSTagTs> 256 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer, 257 SPSTagTs...> 258 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 259 SPSTagTs...> {}; 260 261 // Map const member function types to function types. 262 template <typename ClassT, typename RetT, typename... ArgTs, 263 template <typename> class ResultSerializer, typename... SPSTagTs> 264 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const, 265 ResultSerializer, SPSTagTs...> 266 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 267 SPSTagTs...> {}; 268 269 template <typename WrapperFunctionImplT, 270 template <typename> class ResultSerializer, typename... SPSTagTs> 271 class WrapperFunctionAsyncHandlerHelper 272 : public WrapperFunctionAsyncHandlerHelper< 273 decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()), 274 ResultSerializer, SPSTagTs...> {}; 275 276 template <typename RetT, typename SendResultT, typename... ArgTs, 277 template <typename> class ResultSerializer, typename... SPSTagTs> 278 class WrapperFunctionAsyncHandlerHelper<RetT(SendResultT, ArgTs...), 279 ResultSerializer, SPSTagTs...> { 280 public: 281 using ArgTuple = std::tuple<std::decay_t<ArgTs>...>; 282 using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>; 283 284 template <typename HandlerT, typename SendWrapperFunctionResultT> 285 static void applyAsync(HandlerT &&H, 286 SendWrapperFunctionResultT &&SendWrapperFunctionResult, 287 const char *ArgData, size_t ArgSize) { 288 ArgTuple Args; 289 if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) { 290 SendWrapperFunctionResult(WrapperFunctionResult::createOutOfBandError( 291 "Could not deserialize arguments for wrapper function call")); 292 return; 293 } 294 295 auto SendResult = 296 [SendWFR = std::move(SendWrapperFunctionResult)](auto Result) mutable { 297 using ResultT = decltype(Result); 298 SendWFR(ResultSerializer<ResultT>::serialize(std::move(Result))); 299 }; 300 301 callAsync(std::forward<HandlerT>(H), std::move(SendResult), std::move(Args), 302 ArgIndices{}); 303 } 304 305 private: 306 template <std::size_t... I> 307 static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args, 308 std::index_sequence<I...>) { 309 SPSInputBuffer IB(ArgData, ArgSize); 310 return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...); 311 } 312 313 template <typename HandlerT, typename SerializeAndSendResultT, 314 typename ArgTupleT, std::size_t... I> 315 static void callAsync(HandlerT &&H, 316 SerializeAndSendResultT &&SerializeAndSendResult, 317 ArgTupleT Args, std::index_sequence<I...>) { 318 return std::forward<HandlerT>(H)(std::move(SerializeAndSendResult), 319 std::move(std::get<I>(Args))...); 320 } 321 }; 322 323 // Map function pointers to function types. 324 template <typename RetT, typename... ArgTs, 325 template <typename> class ResultSerializer, typename... SPSTagTs> 326 class WrapperFunctionAsyncHandlerHelper<RetT (*)(ArgTs...), ResultSerializer, 327 SPSTagTs...> 328 : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer, 329 SPSTagTs...> {}; 330 331 // Map non-const member function types to function types. 332 template <typename ClassT, typename RetT, typename... ArgTs, 333 template <typename> class ResultSerializer, typename... SPSTagTs> 334 class WrapperFunctionAsyncHandlerHelper<RetT (ClassT::*)(ArgTs...), 335 ResultSerializer, SPSTagTs...> 336 : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer, 337 SPSTagTs...> {}; 338 339 // Map const member function types to function types. 340 template <typename ClassT, typename RetT, typename... ArgTs, 341 template <typename> class ResultSerializer, typename... SPSTagTs> 342 class WrapperFunctionAsyncHandlerHelper<RetT (ClassT::*)(ArgTs...) const, 343 ResultSerializer, SPSTagTs...> 344 : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer, 345 SPSTagTs...> {}; 346 347 template <typename SPSRetTagT, typename RetT> class ResultSerializer { 348 public: 349 static WrapperFunctionResult serialize(RetT Result) { 350 return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( 351 Result); 352 } 353 }; 354 355 template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> { 356 public: 357 static WrapperFunctionResult serialize(Error Err) { 358 return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( 359 toSPSSerializable(std::move(Err))); 360 } 361 }; 362 363 template <typename SPSRetTagT, typename T> 364 class ResultSerializer<SPSRetTagT, Expected<T>> { 365 public: 366 static WrapperFunctionResult serialize(Expected<T> E) { 367 return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( 368 toSPSSerializable(std::move(E))); 369 } 370 }; 371 372 template <typename SPSRetTagT, typename RetT> class ResultDeserializer { 373 public: 374 static RetT makeValue() { return RetT(); } 375 static void makeSafe(RetT &Result) {} 376 377 static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) { 378 SPSInputBuffer IB(ArgData, ArgSize); 379 if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result)) 380 return make_error<StringError>( 381 "Error deserializing return value from blob in call", 382 inconvertibleErrorCode()); 383 return Error::success(); 384 } 385 }; 386 387 template <> class ResultDeserializer<SPSError, Error> { 388 public: 389 static Error makeValue() { return Error::success(); } 390 static void makeSafe(Error &Err) { cantFail(std::move(Err)); } 391 392 static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) { 393 SPSInputBuffer IB(ArgData, ArgSize); 394 SPSSerializableError BSE; 395 if (!SPSArgList<SPSError>::deserialize(IB, BSE)) 396 return make_error<StringError>( 397 "Error deserializing return value from blob in call", 398 inconvertibleErrorCode()); 399 Err = fromSPSSerializable(std::move(BSE)); 400 return Error::success(); 401 } 402 }; 403 404 template <typename SPSTagT, typename T> 405 class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> { 406 public: 407 static Expected<T> makeValue() { return T(); } 408 static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); } 409 410 static Error deserialize(Expected<T> &E, const char *ArgData, 411 size_t ArgSize) { 412 SPSInputBuffer IB(ArgData, ArgSize); 413 SPSSerializableExpected<T> BSE; 414 if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE)) 415 return make_error<StringError>( 416 "Error deserializing return value from blob in call", 417 inconvertibleErrorCode()); 418 E = fromSPSSerializable(std::move(BSE)); 419 return Error::success(); 420 } 421 }; 422 423 template <typename SPSRetTagT, typename RetT> class AsyncCallResultHelper { 424 // Did you forget to use Error / Expected in your handler? 425 }; 426 427 } // end namespace detail 428 429 template <typename SPSSignature> class WrapperFunction; 430 431 template <typename SPSRetTagT, typename... SPSTagTs> 432 class WrapperFunction<SPSRetTagT(SPSTagTs...)> { 433 private: 434 template <typename RetT> 435 using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>; 436 437 public: 438 /// Call a wrapper function. Caller should be callable as 439 /// WrapperFunctionResult Fn(const char *ArgData, size_t ArgSize); 440 template <typename CallerFn, typename RetT, typename... ArgTs> 441 static Error call(const CallerFn &Caller, RetT &Result, 442 const ArgTs &...Args) { 443 444 // RetT might be an Error or Expected value. Set the checked flag now: 445 // we don't want the user to have to check the unused result if this 446 // operation fails. 447 detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result); 448 449 auto ArgBuffer = 450 detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>( 451 Args...); 452 if (const char *ErrMsg = ArgBuffer.getOutOfBandError()) 453 return make_error<StringError>(ErrMsg, inconvertibleErrorCode()); 454 455 WrapperFunctionResult ResultBuffer = 456 Caller(ArgBuffer.data(), ArgBuffer.size()); 457 if (auto ErrMsg = ResultBuffer.getOutOfBandError()) 458 return make_error<StringError>(ErrMsg, inconvertibleErrorCode()); 459 460 return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize( 461 Result, ResultBuffer.data(), ResultBuffer.size()); 462 } 463 464 /// Call an async wrapper function. 465 /// Caller should be callable as 466 /// void Fn(unique_function<void(WrapperFunctionResult)> SendResult, 467 /// WrapperFunctionResult ArgBuffer); 468 template <typename AsyncCallerFn, typename SendDeserializedResultFn, 469 typename... ArgTs> 470 static void callAsync(AsyncCallerFn &&Caller, 471 SendDeserializedResultFn &&SendDeserializedResult, 472 const ArgTs &...Args) { 473 using RetT = typename std::tuple_element< 474 1, typename detail::WrapperFunctionHandlerHelper< 475 std::remove_reference_t<SendDeserializedResultFn>, 476 ResultSerializer, SPSRetTagT>::ArgTuple>::type; 477 478 auto ArgBuffer = 479 detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>( 480 Args...); 481 if (auto *ErrMsg = ArgBuffer.getOutOfBandError()) { 482 SendDeserializedResult( 483 make_error<StringError>(ErrMsg, inconvertibleErrorCode()), 484 detail::ResultDeserializer<SPSRetTagT, RetT>::makeValue()); 485 return; 486 } 487 488 auto SendSerializedResult = [SDR = std::move(SendDeserializedResult)]( 489 WrapperFunctionResult R) { 490 RetT RetVal = detail::ResultDeserializer<SPSRetTagT, RetT>::makeValue(); 491 detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(RetVal); 492 493 SPSInputBuffer IB(R.data(), R.size()); 494 if (auto Err = detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize( 495 RetVal, R.data(), R.size())) 496 SDR(std::move(Err), std::move(RetVal)); 497 498 SDR(Error::success(), std::move(RetVal)); 499 }; 500 501 Caller(std::move(SendSerializedResult), ArgBuffer.data(), ArgBuffer.size()); 502 } 503 504 /// Handle a call to a wrapper function. 505 template <typename HandlerT> 506 static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize, 507 HandlerT &&Handler) { 508 using WFHH = 509 detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>, 510 ResultSerializer, SPSTagTs...>; 511 return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize); 512 } 513 514 /// Handle a call to an async wrapper function. 515 template <typename HandlerT, typename SendResultT> 516 static void handleAsync(const char *ArgData, size_t ArgSize, 517 HandlerT &&Handler, SendResultT &&SendResult) { 518 using WFAHH = detail::WrapperFunctionAsyncHandlerHelper< 519 std::remove_reference_t<HandlerT>, ResultSerializer, SPSTagTs...>; 520 WFAHH::applyAsync(std::forward<HandlerT>(Handler), 521 std::forward<SendResultT>(SendResult), ArgData, ArgSize); 522 } 523 524 private: 525 template <typename T> static const T &makeSerializable(const T &Value) { 526 return Value; 527 } 528 529 static detail::SPSSerializableError makeSerializable(Error Err) { 530 return detail::toSPSSerializable(std::move(Err)); 531 } 532 533 template <typename T> 534 static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) { 535 return detail::toSPSSerializable(std::move(E)); 536 } 537 }; 538 539 template <typename... SPSTagTs> 540 class WrapperFunction<void(SPSTagTs...)> 541 : private WrapperFunction<SPSEmpty(SPSTagTs...)> { 542 543 public: 544 template <typename CallerFn, typename... ArgTs> 545 static Error call(const CallerFn &Caller, const ArgTs &...Args) { 546 SPSEmpty BE; 547 return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(Caller, BE, Args...); 548 } 549 550 using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle; 551 using WrapperFunction<SPSEmpty(SPSTagTs...)>::handleAsync; 552 }; 553 554 } // end namespace shared 555 } // end namespace orc 556 } // end namespace llvm 557 558 #endif // LLVM_EXECUTIONENGINE_ORC_WRAPPERFUNCTIONUTILS_H 559