1 //===-- wrapper_function_utils_test.cpp -----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file is a part of the ORC runtime.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "wrapper_function_utils.h"
14 #include "gtest/gtest.h"
15 
16 using namespace __orc_rt;
17 
18 namespace {
19 constexpr const char *TestString = "test string";
20 } // end anonymous namespace
21 
TEST(WrapperFunctionUtilsTest,DefaultWrapperFunctionResult)22 TEST(WrapperFunctionUtilsTest, DefaultWrapperFunctionResult) {
23   WrapperFunctionResult R;
24   EXPECT_TRUE(R.empty());
25   EXPECT_EQ(R.size(), 0U);
26   EXPECT_EQ(R.getOutOfBandError(), nullptr);
27 }
28 
TEST(WrapperFunctionUtilsTest,WrapperFunctionResultFromCStruct)29 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromCStruct) {
30   orc_rt_CWrapperFunctionResult CR =
31       orc_rt_CreateCWrapperFunctionResultFromString(TestString);
32   WrapperFunctionResult R(CR);
33   EXPECT_EQ(R.size(), strlen(TestString) + 1);
34   EXPECT_TRUE(strcmp(R.data(), TestString) == 0);
35   EXPECT_FALSE(R.empty());
36   EXPECT_EQ(R.getOutOfBandError(), nullptr);
37 }
38 
TEST(WrapperFunctionUtilsTest,WrapperFunctionResultFromRange)39 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromRange) {
40   auto R = WrapperFunctionResult::copyFrom(TestString, strlen(TestString) + 1);
41   EXPECT_EQ(R.size(), strlen(TestString) + 1);
42   EXPECT_TRUE(strcmp(R.data(), TestString) == 0);
43   EXPECT_FALSE(R.empty());
44   EXPECT_EQ(R.getOutOfBandError(), nullptr);
45 }
46 
TEST(WrapperFunctionUtilsTest,WrapperFunctionResultFromCString)47 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromCString) {
48   auto R = WrapperFunctionResult::copyFrom(TestString);
49   EXPECT_EQ(R.size(), strlen(TestString) + 1);
50   EXPECT_TRUE(strcmp(R.data(), TestString) == 0);
51   EXPECT_FALSE(R.empty());
52   EXPECT_EQ(R.getOutOfBandError(), nullptr);
53 }
54 
TEST(WrapperFunctionUtilsTest,WrapperFunctionResultFromStdString)55 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromStdString) {
56   auto R = WrapperFunctionResult::copyFrom(std::string(TestString));
57   EXPECT_EQ(R.size(), strlen(TestString) + 1);
58   EXPECT_TRUE(strcmp(R.data(), TestString) == 0);
59   EXPECT_FALSE(R.empty());
60   EXPECT_EQ(R.getOutOfBandError(), nullptr);
61 }
62 
TEST(WrapperFunctionUtilsTest,WrapperFunctionResultFromOutOfBandError)63 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromOutOfBandError) {
64   auto R = WrapperFunctionResult::createOutOfBandError(TestString);
65   EXPECT_FALSE(R.empty());
66   EXPECT_TRUE(strcmp(R.getOutOfBandError(), TestString) == 0);
67 }
68 
TEST(WrapperFunctionUtilsTest,WrapperFunctionCCallCreateEmpty)69 TEST(WrapperFunctionUtilsTest, WrapperFunctionCCallCreateEmpty) {
70   EXPECT_TRUE(!!WrapperFunctionCall::Create<SPSArgList<>>(ExecutorAddr()));
71 }
72 
voidNoop()73 static void voidNoop() {}
74 
voidNoopWrapper(const char * ArgData,size_t ArgSize)75 static orc_rt_CWrapperFunctionResult voidNoopWrapper(const char *ArgData,
76                                                      size_t ArgSize) {
77   return WrapperFunction<void()>::handle(ArgData, ArgSize, voidNoop).release();
78 }
79 
addWrapper(const char * ArgData,size_t ArgSize)80 static orc_rt_CWrapperFunctionResult addWrapper(const char *ArgData,
81                                                 size_t ArgSize) {
82   return WrapperFunction<int32_t(int32_t, int32_t)>::handle(
83              ArgData, ArgSize,
84              [](int32_t X, int32_t Y) -> int32_t { return X + Y; })
85       .release();
86 }
87 
88 extern "C" __orc_rt_Opaque __orc_rt_jit_dispatch_ctx{};
89 
90 extern "C" orc_rt_CWrapperFunctionResult
__orc_rt_jit_dispatch(__orc_rt_Opaque * Ctx,const void * FnTag,const char * ArgData,size_t ArgSize)91 __orc_rt_jit_dispatch(__orc_rt_Opaque *Ctx, const void *FnTag,
92                       const char *ArgData, size_t ArgSize) {
93   using WrapperFunctionType =
94       orc_rt_CWrapperFunctionResult (*)(const char *, size_t);
95 
96   return reinterpret_cast<WrapperFunctionType>(const_cast<void *>(FnTag))(
97       ArgData, ArgSize);
98 }
99 
TEST(WrapperFunctionUtilsTest,WrapperFunctionCallVoidNoopAndHandle)100 TEST(WrapperFunctionUtilsTest, WrapperFunctionCallVoidNoopAndHandle) {
101   EXPECT_FALSE(!!WrapperFunction<void()>::call((void *)&voidNoopWrapper));
102 }
103 
TEST(WrapperFunctionUtilsTest,WrapperFunctionCallAddWrapperAndHandle)104 TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAddWrapperAndHandle) {
105   int32_t Result;
106   EXPECT_FALSE(!!WrapperFunction<int32_t(int32_t, int32_t)>::call(
107       (void *)&addWrapper, Result, 1, 2));
108   EXPECT_EQ(Result, (int32_t)3);
109 }
110 
111 class AddClass {
112 public:
AddClass(int32_t X)113   AddClass(int32_t X) : X(X) {}
addMethod(int32_t Y)114   int32_t addMethod(int32_t Y) { return X + Y; }
115 
116 private:
117   int32_t X;
118 };
119 
addMethodWrapper(const char * ArgData,size_t ArgSize)120 static orc_rt_CWrapperFunctionResult addMethodWrapper(const char *ArgData,
121                                                       size_t ArgSize) {
122   return WrapperFunction<int32_t(SPSExecutorAddr, int32_t)>::handle(
123              ArgData, ArgSize, makeMethodWrapperHandler(&AddClass::addMethod))
124       .release();
125 }
126 
TEST(WrapperFunctionUtilsTest,WrapperFunctionMethodCallAndHandleRet)127 TEST(WrapperFunctionUtilsTest, WrapperFunctionMethodCallAndHandleRet) {
128   int32_t Result;
129   AddClass AddObj(1);
130   EXPECT_FALSE(!!WrapperFunction<int32_t(SPSExecutorAddr, int32_t)>::call(
131       (void *)&addMethodWrapper, Result, ExecutorAddr::fromPtr(&AddObj), 2));
132   EXPECT_EQ(Result, (int32_t)3);
133 }
134 
sumArrayWrapper(const char * ArgData,size_t ArgSize)135 static orc_rt_CWrapperFunctionResult sumArrayWrapper(const char *ArgData,
136                                                      size_t ArgSize) {
137   return WrapperFunction<int8_t(SPSExecutorAddrRange)>::handle(
138              ArgData, ArgSize,
139              [](ExecutorAddrRange R) {
140                int8_t Sum = 0;
141                for (char C : R.toSpan<char>())
142                  Sum += C;
143                return Sum;
144              })
145       .release();
146 }
147 
TEST(WrapperFunctionUtilsTest,SerializedWrapperFunctionCallTest)148 TEST(WrapperFunctionUtilsTest, SerializedWrapperFunctionCallTest) {
149   {
150     // Check wrapper function calls.
151     char A[] = {1, 2, 3, 4};
152 
153     auto WFC =
154         cantFail(WrapperFunctionCall::Create<SPSArgList<SPSExecutorAddrRange>>(
155             ExecutorAddr::fromPtr(sumArrayWrapper),
156             ExecutorAddrRange(ExecutorAddr::fromPtr(A),
157                               ExecutorAddrDiff(sizeof(A)))));
158 
159     WrapperFunctionResult WFR(WFC.run());
160     EXPECT_EQ(WFR.size(), 1U);
161     EXPECT_EQ(WFR.data()[0], 10);
162   }
163 
164   {
165     // Check calls to void functions.
166     auto WFC =
167         cantFail(WrapperFunctionCall::Create<SPSArgList<SPSExecutorAddrRange>>(
168             ExecutorAddr::fromPtr(voidNoopWrapper), ExecutorAddrRange()));
169     auto Err = WFC.runWithSPSRet<void>();
170     EXPECT_FALSE(!!Err);
171   }
172 
173   {
174     // Check calls with arguments and return values.
175     auto WFC =
176         cantFail(WrapperFunctionCall::Create<SPSArgList<int32_t, int32_t>>(
177             ExecutorAddr::fromPtr(addWrapper), 2, 4));
178 
179     int32_t Result = 0;
180     auto Err = WFC.runWithSPSRet<int32_t>(Result);
181     EXPECT_FALSE(!!Err);
182     EXPECT_EQ(Result, 6);
183   }
184 }
185