1 //===- WrapperFunctionUtils.h - Utilities for wrapper functions -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A buffer for serialized results.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_WRAPPERFUNCTIONUTILS_H
14 #define LLVM_EXECUTIONENGINE_ORC_SHARED_WRAPPERFUNCTIONUTILS_H
15
16 #include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h"
17 #include "llvm/ExecutionEngine/Orc/Shared/SimplePackedSerialization.h"
18 #include "llvm/Support/Error.h"
19
20 #include <type_traits>
21
22 namespace llvm {
23 namespace orc {
24 namespace shared {
25
26 // Must be kept in-sync with compiler-rt/lib/orc/c-api.h.
27 union CWrapperFunctionResultDataUnion {
28 char *ValuePtr;
29 char Value[sizeof(ValuePtr)];
30 };
31
32 // Must be kept in-sync with compiler-rt/lib/orc/c-api.h.
33 typedef struct {
34 CWrapperFunctionResultDataUnion Data;
35 size_t Size;
36 } CWrapperFunctionResult;
37
38 /// C++ wrapper function result: Same as CWrapperFunctionResult but
39 /// auto-releases memory.
40 class WrapperFunctionResult {
41 public:
42 /// Create a default WrapperFunctionResult.
WrapperFunctionResult()43 WrapperFunctionResult() { init(R); }
44
45 /// Create a WrapperFunctionResult by taking ownership of a
46 /// CWrapperFunctionResult.
47 ///
48 /// Warning: This should only be used by clients writing wrapper-function
49 /// caller utilities (like TargetProcessControl).
WrapperFunctionResult(CWrapperFunctionResult R)50 WrapperFunctionResult(CWrapperFunctionResult R) : R(R) {
51 // Reset R.
52 init(R);
53 }
54
55 WrapperFunctionResult(const WrapperFunctionResult &) = delete;
56 WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete;
57
WrapperFunctionResult(WrapperFunctionResult && Other)58 WrapperFunctionResult(WrapperFunctionResult &&Other) {
59 init(R);
60 std::swap(R, Other.R);
61 }
62
63 WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) {
64 WrapperFunctionResult Tmp(std::move(Other));
65 std::swap(R, Tmp.R);
66 return *this;
67 }
68
~WrapperFunctionResult()69 ~WrapperFunctionResult() {
70 if ((R.Size > sizeof(R.Data.Value)) ||
71 (R.Size == 0 && R.Data.ValuePtr != nullptr))
72 free(R.Data.ValuePtr);
73 }
74
75 /// Release ownership of the contained CWrapperFunctionResult.
76 /// Warning: Do not use -- this method will be removed in the future. It only
77 /// exists to temporarily support some code that will eventually be moved to
78 /// the ORC runtime.
release()79 CWrapperFunctionResult release() {
80 CWrapperFunctionResult Tmp;
81 init(Tmp);
82 std::swap(R, Tmp);
83 return Tmp;
84 }
85
86 /// Get a pointer to the data contained in this instance.
data()87 char *data() {
88 assert((R.Size != 0 || R.Data.ValuePtr == nullptr) &&
89 "Cannot get data for out-of-band error value");
90 return R.Size > sizeof(R.Data.Value) ? R.Data.ValuePtr : R.Data.Value;
91 }
92
93 /// Get a const pointer to the data contained in this instance.
data()94 const char *data() const {
95 assert((R.Size != 0 || R.Data.ValuePtr == nullptr) &&
96 "Cannot get data for out-of-band error value");
97 return R.Size > sizeof(R.Data.Value) ? R.Data.ValuePtr : R.Data.Value;
98 }
99
100 /// Returns the size of the data contained in this instance.
size()101 size_t size() const {
102 assert((R.Size != 0 || R.Data.ValuePtr == nullptr) &&
103 "Cannot get data for out-of-band error value");
104 return R.Size;
105 }
106
107 /// Returns true if this value is equivalent to a default-constructed
108 /// WrapperFunctionResult.
empty()109 bool empty() const { return R.Size == 0 && R.Data.ValuePtr == nullptr; }
110
111 /// Create a WrapperFunctionResult with the given size and return a pointer
112 /// to the underlying memory.
allocate(size_t Size)113 static WrapperFunctionResult allocate(size_t Size) {
114 // Reset.
115 WrapperFunctionResult WFR;
116 WFR.R.Size = Size;
117 if (WFR.R.Size > sizeof(WFR.R.Data.Value))
118 WFR.R.Data.ValuePtr = (char *)malloc(WFR.R.Size);
119 return WFR;
120 }
121
122 /// Copy from the given char range.
copyFrom(const char * Source,size_t Size)123 static WrapperFunctionResult copyFrom(const char *Source, size_t Size) {
124 auto WFR = allocate(Size);
125 memcpy(WFR.data(), Source, Size);
126 return WFR;
127 }
128
129 /// Copy from the given null-terminated string (includes the null-terminator).
copyFrom(const char * Source)130 static WrapperFunctionResult copyFrom(const char *Source) {
131 return copyFrom(Source, strlen(Source) + 1);
132 }
133
134 /// Copy from the given std::string (includes the null terminator).
copyFrom(const std::string & Source)135 static WrapperFunctionResult copyFrom(const std::string &Source) {
136 return copyFrom(Source.c_str());
137 }
138
139 /// Create an out-of-band error by copying the given string.
createOutOfBandError(const char * Msg)140 static WrapperFunctionResult createOutOfBandError(const char *Msg) {
141 // Reset.
142 WrapperFunctionResult WFR;
143 char *Tmp = (char *)malloc(strlen(Msg) + 1);
144 strcpy(Tmp, Msg);
145 WFR.R.Data.ValuePtr = Tmp;
146 return WFR;
147 }
148
149 /// Create an out-of-band error by copying the given string.
createOutOfBandError(const std::string & Msg)150 static WrapperFunctionResult createOutOfBandError(const std::string &Msg) {
151 return createOutOfBandError(Msg.c_str());
152 }
153
154 /// If this value is an out-of-band error then this returns the error message,
155 /// otherwise returns nullptr.
getOutOfBandError()156 const char *getOutOfBandError() const {
157 return R.Size == 0 ? R.Data.ValuePtr : nullptr;
158 }
159
160 private:
init(CWrapperFunctionResult & R)161 static void init(CWrapperFunctionResult &R) {
162 R.Data.ValuePtr = nullptr;
163 R.Size = 0;
164 }
165
166 CWrapperFunctionResult R;
167 };
168
169 namespace detail {
170
171 template <typename SPSArgListT, typename... ArgTs>
172 WrapperFunctionResult
serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args)173 serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) {
174 auto Result = WrapperFunctionResult::allocate(SPSArgListT::size(Args...));
175 SPSOutputBuffer OB(Result.data(), Result.size());
176 if (!SPSArgListT::serialize(OB, Args...))
177 return WrapperFunctionResult::createOutOfBandError(
178 "Error serializing arguments to blob in call");
179 return Result;
180 }
181
182 template <typename RetT> class WrapperFunctionHandlerCaller {
183 public:
184 template <typename HandlerT, typename ArgTupleT, std::size_t... I>
decltype(auto)185 static decltype(auto) call(HandlerT &&H, ArgTupleT &Args,
186 std::index_sequence<I...>) {
187 return std::forward<HandlerT>(H)(std::get<I>(Args)...);
188 }
189 };
190
191 template <> class WrapperFunctionHandlerCaller<void> {
192 public:
193 template <typename HandlerT, typename ArgTupleT, std::size_t... I>
call(HandlerT && H,ArgTupleT & Args,std::index_sequence<I...>)194 static SPSEmpty call(HandlerT &&H, ArgTupleT &Args,
195 std::index_sequence<I...>) {
196 std::forward<HandlerT>(H)(std::get<I>(Args)...);
197 return SPSEmpty();
198 }
199 };
200
201 template <typename WrapperFunctionImplT,
202 template <typename> class ResultSerializer, typename... SPSTagTs>
203 class WrapperFunctionHandlerHelper
204 : public WrapperFunctionHandlerHelper<
205 decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
206 ResultSerializer, SPSTagTs...> {};
207
208 template <typename RetT, typename... ArgTs,
209 template <typename> class ResultSerializer, typename... SPSTagTs>
210 class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
211 SPSTagTs...> {
212 public:
213 using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
214 using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
215
216 template <typename HandlerT>
apply(HandlerT && H,const char * ArgData,size_t ArgSize)217 static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData,
218 size_t ArgSize) {
219 ArgTuple Args;
220 if (!deserialize(ArgData, ArgSize, Args, ArgIndices{}))
221 return WrapperFunctionResult::createOutOfBandError(
222 "Could not deserialize arguments for wrapper function call");
223
224 auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call(
225 std::forward<HandlerT>(H), Args, ArgIndices{});
226
227 return ResultSerializer<decltype(HandlerResult)>::serialize(
228 std::move(HandlerResult));
229 }
230
231 private:
232 template <std::size_t... I>
deserialize(const char * ArgData,size_t ArgSize,ArgTuple & Args,std::index_sequence<I...>)233 static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
234 std::index_sequence<I...>) {
235 SPSInputBuffer IB(ArgData, ArgSize);
236 return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
237 }
238 };
239
240 // Map function pointers to function types.
241 template <typename RetT, typename... ArgTs,
242 template <typename> class ResultSerializer, typename... SPSTagTs>
243 class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
244 SPSTagTs...>
245 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
246 SPSTagTs...> {};
247
248 // Map non-const member function types to function types.
249 template <typename ClassT, typename RetT, typename... ArgTs,
250 template <typename> class ResultSerializer, typename... SPSTagTs>
251 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer,
252 SPSTagTs...>
253 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
254 SPSTagTs...> {};
255
256 // Map const member function types to function types.
257 template <typename ClassT, typename RetT, typename... ArgTs,
258 template <typename> class ResultSerializer, typename... SPSTagTs>
259 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
260 ResultSerializer, SPSTagTs...>
261 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
262 SPSTagTs...> {};
263
264 template <typename WrapperFunctionImplT,
265 template <typename> class ResultSerializer, typename... SPSTagTs>
266 class WrapperFunctionAsyncHandlerHelper
267 : public WrapperFunctionAsyncHandlerHelper<
268 decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
269 ResultSerializer, SPSTagTs...> {};
270
271 template <typename RetT, typename SendResultT, typename... ArgTs,
272 template <typename> class ResultSerializer, typename... SPSTagTs>
273 class WrapperFunctionAsyncHandlerHelper<RetT(SendResultT, ArgTs...),
274 ResultSerializer, SPSTagTs...> {
275 public:
276 using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
277 using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
278
279 template <typename HandlerT, typename SendWrapperFunctionResultT>
applyAsync(HandlerT && H,SendWrapperFunctionResultT && SendWrapperFunctionResult,const char * ArgData,size_t ArgSize)280 static void applyAsync(HandlerT &&H,
281 SendWrapperFunctionResultT &&SendWrapperFunctionResult,
282 const char *ArgData, size_t ArgSize) {
283 ArgTuple Args;
284 if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) {
285 SendWrapperFunctionResult(WrapperFunctionResult::createOutOfBandError(
286 "Could not deserialize arguments for wrapper function call"));
287 return;
288 }
289
290 auto SendResult =
291 [SendWFR = std::move(SendWrapperFunctionResult)](auto Result) mutable {
292 using ResultT = decltype(Result);
293 SendWFR(ResultSerializer<ResultT>::serialize(std::move(Result)));
294 };
295
296 callAsync(std::forward<HandlerT>(H), std::move(SendResult), std::move(Args),
297 ArgIndices{});
298 }
299
300 private:
301 template <std::size_t... I>
deserialize(const char * ArgData,size_t ArgSize,ArgTuple & Args,std::index_sequence<I...>)302 static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
303 std::index_sequence<I...>) {
304 SPSInputBuffer IB(ArgData, ArgSize);
305 return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
306 }
307
308 template <typename HandlerT, typename SerializeAndSendResultT,
309 typename ArgTupleT, std::size_t... I>
callAsync(HandlerT && H,SerializeAndSendResultT && SerializeAndSendResult,ArgTupleT Args,std::index_sequence<I...>)310 static void callAsync(HandlerT &&H,
311 SerializeAndSendResultT &&SerializeAndSendResult,
312 ArgTupleT Args, std::index_sequence<I...>) {
313 (void)Args; // Silence a buggy GCC warning.
314 return std::forward<HandlerT>(H)(std::move(SerializeAndSendResult),
315 std::move(std::get<I>(Args))...);
316 }
317 };
318
319 // Map function pointers to function types.
320 template <typename RetT, typename... ArgTs,
321 template <typename> class ResultSerializer, typename... SPSTagTs>
322 class WrapperFunctionAsyncHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
323 SPSTagTs...>
324 : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
325 SPSTagTs...> {};
326
327 // Map non-const member function types to function types.
328 template <typename ClassT, typename RetT, typename... ArgTs,
329 template <typename> class ResultSerializer, typename... SPSTagTs>
330 class WrapperFunctionAsyncHandlerHelper<RetT (ClassT::*)(ArgTs...),
331 ResultSerializer, SPSTagTs...>
332 : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
333 SPSTagTs...> {};
334
335 // Map const member function types to function types.
336 template <typename ClassT, typename RetT, typename... ArgTs,
337 template <typename> class ResultSerializer, typename... SPSTagTs>
338 class WrapperFunctionAsyncHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
339 ResultSerializer, SPSTagTs...>
340 : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
341 SPSTagTs...> {};
342
343 template <typename SPSRetTagT, typename RetT> class ResultSerializer {
344 public:
serialize(RetT Result)345 static WrapperFunctionResult serialize(RetT Result) {
346 return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
347 Result);
348 }
349 };
350
351 template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
352 public:
serialize(Error Err)353 static WrapperFunctionResult serialize(Error Err) {
354 return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
355 toSPSSerializable(std::move(Err)));
356 }
357 };
358
359 template <typename SPSRetTagT>
360 class ResultSerializer<SPSRetTagT, ErrorSuccess> {
361 public:
serialize(ErrorSuccess Err)362 static WrapperFunctionResult serialize(ErrorSuccess Err) {
363 return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
364 toSPSSerializable(std::move(Err)));
365 }
366 };
367
368 template <typename SPSRetTagT, typename T>
369 class ResultSerializer<SPSRetTagT, Expected<T>> {
370 public:
serialize(Expected<T> E)371 static WrapperFunctionResult serialize(Expected<T> E) {
372 return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
373 toSPSSerializable(std::move(E)));
374 }
375 };
376
377 template <typename SPSRetTagT, typename RetT> class ResultDeserializer {
378 public:
makeValue()379 static RetT makeValue() { return RetT(); }
makeSafe(RetT & Result)380 static void makeSafe(RetT &Result) {}
381
deserialize(RetT & Result,const char * ArgData,size_t ArgSize)382 static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) {
383 SPSInputBuffer IB(ArgData, ArgSize);
384 if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result))
385 return make_error<StringError>(
386 "Error deserializing return value from blob in call",
387 inconvertibleErrorCode());
388 return Error::success();
389 }
390 };
391
392 template <> class ResultDeserializer<SPSError, Error> {
393 public:
makeValue()394 static Error makeValue() { return Error::success(); }
makeSafe(Error & Err)395 static void makeSafe(Error &Err) { cantFail(std::move(Err)); }
396
deserialize(Error & Err,const char * ArgData,size_t ArgSize)397 static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) {
398 SPSInputBuffer IB(ArgData, ArgSize);
399 SPSSerializableError BSE;
400 if (!SPSArgList<SPSError>::deserialize(IB, BSE))
401 return make_error<StringError>(
402 "Error deserializing return value from blob in call",
403 inconvertibleErrorCode());
404 Err = fromSPSSerializable(std::move(BSE));
405 return Error::success();
406 }
407 };
408
409 template <typename SPSTagT, typename T>
410 class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> {
411 public:
makeValue()412 static Expected<T> makeValue() { return T(); }
makeSafe(Expected<T> & E)413 static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); }
414
deserialize(Expected<T> & E,const char * ArgData,size_t ArgSize)415 static Error deserialize(Expected<T> &E, const char *ArgData,
416 size_t ArgSize) {
417 SPSInputBuffer IB(ArgData, ArgSize);
418 SPSSerializableExpected<T> BSE;
419 if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE))
420 return make_error<StringError>(
421 "Error deserializing return value from blob in call",
422 inconvertibleErrorCode());
423 E = fromSPSSerializable(std::move(BSE));
424 return Error::success();
425 }
426 };
427
428 template <typename SPSRetTagT, typename RetT> class AsyncCallResultHelper {
429 // Did you forget to use Error / Expected in your handler?
430 };
431
432 } // end namespace detail
433
434 template <typename SPSSignature> class WrapperFunction;
435
436 template <typename SPSRetTagT, typename... SPSTagTs>
437 class WrapperFunction<SPSRetTagT(SPSTagTs...)> {
438 private:
439 template <typename RetT>
440 using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>;
441
442 public:
443 /// Call a wrapper function. Caller should be callable as
444 /// WrapperFunctionResult Fn(const char *ArgData, size_t ArgSize);
445 template <typename CallerFn, typename RetT, typename... ArgTs>
call(const CallerFn & Caller,RetT & Result,const ArgTs &...Args)446 static Error call(const CallerFn &Caller, RetT &Result,
447 const ArgTs &...Args) {
448
449 // RetT might be an Error or Expected value. Set the checked flag now:
450 // we don't want the user to have to check the unused result if this
451 // operation fails.
452 detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result);
453
454 auto ArgBuffer =
455 detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
456 Args...);
457 if (const char *ErrMsg = ArgBuffer.getOutOfBandError())
458 return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
459
460 WrapperFunctionResult ResultBuffer =
461 Caller(ArgBuffer.data(), ArgBuffer.size());
462 if (auto ErrMsg = ResultBuffer.getOutOfBandError())
463 return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
464
465 return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
466 Result, ResultBuffer.data(), ResultBuffer.size());
467 }
468
469 /// Call an async wrapper function.
470 /// Caller should be callable as
471 /// void Fn(unique_function<void(WrapperFunctionResult)> SendResult,
472 /// WrapperFunctionResult ArgBuffer);
473 template <typename AsyncCallerFn, typename SendDeserializedResultFn,
474 typename... ArgTs>
callAsync(AsyncCallerFn && Caller,SendDeserializedResultFn && SendDeserializedResult,const ArgTs &...Args)475 static void callAsync(AsyncCallerFn &&Caller,
476 SendDeserializedResultFn &&SendDeserializedResult,
477 const ArgTs &...Args) {
478 using RetT = typename std::tuple_element<
479 1, typename detail::WrapperFunctionHandlerHelper<
480 std::remove_reference_t<SendDeserializedResultFn>,
481 ResultSerializer, SPSRetTagT>::ArgTuple>::type;
482
483 auto ArgBuffer =
484 detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
485 Args...);
486 if (auto *ErrMsg = ArgBuffer.getOutOfBandError()) {
487 SendDeserializedResult(
488 make_error<StringError>(ErrMsg, inconvertibleErrorCode()),
489 detail::ResultDeserializer<SPSRetTagT, RetT>::makeValue());
490 return;
491 }
492
493 auto SendSerializedResult = [SDR = std::move(SendDeserializedResult)](
494 WrapperFunctionResult R) mutable {
495 RetT RetVal = detail::ResultDeserializer<SPSRetTagT, RetT>::makeValue();
496 detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(RetVal);
497
498 if (auto *ErrMsg = R.getOutOfBandError()) {
499 SDR(make_error<StringError>(ErrMsg, inconvertibleErrorCode()),
500 std::move(RetVal));
501 return;
502 }
503
504 SPSInputBuffer IB(R.data(), R.size());
505 if (auto Err = detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
506 RetVal, R.data(), R.size()))
507 SDR(std::move(Err), std::move(RetVal));
508
509 SDR(Error::success(), std::move(RetVal));
510 };
511
512 Caller(std::move(SendSerializedResult), ArgBuffer.data(), ArgBuffer.size());
513 }
514
515 /// Handle a call to a wrapper function.
516 template <typename HandlerT>
handle(const char * ArgData,size_t ArgSize,HandlerT && Handler)517 static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize,
518 HandlerT &&Handler) {
519 using WFHH =
520 detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>,
521 ResultSerializer, SPSTagTs...>;
522 return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize);
523 }
524
525 /// Handle a call to an async wrapper function.
526 template <typename HandlerT, typename SendResultT>
handleAsync(const char * ArgData,size_t ArgSize,HandlerT && Handler,SendResultT && SendResult)527 static void handleAsync(const char *ArgData, size_t ArgSize,
528 HandlerT &&Handler, SendResultT &&SendResult) {
529 using WFAHH = detail::WrapperFunctionAsyncHandlerHelper<
530 std::remove_reference_t<HandlerT>, ResultSerializer, SPSTagTs...>;
531 WFAHH::applyAsync(std::forward<HandlerT>(Handler),
532 std::forward<SendResultT>(SendResult), ArgData, ArgSize);
533 }
534
535 private:
makeSerializable(const T & Value)536 template <typename T> static const T &makeSerializable(const T &Value) {
537 return Value;
538 }
539
makeSerializable(Error Err)540 static detail::SPSSerializableError makeSerializable(Error Err) {
541 return detail::toSPSSerializable(std::move(Err));
542 }
543
544 template <typename T>
makeSerializable(Expected<T> E)545 static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) {
546 return detail::toSPSSerializable(std::move(E));
547 }
548 };
549
550 template <typename... SPSTagTs>
551 class WrapperFunction<void(SPSTagTs...)>
552 : private WrapperFunction<SPSEmpty(SPSTagTs...)> {
553
554 public:
555 template <typename CallerFn, typename... ArgTs>
call(const CallerFn & Caller,const ArgTs &...Args)556 static Error call(const CallerFn &Caller, const ArgTs &...Args) {
557 SPSEmpty BE;
558 return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(Caller, BE, Args...);
559 }
560
561 template <typename AsyncCallerFn, typename SendDeserializedResultFn,
562 typename... ArgTs>
callAsync(AsyncCallerFn && Caller,SendDeserializedResultFn && SendDeserializedResult,const ArgTs &...Args)563 static void callAsync(AsyncCallerFn &&Caller,
564 SendDeserializedResultFn &&SendDeserializedResult,
565 const ArgTs &...Args) {
566 WrapperFunction<SPSEmpty(SPSTagTs...)>::callAsync(
567 std::forward<AsyncCallerFn>(Caller),
568 [SDR = std::move(SendDeserializedResult)](Error SerializeErr,
569 SPSEmpty E) mutable {
570 SDR(std::move(SerializeErr));
571 },
572 Args...);
573 }
574
575 using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;
576 using WrapperFunction<SPSEmpty(SPSTagTs...)>::handleAsync;
577 };
578
579 /// A function object that takes an ExecutorAddr as its first argument,
580 /// casts that address to a ClassT*, then calls the given method on that
581 /// pointer passing in the remaining function arguments. This utility
582 /// removes some of the boilerplate from writing wrappers for method calls.
583 ///
584 /// @code{.cpp}
585 /// class MyClass {
586 /// public:
587 /// void myMethod(uint32_t, bool) { ... }
588 /// };
589 ///
590 /// // SPS Method signature -- note MyClass object address as first argument.
591 /// using SPSMyMethodWrapperSignature =
592 /// SPSTuple<SPSExecutorAddr, uint32_t, bool>;
593 ///
594 /// WrapperFunctionResult
595 /// myMethodCallWrapper(const char *ArgData, size_t ArgSize) {
596 /// return WrapperFunction<SPSMyMethodWrapperSignature>::handle(
597 /// ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod));
598 /// }
599 /// @endcode
600 ///
601 template <typename RetT, typename ClassT, typename... ArgTs>
602 class MethodWrapperHandler {
603 public:
604 using MethodT = RetT (ClassT::*)(ArgTs...);
MethodWrapperHandler(MethodT M)605 MethodWrapperHandler(MethodT M) : M(M) {}
operator()606 RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) {
607 return (ObjAddr.toPtr<ClassT*>()->*M)(std::forward<ArgTs>(Args)...);
608 }
609
610 private:
611 MethodT M;
612 };
613
614 /// Create a MethodWrapperHandler object from the given method pointer.
615 template <typename RetT, typename ClassT, typename... ArgTs>
616 MethodWrapperHandler<RetT, ClassT, ArgTs...>
makeMethodWrapperHandler(RetT (ClassT::* Method)(ArgTs...))617 makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) {
618 return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method);
619 }
620
621 /// Represents a serialized wrapper function call.
622 /// Serializing calls themselves allows us to batch them: We can make one
623 /// "run-wrapper-functions" utility and send it a list of calls to run.
624 ///
625 /// The motivating use-case for this API is JITLink allocation actions, where
626 /// we want to run multiple functions to finalize linked memory without having
627 /// to make separate IPC calls for each one.
628 class WrapperFunctionCall {
629 public:
630 using ArgDataBufferType = SmallVector<char, 24>;
631
632 /// Create a WrapperFunctionCall using the given SPS serializer to serialize
633 /// the arguments.
634 template <typename SPSSerializer, typename... ArgTs>
Create(ExecutorAddr FnAddr,const ArgTs &...Args)635 static Expected<WrapperFunctionCall> Create(ExecutorAddr FnAddr,
636 const ArgTs &...Args) {
637 ArgDataBufferType ArgData;
638 ArgData.resize(SPSSerializer::size(Args...));
639 SPSOutputBuffer OB(ArgData.empty() ? nullptr : ArgData.data(),
640 ArgData.size());
641 if (SPSSerializer::serialize(OB, Args...))
642 return WrapperFunctionCall(FnAddr, std::move(ArgData));
643 return make_error<StringError>("Cannot serialize arguments for "
644 "AllocActionCall",
645 inconvertibleErrorCode());
646 }
647
648 WrapperFunctionCall() = default;
649
650 /// Create a WrapperFunctionCall from a target function and arg buffer.
WrapperFunctionCall(ExecutorAddr FnAddr,ArgDataBufferType ArgData)651 WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData)
652 : FnAddr(FnAddr), ArgData(std::move(ArgData)) {}
653
654 /// Returns the address to be called.
getCallee()655 const ExecutorAddr &getCallee() const { return FnAddr; }
656
657 /// Returns the argument data.
getArgData()658 const ArgDataBufferType &getArgData() const { return ArgData; }
659
660 /// WrapperFunctionCalls convert to true if the callee is non-null.
661 explicit operator bool() const { return !!FnAddr; }
662
663 /// Run call returning raw WrapperFunctionResult.
run()664 shared::WrapperFunctionResult run() const {
665 using FnTy =
666 shared::CWrapperFunctionResult(const char *ArgData, size_t ArgSize);
667 return shared::WrapperFunctionResult(
668 FnAddr.toPtr<FnTy *>()(ArgData.data(), ArgData.size()));
669 }
670
671 /// Run call and deserialize result using SPS.
672 template <typename SPSRetT, typename RetT>
673 std::enable_if_t<!std::is_same<SPSRetT, void>::value, Error>
runWithSPSRet(RetT & RetVal)674 runWithSPSRet(RetT &RetVal) const {
675 auto WFR = run();
676 if (const char *ErrMsg = WFR.getOutOfBandError())
677 return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
678 shared::SPSInputBuffer IB(WFR.data(), WFR.size());
679 if (!shared::SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal))
680 return make_error<StringError>("Could not deserialize result from "
681 "serialized wrapper function call",
682 inconvertibleErrorCode());
683 return Error::success();
684 }
685
686 /// Overload for SPS functions returning void.
687 template <typename SPSRetT>
688 std::enable_if_t<std::is_same<SPSRetT, void>::value, Error>
runWithSPSRet()689 runWithSPSRet() const {
690 shared::SPSEmpty E;
691 return runWithSPSRet<shared::SPSEmpty>(E);
692 }
693
694 /// Run call and deserialize an SPSError result. SPSError returns and
695 /// deserialization failures are merged into the returned error.
runWithSPSRetErrorMerged()696 Error runWithSPSRetErrorMerged() const {
697 detail::SPSSerializableError RetErr;
698 if (auto Err = runWithSPSRet<SPSError>(RetErr))
699 return Err;
700 return detail::fromSPSSerializable(std::move(RetErr));
701 }
702
703 private:
704 orc::ExecutorAddr FnAddr;
705 ArgDataBufferType ArgData;
706 };
707
708 using SPSWrapperFunctionCall = SPSTuple<SPSExecutorAddr, SPSSequence<char>>;
709
710 template <>
711 class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> {
712 public:
size(const WrapperFunctionCall & WFC)713 static size_t size(const WrapperFunctionCall &WFC) {
714 return SPSWrapperFunctionCall::AsArgList::size(WFC.getCallee(),
715 WFC.getArgData());
716 }
717
serialize(SPSOutputBuffer & OB,const WrapperFunctionCall & WFC)718 static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) {
719 return SPSWrapperFunctionCall::AsArgList::serialize(OB, WFC.getCallee(),
720 WFC.getArgData());
721 }
722
deserialize(SPSInputBuffer & IB,WrapperFunctionCall & WFC)723 static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) {
724 ExecutorAddr FnAddr;
725 WrapperFunctionCall::ArgDataBufferType ArgData;
726 if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, FnAddr, ArgData))
727 return false;
728 WFC = WrapperFunctionCall(FnAddr, std::move(ArgData));
729 return true;
730 }
731 };
732
733 } // end namespace shared
734 } // end namespace orc
735 } // end namespace llvm
736
737 #endif // LLVM_EXECUTIONENGINE_ORC_SHARED_WRAPPERFUNCTIONUTILS_H
738