1 //===-- wrapper_function_utils_test.cpp -----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file is a part of the ORC runtime. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "wrapper_function_utils.h" 14 #include "gtest/gtest.h" 15 16 using namespace __orc_rt; 17 18 namespace { 19 constexpr const char *TestString = "test string"; 20 } // end anonymous namespace 21 22 TEST(WrapperFunctionUtilsTest, DefaultWrapperFunctionResult) { 23 WrapperFunctionResult R; 24 EXPECT_TRUE(R.empty()); 25 EXPECT_EQ(R.size(), 0U); 26 EXPECT_EQ(R.getOutOfBandError(), nullptr); 27 } 28 29 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromCStruct) { 30 orc_rt_CWrapperFunctionResult CR = 31 orc_rt_CreateCWrapperFunctionResultFromString(TestString); 32 WrapperFunctionResult R(CR); 33 EXPECT_EQ(R.size(), strlen(TestString) + 1); 34 EXPECT_TRUE(strcmp(R.data(), TestString) == 0); 35 EXPECT_FALSE(R.empty()); 36 EXPECT_EQ(R.getOutOfBandError(), nullptr); 37 } 38 39 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromRange) { 40 auto R = WrapperFunctionResult::copyFrom(TestString, strlen(TestString) + 1); 41 EXPECT_EQ(R.size(), strlen(TestString) + 1); 42 EXPECT_TRUE(strcmp(R.data(), TestString) == 0); 43 EXPECT_FALSE(R.empty()); 44 EXPECT_EQ(R.getOutOfBandError(), nullptr); 45 } 46 47 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromCString) { 48 auto R = WrapperFunctionResult::copyFrom(TestString); 49 EXPECT_EQ(R.size(), strlen(TestString) + 1); 50 EXPECT_TRUE(strcmp(R.data(), TestString) == 0); 51 EXPECT_FALSE(R.empty()); 52 EXPECT_EQ(R.getOutOfBandError(), nullptr); 53 } 54 55 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromStdString) { 56 auto R = WrapperFunctionResult::copyFrom(std::string(TestString)); 57 EXPECT_EQ(R.size(), strlen(TestString) + 1); 58 EXPECT_TRUE(strcmp(R.data(), TestString) == 0); 59 EXPECT_FALSE(R.empty()); 60 EXPECT_EQ(R.getOutOfBandError(), nullptr); 61 } 62 63 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromOutOfBandError) { 64 auto R = WrapperFunctionResult::createOutOfBandError(TestString); 65 EXPECT_FALSE(R.empty()); 66 EXPECT_TRUE(strcmp(R.getOutOfBandError(), TestString) == 0); 67 } 68 69 TEST(WrapperFunctionUtilsTest, WrapperFunctionCCallCreateEmpty) { 70 EXPECT_TRUE(!!WrapperFunctionCall::Create<SPSArgList<>>(ExecutorAddr())); 71 } 72 73 static void voidNoop() {} 74 75 static orc_rt_CWrapperFunctionResult voidNoopWrapper(const char *ArgData, 76 size_t ArgSize) { 77 return WrapperFunction<void()>::handle(ArgData, ArgSize, voidNoop).release(); 78 } 79 80 static orc_rt_CWrapperFunctionResult addWrapper(const char *ArgData, 81 size_t ArgSize) { 82 return WrapperFunction<int32_t(int32_t, int32_t)>::handle( 83 ArgData, ArgSize, 84 [](int32_t X, int32_t Y) -> int32_t { return X + Y; }) 85 .release(); 86 } 87 88 extern "C" __orc_rt_Opaque __orc_rt_jit_dispatch_ctx{}; 89 90 extern "C" orc_rt_CWrapperFunctionResult 91 __orc_rt_jit_dispatch(__orc_rt_Opaque *Ctx, const void *FnTag, 92 const char *ArgData, size_t ArgSize) { 93 using WrapperFunctionType = 94 orc_rt_CWrapperFunctionResult (*)(const char *, size_t); 95 96 return reinterpret_cast<WrapperFunctionType>(const_cast<void *>(FnTag))( 97 ArgData, ArgSize); 98 } 99 100 TEST(WrapperFunctionUtilsTest, WrapperFunctionCallVoidNoopAndHandle) { 101 EXPECT_FALSE(!!WrapperFunction<void()>::call((void *)&voidNoopWrapper)); 102 } 103 104 TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAddWrapperAndHandle) { 105 int32_t Result; 106 EXPECT_FALSE(!!WrapperFunction<int32_t(int32_t, int32_t)>::call( 107 (void *)&addWrapper, Result, 1, 2)); 108 EXPECT_EQ(Result, (int32_t)3); 109 } 110 111 class AddClass { 112 public: 113 AddClass(int32_t X) : X(X) {} 114 int32_t addMethod(int32_t Y) { return X + Y; } 115 116 private: 117 int32_t X; 118 }; 119 120 static orc_rt_CWrapperFunctionResult addMethodWrapper(const char *ArgData, 121 size_t ArgSize) { 122 return WrapperFunction<int32_t(SPSExecutorAddr, int32_t)>::handle( 123 ArgData, ArgSize, makeMethodWrapperHandler(&AddClass::addMethod)) 124 .release(); 125 } 126 127 TEST(WrapperFunctionUtilsTest, WrapperFunctionMethodCallAndHandleRet) { 128 int32_t Result; 129 AddClass AddObj(1); 130 EXPECT_FALSE(!!WrapperFunction<int32_t(SPSExecutorAddr, int32_t)>::call( 131 (void *)&addMethodWrapper, Result, ExecutorAddr::fromPtr(&AddObj), 2)); 132 EXPECT_EQ(Result, (int32_t)3); 133 } 134 135 static orc_rt_CWrapperFunctionResult sumArrayWrapper(const char *ArgData, 136 size_t ArgSize) { 137 return WrapperFunction<int8_t(SPSExecutorAddrRange)>::handle( 138 ArgData, ArgSize, 139 [](ExecutorAddrRange R) { 140 int8_t Sum = 0; 141 for (char C : R.toSpan<char>()) 142 Sum += C; 143 return Sum; 144 }) 145 .release(); 146 } 147 148 TEST(WrapperFunctionUtilsTest, SerializedWrapperFunctionCallTest) { 149 { 150 // Check wrapper function calls. 151 char A[] = {1, 2, 3, 4}; 152 153 auto WFC = 154 cantFail(WrapperFunctionCall::Create<SPSArgList<SPSExecutorAddrRange>>( 155 ExecutorAddr::fromPtr(sumArrayWrapper), 156 ExecutorAddrRange(ExecutorAddr::fromPtr(A), 157 ExecutorAddrDiff(sizeof(A))))); 158 159 WrapperFunctionResult WFR(WFC.run()); 160 EXPECT_EQ(WFR.size(), 1U); 161 EXPECT_EQ(WFR.data()[0], 10); 162 } 163 164 { 165 // Check calls to void functions. 166 auto WFC = 167 cantFail(WrapperFunctionCall::Create<SPSArgList<SPSExecutorAddrRange>>( 168 ExecutorAddr::fromPtr(voidNoopWrapper), ExecutorAddrRange())); 169 auto Err = WFC.runWithSPSRet<void>(); 170 EXPECT_FALSE(!!Err); 171 } 172 173 { 174 // Check calls with arguments and return values. 175 auto WFC = 176 cantFail(WrapperFunctionCall::Create<SPSArgList<int32_t, int32_t>>( 177 ExecutorAddr::fromPtr(addWrapper), 2, 4)); 178 179 int32_t Result = 0; 180 auto Err = WFC.runWithSPSRet<int32_t>(Result); 181 EXPECT_FALSE(!!Err); 182 EXPECT_EQ(Result, 6); 183 } 184 } 185