1 //===- WrapperFunctionUtils.h - Utilities for wrapper functions -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A buffer for serialized results.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_EXECUTIONENGINE_ORC_WRAPPERFUNCTIONUTILS_H
14 #define LLVM_EXECUTIONENGINE_ORC_WRAPPERFUNCTIONUTILS_H
15 
16 #include "llvm/ExecutionEngine/Orc/Shared/SimplePackedSerialization.h"
17 #include "llvm/Support/Error.h"
18 
19 #include <type_traits>
20 
21 namespace llvm {
22 namespace orc {
23 namespace shared {
24 
25 namespace detail {
26 
27 // DO NOT USE DIRECTLY.
28 // Must be kept in-sync with compiler-rt/lib/orc/c-api.h.
29 union CWrapperFunctionResultDataUnion {
30   char *ValuePtr;
31   char Value[sizeof(ValuePtr)];
32 };
33 
34 // DO NOT USE DIRECTLY.
35 // Must be kept in-sync with compiler-rt/lib/orc/c-api.h.
36 typedef struct {
37   CWrapperFunctionResultDataUnion Data;
38   size_t Size;
39 } CWrapperFunctionResult;
40 
41 } // end namespace detail
42 
43 /// C++ wrapper function result: Same as CWrapperFunctionResult but
44 /// auto-releases memory.
45 class WrapperFunctionResult {
46 public:
47   /// Create a default WrapperFunctionResult.
48   WrapperFunctionResult() { init(R); }
49 
50   /// Create a WrapperFunctionResult by taking ownership of a
51   /// detail::CWrapperFunctionResult.
52   ///
53   /// Warning: This should only be used by clients writing wrapper-function
54   /// caller utilities (like TargetProcessControl).
55   WrapperFunctionResult(detail::CWrapperFunctionResult R) : R(R) {
56     // Reset R.
57     init(R);
58   }
59 
60   WrapperFunctionResult(const WrapperFunctionResult &) = delete;
61   WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete;
62 
63   WrapperFunctionResult(WrapperFunctionResult &&Other) {
64     init(R);
65     std::swap(R, Other.R);
66   }
67 
68   WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) {
69     WrapperFunctionResult Tmp(std::move(Other));
70     std::swap(R, Tmp.R);
71     return *this;
72   }
73 
74   ~WrapperFunctionResult() {
75     if ((R.Size > sizeof(R.Data.Value)) ||
76         (R.Size == 0 && R.Data.ValuePtr != nullptr))
77       free(R.Data.ValuePtr);
78   }
79 
80   /// Release ownership of the contained detail::CWrapperFunctionResult.
81   /// Warning: Do not use -- this method will be removed in the future. It only
82   /// exists to temporarily support some code that will eventually be moved to
83   /// the ORC runtime.
84   detail::CWrapperFunctionResult release() {
85     detail::CWrapperFunctionResult Tmp;
86     init(Tmp);
87     std::swap(R, Tmp);
88     return Tmp;
89   }
90 
91   /// Get a pointer to the data contained in this instance.
92   const char *data() const {
93     assert((R.Size != 0 || R.Data.ValuePtr == nullptr) &&
94            "Cannot get data for out-of-band error value");
95     return R.Size > sizeof(R.Data.Value) ? R.Data.ValuePtr : R.Data.Value;
96   }
97 
98   /// Returns the size of the data contained in this instance.
99   size_t size() const {
100     assert((R.Size != 0 || R.Data.ValuePtr == nullptr) &&
101            "Cannot get data for out-of-band error value");
102     return R.Size;
103   }
104 
105   /// Returns true if this value is equivalent to a default-constructed
106   /// WrapperFunctionResult.
107   bool empty() const { return R.Size == 0 && R.Data.ValuePtr == nullptr; }
108 
109   /// Create a WrapperFunctionResult with the given size and return a pointer
110   /// to the underlying memory.
111   static char *allocate(WrapperFunctionResult &WFR, size_t Size) {
112     // Reset.
113     WFR = WrapperFunctionResult();
114     WFR.R.Size = Size;
115     char *DataPtr;
116     if (WFR.R.Size > sizeof(WFR.R.Data.Value)) {
117       DataPtr = (char *)malloc(WFR.R.Size);
118       WFR.R.Data.ValuePtr = DataPtr;
119     } else
120       DataPtr = WFR.R.Data.Value;
121     return DataPtr;
122   }
123 
124   /// Copy from the given char range.
125   static WrapperFunctionResult copyFrom(const char *Source, size_t Size) {
126     WrapperFunctionResult WFR;
127     char *DataPtr = allocate(WFR, Size);
128     memcpy(DataPtr, Source, Size);
129     return WFR;
130   }
131 
132   /// Copy from the given null-terminated string (includes the null-terminator).
133   static WrapperFunctionResult copyFrom(const char *Source) {
134     return copyFrom(Source, strlen(Source) + 1);
135   }
136 
137   /// Copy from the given std::string (includes the null terminator).
138   static WrapperFunctionResult copyFrom(const std::string &Source) {
139     return copyFrom(Source.c_str());
140   }
141 
142   /// Create an out-of-band error by copying the given string.
143   static WrapperFunctionResult createOutOfBandError(const char *Msg) {
144     // Reset.
145     WrapperFunctionResult WFR;
146     char *Tmp = (char *)malloc(strlen(Msg) + 1);
147     strcpy(Tmp, Msg);
148     WFR.R.Data.ValuePtr = Tmp;
149     return WFR;
150   }
151 
152   /// Create an out-of-band error by copying the given string.
153   static WrapperFunctionResult createOutOfBandError(const std::string &Msg) {
154     return createOutOfBandError(Msg.c_str());
155   }
156 
157   /// If this value is an out-of-band error then this returns the error message,
158   /// otherwise returns nullptr.
159   const char *getOutOfBandError() const {
160     return R.Size == 0 ? R.Data.ValuePtr : nullptr;
161   }
162 
163 private:
164   static void init(detail::CWrapperFunctionResult &R) {
165     R.Data.ValuePtr = nullptr;
166     R.Size = 0;
167   }
168 
169   detail::CWrapperFunctionResult R;
170 };
171 
172 namespace detail {
173 
174 template <typename SPSArgListT, typename... ArgTs>
175 WrapperFunctionResult
176 serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) {
177   WrapperFunctionResult Result;
178   char *DataPtr =
179       WrapperFunctionResult::allocate(Result, SPSArgListT::size(Args...));
180   SPSOutputBuffer OB(DataPtr, Result.size());
181   if (!SPSArgListT::serialize(OB, Args...))
182     return WrapperFunctionResult::createOutOfBandError(
183         "Error serializing arguments to blob in call");
184   return Result;
185 }
186 
187 template <typename RetT> class WrapperFunctionHandlerCaller {
188 public:
189   template <typename HandlerT, typename ArgTupleT, std::size_t... I>
190   static decltype(auto) call(HandlerT &&H, ArgTupleT &Args,
191                              std::index_sequence<I...>) {
192     return std::forward<HandlerT>(H)(std::get<I>(Args)...);
193   }
194 };
195 
196 template <> class WrapperFunctionHandlerCaller<void> {
197 public:
198   template <typename HandlerT, typename ArgTupleT, std::size_t... I>
199   static SPSEmpty call(HandlerT &&H, ArgTupleT &Args,
200                        std::index_sequence<I...>) {
201     std::forward<HandlerT>(H)(std::get<I>(Args)...);
202     return SPSEmpty();
203   }
204 };
205 
206 template <typename WrapperFunctionImplT,
207           template <typename> class ResultSerializer, typename... SPSTagTs>
208 class WrapperFunctionHandlerHelper
209     : public WrapperFunctionHandlerHelper<
210           decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
211           ResultSerializer, SPSTagTs...> {};
212 
213 template <typename RetT, typename... ArgTs,
214           template <typename> class ResultSerializer, typename... SPSTagTs>
215 class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
216                                    SPSTagTs...> {
217 public:
218   using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
219   using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
220 
221   template <typename HandlerT>
222   static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData,
223                                      size_t ArgSize) {
224     ArgTuple Args;
225     if (!deserialize(ArgData, ArgSize, Args, ArgIndices{}))
226       return WrapperFunctionResult::createOutOfBandError(
227           "Could not deserialize arguments for wrapper function call");
228 
229     auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call(
230         std::forward<HandlerT>(H), Args, ArgIndices{});
231 
232     return ResultSerializer<decltype(HandlerResult)>::serialize(
233         std::move(HandlerResult));
234   }
235 
236 private:
237   template <std::size_t... I>
238   static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
239                           std::index_sequence<I...>) {
240     SPSInputBuffer IB(ArgData, ArgSize);
241     return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
242   }
243 };
244 
245 // Map function pointers to function types.
246 template <typename RetT, typename... ArgTs,
247           template <typename> class ResultSerializer, typename... SPSTagTs>
248 class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
249                                    SPSTagTs...>
250     : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
251                                           SPSTagTs...> {};
252 
253 // Map non-const member function types to function types.
254 template <typename ClassT, typename RetT, typename... ArgTs,
255           template <typename> class ResultSerializer, typename... SPSTagTs>
256 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer,
257                                    SPSTagTs...>
258     : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
259                                           SPSTagTs...> {};
260 
261 // Map const member function types to function types.
262 template <typename ClassT, typename RetT, typename... ArgTs,
263           template <typename> class ResultSerializer, typename... SPSTagTs>
264 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
265                                    ResultSerializer, SPSTagTs...>
266     : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
267                                           SPSTagTs...> {};
268 
269 template <typename WrapperFunctionImplT,
270           template <typename> class ResultSerializer, typename... SPSTagTs>
271 class WrapperFunctionAsyncHandlerHelper
272     : public WrapperFunctionAsyncHandlerHelper<
273           decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
274           ResultSerializer, SPSTagTs...> {};
275 
276 template <typename RetT, typename SendResultT, typename... ArgTs,
277           template <typename> class ResultSerializer, typename... SPSTagTs>
278 class WrapperFunctionAsyncHandlerHelper<RetT(SendResultT, ArgTs...),
279                                         ResultSerializer, SPSTagTs...> {
280 public:
281   using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
282   using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
283 
284   template <typename HandlerT, typename SendWrapperFunctionResultT>
285   static void applyAsync(HandlerT &&H,
286                          SendWrapperFunctionResultT &&SendWrapperFunctionResult,
287                          const char *ArgData, size_t ArgSize) {
288     ArgTuple Args;
289     if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) {
290       SendWrapperFunctionResult(WrapperFunctionResult::createOutOfBandError(
291           "Could not deserialize arguments for wrapper function call"));
292       return;
293     }
294 
295     auto SendResult =
296         [SendWFR = std::move(SendWrapperFunctionResult)](auto Result) mutable {
297           using ResultT = decltype(Result);
298           SendWFR(ResultSerializer<ResultT>::serialize(std::move(Result)));
299         };
300 
301     callAsync(std::forward<HandlerT>(H), std::move(SendResult), std::move(Args),
302               ArgIndices{});
303   }
304 
305 private:
306   template <std::size_t... I>
307   static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
308                           std::index_sequence<I...>) {
309     SPSInputBuffer IB(ArgData, ArgSize);
310     return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
311   }
312 
313   template <typename HandlerT, typename SerializeAndSendResultT,
314             typename ArgTupleT, std::size_t... I>
315   static void callAsync(HandlerT &&H,
316                         SerializeAndSendResultT &&SerializeAndSendResult,
317                         ArgTupleT Args, std::index_sequence<I...>) {
318     return std::forward<HandlerT>(H)(std::move(SerializeAndSendResult),
319                                      std::move(std::get<I>(Args))...);
320   }
321 };
322 
323 // Map function pointers to function types.
324 template <typename RetT, typename... ArgTs,
325           template <typename> class ResultSerializer, typename... SPSTagTs>
326 class WrapperFunctionAsyncHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
327                                         SPSTagTs...>
328     : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
329                                                SPSTagTs...> {};
330 
331 // Map non-const member function types to function types.
332 template <typename ClassT, typename RetT, typename... ArgTs,
333           template <typename> class ResultSerializer, typename... SPSTagTs>
334 class WrapperFunctionAsyncHandlerHelper<RetT (ClassT::*)(ArgTs...),
335                                         ResultSerializer, SPSTagTs...>
336     : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
337                                                SPSTagTs...> {};
338 
339 // Map const member function types to function types.
340 template <typename ClassT, typename RetT, typename... ArgTs,
341           template <typename> class ResultSerializer, typename... SPSTagTs>
342 class WrapperFunctionAsyncHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
343                                         ResultSerializer, SPSTagTs...>
344     : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
345                                                SPSTagTs...> {};
346 
347 template <typename SPSRetTagT, typename RetT> class ResultSerializer {
348 public:
349   static WrapperFunctionResult serialize(RetT Result) {
350     return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
351         Result);
352   }
353 };
354 
355 template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
356 public:
357   static WrapperFunctionResult serialize(Error Err) {
358     return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
359         toSPSSerializable(std::move(Err)));
360   }
361 };
362 
363 template <typename SPSRetTagT, typename T>
364 class ResultSerializer<SPSRetTagT, Expected<T>> {
365 public:
366   static WrapperFunctionResult serialize(Expected<T> E) {
367     return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
368         toSPSSerializable(std::move(E)));
369   }
370 };
371 
372 template <typename SPSRetTagT, typename RetT> class ResultDeserializer {
373 public:
374   static RetT makeValue() { return RetT(); }
375   static void makeSafe(RetT &Result) {}
376 
377   static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) {
378     SPSInputBuffer IB(ArgData, ArgSize);
379     if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result))
380       return make_error<StringError>(
381           "Error deserializing return value from blob in call",
382           inconvertibleErrorCode());
383     return Error::success();
384   }
385 };
386 
387 template <> class ResultDeserializer<SPSError, Error> {
388 public:
389   static Error makeValue() { return Error::success(); }
390   static void makeSafe(Error &Err) { cantFail(std::move(Err)); }
391 
392   static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) {
393     SPSInputBuffer IB(ArgData, ArgSize);
394     SPSSerializableError BSE;
395     if (!SPSArgList<SPSError>::deserialize(IB, BSE))
396       return make_error<StringError>(
397           "Error deserializing return value from blob in call",
398           inconvertibleErrorCode());
399     Err = fromSPSSerializable(std::move(BSE));
400     return Error::success();
401   }
402 };
403 
404 template <typename SPSTagT, typename T>
405 class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> {
406 public:
407   static Expected<T> makeValue() { return T(); }
408   static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); }
409 
410   static Error deserialize(Expected<T> &E, const char *ArgData,
411                            size_t ArgSize) {
412     SPSInputBuffer IB(ArgData, ArgSize);
413     SPSSerializableExpected<T> BSE;
414     if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE))
415       return make_error<StringError>(
416           "Error deserializing return value from blob in call",
417           inconvertibleErrorCode());
418     E = fromSPSSerializable(std::move(BSE));
419     return Error::success();
420   }
421 };
422 
423 template <typename SPSRetTagT, typename RetT> class AsyncCallResultHelper {
424   // Did you forget to use Error / Expected in your handler?
425 };
426 
427 } // end namespace detail
428 
429 template <typename SPSSignature> class WrapperFunction;
430 
431 template <typename SPSRetTagT, typename... SPSTagTs>
432 class WrapperFunction<SPSRetTagT(SPSTagTs...)> {
433 private:
434   template <typename RetT>
435   using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>;
436 
437 public:
438   /// Call a wrapper function. Caller should be callable as
439   /// WrapperFunctionResult Fn(const char *ArgData, size_t ArgSize);
440   template <typename CallerFn, typename RetT, typename... ArgTs>
441   static Error call(const CallerFn &Caller, RetT &Result,
442                     const ArgTs &...Args) {
443 
444     // RetT might be an Error or Expected value. Set the checked flag now:
445     // we don't want the user to have to check the unused result if this
446     // operation fails.
447     detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result);
448 
449     auto ArgBuffer =
450         detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
451             Args...);
452     if (const char *ErrMsg = ArgBuffer.getOutOfBandError())
453       return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
454 
455     WrapperFunctionResult ResultBuffer =
456         Caller(ArgBuffer.data(), ArgBuffer.size());
457     if (auto ErrMsg = ResultBuffer.getOutOfBandError())
458       return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
459 
460     return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
461         Result, ResultBuffer.data(), ResultBuffer.size());
462   }
463 
464   /// Call an async wrapper function.
465   /// Caller should be callable as
466   /// void Fn(unique_function<void(WrapperFunctionResult)> SendResult,
467   ///         WrapperFunctionResult ArgBuffer);
468   template <typename AsyncCallerFn, typename SendDeserializedResultFn,
469             typename... ArgTs>
470   static void callAsync(AsyncCallerFn &&Caller,
471                         SendDeserializedResultFn &&SendDeserializedResult,
472                         const ArgTs &...Args) {
473     using RetT = typename std::tuple_element<
474         1, typename detail::WrapperFunctionHandlerHelper<
475                std::remove_reference_t<SendDeserializedResultFn>,
476                ResultSerializer, SPSRetTagT>::ArgTuple>::type;
477 
478     auto ArgBuffer =
479         detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
480             Args...);
481     if (auto *ErrMsg = ArgBuffer.getOutOfBandError()) {
482       SendDeserializedResult(
483           make_error<StringError>(ErrMsg, inconvertibleErrorCode()),
484           detail::ResultDeserializer<SPSRetTagT, RetT>::makeValue());
485       return;
486     }
487 
488     auto SendSerializedResult = [SDR = std::move(SendDeserializedResult)](
489                                     WrapperFunctionResult R) {
490       RetT RetVal = detail::ResultDeserializer<SPSRetTagT, RetT>::makeValue();
491       detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(RetVal);
492 
493       SPSInputBuffer IB(R.data(), R.size());
494       if (auto Err = detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
495               RetVal, R.data(), R.size()))
496         SDR(std::move(Err), std::move(RetVal));
497 
498       SDR(Error::success(), std::move(RetVal));
499     };
500 
501     Caller(std::move(SendSerializedResult), ArgBuffer.data(), ArgBuffer.size());
502   }
503 
504   /// Handle a call to a wrapper function.
505   template <typename HandlerT>
506   static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize,
507                                       HandlerT &&Handler) {
508     using WFHH =
509         detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>,
510                                              ResultSerializer, SPSTagTs...>;
511     return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize);
512   }
513 
514   /// Handle a call to an async wrapper function.
515   template <typename HandlerT, typename SendResultT>
516   static void handleAsync(const char *ArgData, size_t ArgSize,
517                           HandlerT &&Handler, SendResultT &&SendResult) {
518     using WFAHH = detail::WrapperFunctionAsyncHandlerHelper<
519         std::remove_reference_t<HandlerT>, ResultSerializer, SPSTagTs...>;
520     WFAHH::applyAsync(std::forward<HandlerT>(Handler),
521                       std::forward<SendResultT>(SendResult), ArgData, ArgSize);
522   }
523 
524 private:
525   template <typename T> static const T &makeSerializable(const T &Value) {
526     return Value;
527   }
528 
529   static detail::SPSSerializableError makeSerializable(Error Err) {
530     return detail::toSPSSerializable(std::move(Err));
531   }
532 
533   template <typename T>
534   static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) {
535     return detail::toSPSSerializable(std::move(E));
536   }
537 };
538 
539 template <typename... SPSTagTs>
540 class WrapperFunction<void(SPSTagTs...)>
541     : private WrapperFunction<SPSEmpty(SPSTagTs...)> {
542 
543 public:
544   template <typename CallerFn, typename... ArgTs>
545   static Error call(const CallerFn &Caller, const ArgTs &...Args) {
546     SPSEmpty BE;
547     return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(Caller, BE, Args...);
548   }
549 
550   using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;
551   using WrapperFunction<SPSEmpty(SPSTagTs...)>::handleAsync;
552 };
553 
554 } // end namespace shared
555 } // end namespace orc
556 } // end namespace llvm
557 
558 #endif // LLVM_EXECUTIONENGINE_ORC_WRAPPERFUNCTIONUTILS_H
559