1 //===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file is a part of the ORC runtime support library.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H
14 #define ORC_RT_WRAPPER_FUNCTION_UTILS_H
15 
16 #include "orc_rt/c_api.h"
17 #include "common.h"
18 #include "error.h"
19 #include "executor_address.h"
20 #include "simple_packed_serialization.h"
21 #include <type_traits>
22 
23 namespace __orc_rt {
24 
25 /// C++ wrapper function result: Same as CWrapperFunctionResult but
26 /// auto-releases memory.
27 class WrapperFunctionResult {
28 public:
29   /// Create a default WrapperFunctionResult.
30   WrapperFunctionResult() { orc_rt_CWrapperFunctionResultInit(&R); }
31 
32   /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This
33   /// instance takes ownership of the result object and will automatically
34   /// call dispose on the result upon destruction.
35   WrapperFunctionResult(orc_rt_CWrapperFunctionResult R) : R(R) {}
36 
37   WrapperFunctionResult(const WrapperFunctionResult &) = delete;
38   WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete;
39 
40   WrapperFunctionResult(WrapperFunctionResult &&Other) {
41     orc_rt_CWrapperFunctionResultInit(&R);
42     std::swap(R, Other.R);
43   }
44 
45   WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) {
46     orc_rt_CWrapperFunctionResult Tmp;
47     orc_rt_CWrapperFunctionResultInit(&Tmp);
48     std::swap(Tmp, Other.R);
49     std::swap(R, Tmp);
50     return *this;
51   }
52 
53   ~WrapperFunctionResult() { orc_rt_DisposeCWrapperFunctionResult(&R); }
54 
55   /// Relinquish ownership of and return the
56   /// orc_rt_CWrapperFunctionResult.
57   orc_rt_CWrapperFunctionResult release() {
58     orc_rt_CWrapperFunctionResult Tmp;
59     orc_rt_CWrapperFunctionResultInit(&Tmp);
60     std::swap(R, Tmp);
61     return Tmp;
62   }
63 
64   /// Get a pointer to the data contained in this instance.
65   char *data() { return orc_rt_CWrapperFunctionResultData(&R); }
66 
67   /// Returns the size of the data contained in this instance.
68   size_t size() const { return orc_rt_CWrapperFunctionResultSize(&R); }
69 
70   /// Returns true if this value is equivalent to a default-constructed
71   /// WrapperFunctionResult.
72   bool empty() const { return orc_rt_CWrapperFunctionResultEmpty(&R); }
73 
74   /// Create a WrapperFunctionResult with the given size and return a pointer
75   /// to the underlying memory.
76   static WrapperFunctionResult allocate(size_t Size) {
77     WrapperFunctionResult R;
78     R.R = orc_rt_CWrapperFunctionResultAllocate(Size);
79     return R;
80   }
81 
82   /// Copy from the given char range.
83   static WrapperFunctionResult copyFrom(const char *Source, size_t Size) {
84     return orc_rt_CreateCWrapperFunctionResultFromRange(Source, Size);
85   }
86 
87   /// Copy from the given null-terminated string (includes the null-terminator).
88   static WrapperFunctionResult copyFrom(const char *Source) {
89     return orc_rt_CreateCWrapperFunctionResultFromString(Source);
90   }
91 
92   /// Copy from the given std::string (includes the null terminator).
93   static WrapperFunctionResult copyFrom(const std::string &Source) {
94     return copyFrom(Source.c_str());
95   }
96 
97   /// Create an out-of-band error by copying the given string.
98   static WrapperFunctionResult createOutOfBandError(const char *Msg) {
99     return orc_rt_CreateCWrapperFunctionResultFromOutOfBandError(Msg);
100   }
101 
102   /// Create an out-of-band error by copying the given string.
103   static WrapperFunctionResult createOutOfBandError(const std::string &Msg) {
104     return createOutOfBandError(Msg.c_str());
105   }
106 
107   template <typename SPSArgListT, typename... ArgTs>
108   static WrapperFunctionResult fromSPSArgs(const ArgTs &...Args) {
109     auto Result = allocate(SPSArgListT::size(Args...));
110     SPSOutputBuffer OB(Result.data(), Result.size());
111     if (!SPSArgListT::serialize(OB, Args...))
112       return createOutOfBandError(
113           "Error serializing arguments to blob in call");
114     return Result;
115   }
116 
117   /// If this value is an out-of-band error then this returns the error message,
118   /// otherwise returns nullptr.
119   const char *getOutOfBandError() const {
120     return orc_rt_CWrapperFunctionResultGetOutOfBandError(&R);
121   }
122 
123 private:
124   orc_rt_CWrapperFunctionResult R;
125 };
126 
127 namespace detail {
128 
129 template <typename RetT> class WrapperFunctionHandlerCaller {
130 public:
131   template <typename HandlerT, typename ArgTupleT, std::size_t... I>
132   static decltype(auto) call(HandlerT &&H, ArgTupleT &Args,
133                              std::index_sequence<I...>) {
134     return std::forward<HandlerT>(H)(std::get<I>(Args)...);
135   }
136 };
137 
138 template <> class WrapperFunctionHandlerCaller<void> {
139 public:
140   template <typename HandlerT, typename ArgTupleT, std::size_t... I>
141   static SPSEmpty call(HandlerT &&H, ArgTupleT &Args,
142                        std::index_sequence<I...>) {
143     std::forward<HandlerT>(H)(std::get<I>(Args)...);
144     return SPSEmpty();
145   }
146 };
147 
148 template <typename WrapperFunctionImplT,
149           template <typename> class ResultSerializer, typename... SPSTagTs>
150 class WrapperFunctionHandlerHelper
151     : public WrapperFunctionHandlerHelper<
152           decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
153           ResultSerializer, SPSTagTs...> {};
154 
155 template <typename RetT, typename... ArgTs,
156           template <typename> class ResultSerializer, typename... SPSTagTs>
157 class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
158                                    SPSTagTs...> {
159 public:
160   using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
161   using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
162 
163   template <typename HandlerT>
164   static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData,
165                                      size_t ArgSize) {
166     ArgTuple Args;
167     if (!deserialize(ArgData, ArgSize, Args, ArgIndices{}))
168       return WrapperFunctionResult::createOutOfBandError(
169           "Could not deserialize arguments for wrapper function call");
170 
171     auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call(
172         std::forward<HandlerT>(H), Args, ArgIndices{});
173 
174     return ResultSerializer<decltype(HandlerResult)>::serialize(
175         std::move(HandlerResult));
176   }
177 
178 private:
179   template <std::size_t... I>
180   static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
181                           std::index_sequence<I...>) {
182     SPSInputBuffer IB(ArgData, ArgSize);
183     return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
184   }
185 };
186 
187 // Map function pointers to function types.
188 template <typename RetT, typename... ArgTs,
189           template <typename> class ResultSerializer, typename... SPSTagTs>
190 class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
191                                    SPSTagTs...>
192     : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
193                                           SPSTagTs...> {};
194 
195 // Map non-const member function types to function types.
196 template <typename ClassT, typename RetT, typename... ArgTs,
197           template <typename> class ResultSerializer, typename... SPSTagTs>
198 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer,
199                                    SPSTagTs...>
200     : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
201                                           SPSTagTs...> {};
202 
203 // Map const member function types to function types.
204 template <typename ClassT, typename RetT, typename... ArgTs,
205           template <typename> class ResultSerializer, typename... SPSTagTs>
206 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
207                                    ResultSerializer, SPSTagTs...>
208     : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
209                                           SPSTagTs...> {};
210 
211 template <typename SPSRetTagT, typename RetT> class ResultSerializer {
212 public:
213   static WrapperFunctionResult serialize(RetT Result) {
214     return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(Result);
215   }
216 };
217 
218 template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
219 public:
220   static WrapperFunctionResult serialize(Error Err) {
221     return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
222         toSPSSerializable(std::move(Err)));
223   }
224 };
225 
226 template <typename SPSRetTagT, typename T>
227 class ResultSerializer<SPSRetTagT, Expected<T>> {
228 public:
229   static WrapperFunctionResult serialize(Expected<T> E) {
230     return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
231         toSPSSerializable(std::move(E)));
232   }
233 };
234 
235 template <typename SPSRetTagT, typename RetT> class ResultDeserializer {
236 public:
237   static void makeSafe(RetT &Result) {}
238 
239   static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) {
240     SPSInputBuffer IB(ArgData, ArgSize);
241     if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result))
242       return make_error<StringError>(
243           "Error deserializing return value from blob in call");
244     return Error::success();
245   }
246 };
247 
248 template <> class ResultDeserializer<SPSError, Error> {
249 public:
250   static void makeSafe(Error &Err) { cantFail(std::move(Err)); }
251 
252   static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) {
253     SPSInputBuffer IB(ArgData, ArgSize);
254     SPSSerializableError BSE;
255     if (!SPSArgList<SPSError>::deserialize(IB, BSE))
256       return make_error<StringError>(
257           "Error deserializing return value from blob in call");
258     Err = fromSPSSerializable(std::move(BSE));
259     return Error::success();
260   }
261 };
262 
263 template <typename SPSTagT, typename T>
264 class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> {
265 public:
266   static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); }
267 
268   static Error deserialize(Expected<T> &E, const char *ArgData,
269                            size_t ArgSize) {
270     SPSInputBuffer IB(ArgData, ArgSize);
271     SPSSerializableExpected<T> BSE;
272     if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE))
273       return make_error<StringError>(
274           "Error deserializing return value from blob in call");
275     E = fromSPSSerializable(std::move(BSE));
276     return Error::success();
277   }
278 };
279 
280 } // end namespace detail
281 
282 template <typename SPSSignature> class WrapperFunction;
283 
284 template <typename SPSRetTagT, typename... SPSTagTs>
285 class WrapperFunction<SPSRetTagT(SPSTagTs...)> {
286 private:
287   template <typename RetT>
288   using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>;
289 
290 public:
291   template <typename RetT, typename... ArgTs>
292   static Error call(const void *FnTag, RetT &Result, const ArgTs &...Args) {
293 
294     // RetT might be an Error or Expected value. Set the checked flag now:
295     // we don't want the user to have to check the unused result if this
296     // operation fails.
297     detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result);
298 
299     // Since the functions cannot be zero/unresolved on Windows, the following
300     // reference taking would always be non-zero, thus generating a compiler
301     // warning otherwise.
302 #if !defined(_WIN32)
303     if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch_ctx))
304       return make_error<StringError>("__orc_rt_jit_dispatch_ctx not set");
305     if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch))
306       return make_error<StringError>("__orc_rt_jit_dispatch not set");
307 #endif
308     auto ArgBuffer =
309         WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSTagTs...>>(Args...);
310     if (const char *ErrMsg = ArgBuffer.getOutOfBandError())
311       return make_error<StringError>(ErrMsg);
312 
313     WrapperFunctionResult ResultBuffer = __orc_rt_jit_dispatch(
314         &__orc_rt_jit_dispatch_ctx, FnTag, ArgBuffer.data(), ArgBuffer.size());
315     if (auto ErrMsg = ResultBuffer.getOutOfBandError())
316       return make_error<StringError>(ErrMsg);
317 
318     return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
319         Result, ResultBuffer.data(), ResultBuffer.size());
320   }
321 
322   template <typename HandlerT>
323   static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize,
324                                       HandlerT &&Handler) {
325     using WFHH =
326         detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>,
327                                              ResultSerializer, SPSTagTs...>;
328     return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize);
329   }
330 
331 private:
332   template <typename T> static const T &makeSerializable(const T &Value) {
333     return Value;
334   }
335 
336   static detail::SPSSerializableError makeSerializable(Error Err) {
337     return detail::toSPSSerializable(std::move(Err));
338   }
339 
340   template <typename T>
341   static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) {
342     return detail::toSPSSerializable(std::move(E));
343   }
344 };
345 
346 template <typename... SPSTagTs>
347 class WrapperFunction<void(SPSTagTs...)>
348     : private WrapperFunction<SPSEmpty(SPSTagTs...)> {
349 public:
350   template <typename... ArgTs>
351   static Error call(const void *FnTag, const ArgTs &...Args) {
352     SPSEmpty BE;
353     return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(FnTag, BE, Args...);
354   }
355 
356   using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;
357 };
358 
359 /// A function object that takes an ExecutorAddr as its first argument,
360 /// casts that address to a ClassT*, then calls the given method on that
361 /// pointer passing in the remaining function arguments. This utility
362 /// removes some of the boilerplate from writing wrappers for method calls.
363 ///
364 ///   @code{.cpp}
365 ///   class MyClass {
366 ///   public:
367 ///     void myMethod(uint32_t, bool) { ... }
368 ///   };
369 ///
370 ///   // SPS Method signature -- note MyClass object address as first argument.
371 ///   using SPSMyMethodWrapperSignature =
372 ///     SPSTuple<SPSExecutorAddr, uint32_t, bool>;
373 ///
374 ///   WrapperFunctionResult
375 ///   myMethodCallWrapper(const char *ArgData, size_t ArgSize) {
376 ///     return WrapperFunction<SPSMyMethodWrapperSignature>::handle(
377 ///        ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod));
378 ///   }
379 ///   @endcode
380 ///
381 template <typename RetT, typename ClassT, typename... ArgTs>
382 class MethodWrapperHandler {
383 public:
384   using MethodT = RetT (ClassT::*)(ArgTs...);
385   MethodWrapperHandler(MethodT M) : M(M) {}
386   RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) {
387     return (ObjAddr.toPtr<ClassT *>()->*M)(std::forward<ArgTs>(Args)...);
388   }
389 
390 private:
391   MethodT M;
392 };
393 
394 /// Create a MethodWrapperHandler object from the given method pointer.
395 template <typename RetT, typename ClassT, typename... ArgTs>
396 MethodWrapperHandler<RetT, ClassT, ArgTs...>
397 makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) {
398   return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method);
399 }
400 
401 /// Represents a call to a wrapper function.
402 class WrapperFunctionCall {
403 public:
404   // FIXME: Switch to a SmallVector<char, 24> once ORC runtime has a
405   // smallvector.
406   using ArgDataBufferType = std::vector<char>;
407 
408   /// Create a WrapperFunctionCall using the given SPS serializer to serialize
409   /// the arguments.
410   template <typename SPSSerializer, typename... ArgTs>
411   static Expected<WrapperFunctionCall> Create(ExecutorAddr FnAddr,
412                                               const ArgTs &...Args) {
413     ArgDataBufferType ArgData;
414     ArgData.resize(SPSSerializer::size(Args...));
415     SPSOutputBuffer OB(ArgData.empty() ? nullptr : ArgData.data(),
416                        ArgData.size());
417     if (SPSSerializer::serialize(OB, Args...))
418       return WrapperFunctionCall(FnAddr, std::move(ArgData));
419     return make_error<StringError>("Cannot serialize arguments for "
420                                    "AllocActionCall");
421   }
422 
423   WrapperFunctionCall() = default;
424 
425   /// Create a WrapperFunctionCall from a target function and arg buffer.
426   WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData)
427       : FnAddr(FnAddr), ArgData(std::move(ArgData)) {}
428 
429   /// Returns the address to be called.
430   const ExecutorAddr &getCallee() const { return FnAddr; }
431 
432   /// Returns the argument data.
433   const ArgDataBufferType &getArgData() const { return ArgData; }
434 
435   /// WrapperFunctionCalls convert to true if the callee is non-null.
436   explicit operator bool() const { return !!FnAddr; }
437 
438   /// Run call returning raw WrapperFunctionResult.
439   WrapperFunctionResult run() const {
440     using FnTy =
441         orc_rt_CWrapperFunctionResult(const char *ArgData, size_t ArgSize);
442     return WrapperFunctionResult(
443         FnAddr.toPtr<FnTy *>()(ArgData.data(), ArgData.size()));
444   }
445 
446   /// Run call and deserialize result using SPS.
447   template <typename SPSRetT, typename RetT>
448   std::enable_if_t<!std::is_same<SPSRetT, void>::value, Error>
449   runWithSPSRet(RetT &RetVal) const {
450     auto WFR = run();
451     if (const char *ErrMsg = WFR.getOutOfBandError())
452       return make_error<StringError>(ErrMsg);
453     SPSInputBuffer IB(WFR.data(), WFR.size());
454     if (!SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal))
455       return make_error<StringError>("Could not deserialize result from "
456                                      "serialized wrapper function call");
457     return Error::success();
458   }
459 
460   /// Overload for SPS functions returning void.
461   template <typename SPSRetT>
462   std::enable_if_t<std::is_same<SPSRetT, void>::value, Error>
463   runWithSPSRet() const {
464     SPSEmpty E;
465     return runWithSPSRet<SPSEmpty>(E);
466   }
467 
468   /// Run call and deserialize an SPSError result. SPSError returns and
469   /// deserialization failures are merged into the returned error.
470   Error runWithSPSRetErrorMerged() const {
471     detail::SPSSerializableError RetErr;
472     if (auto Err = runWithSPSRet<SPSError>(RetErr))
473       return Err;
474     return detail::fromSPSSerializable(std::move(RetErr));
475   }
476 
477 private:
478   ExecutorAddr FnAddr;
479   std::vector<char> ArgData;
480 };
481 
482 using SPSWrapperFunctionCall = SPSTuple<SPSExecutorAddr, SPSSequence<char>>;
483 
484 template <>
485 class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> {
486 public:
487   static size_t size(const WrapperFunctionCall &WFC) {
488     return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::size(
489         WFC.getCallee(), WFC.getArgData());
490   }
491 
492   static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) {
493     return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::serialize(
494         OB, WFC.getCallee(), WFC.getArgData());
495   }
496 
497   static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) {
498     ExecutorAddr FnAddr;
499     WrapperFunctionCall::ArgDataBufferType ArgData;
500     if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, FnAddr, ArgData))
501       return false;
502     WFC = WrapperFunctionCall(FnAddr, std::move(ArgData));
503     return true;
504   }
505 };
506 
507 } // end namespace __orc_rt
508 
509 #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H
510