1 //===- Invoke.cpp ------------------------------------*- C++ -*-===// 2 // 3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" 10 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 12 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 13 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 14 #include "mlir/Dialect/Linalg/Passes.h" 15 #include "mlir/ExecutionEngine/CRunnerUtils.h" 16 #include "mlir/ExecutionEngine/ExecutionEngine.h" 17 #include "mlir/ExecutionEngine/MemRefUtils.h" 18 #include "mlir/ExecutionEngine/RunnerUtils.h" 19 #include "mlir/IR/MLIRContext.h" 20 #include "mlir/InitAllDialects.h" 21 #include "mlir/Parser.h" 22 #include "mlir/Pass/PassManager.h" 23 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" 24 #include "mlir/Target/LLVMIR/Export.h" 25 #include "llvm/Support/TargetSelect.h" 26 #include "llvm/Support/raw_ostream.h" 27 28 #include "gmock/gmock.h" 29 30 using namespace mlir; 31 32 static struct LLVMInitializer { 33 LLVMInitializer() { 34 llvm::InitializeNativeTarget(); 35 llvm::InitializeNativeTargetAsmPrinter(); 36 } 37 } initializer; 38 39 /// Simple conversion pipeline for the purpose of testing sources written in 40 /// dialects lowering to LLVM Dialect. 41 static LogicalResult lowerToLLVMDialect(ModuleOp module) { 42 PassManager pm(module.getContext()); 43 pm.addPass(mlir::createMemRefToLLVMPass()); 44 pm.addPass(mlir::createLowerToLLVMPass()); 45 return pm.run(module); 46 } 47 48 // The JIT isn't supported on Windows at that time 49 #ifndef _WIN32 50 51 TEST(MLIRExecutionEngine, AddInteger) { 52 std::string moduleStr = R"mlir( 53 func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } { 54 %res = std.addi %arg0, %arg0 : i32 55 return %res : i32 56 } 57 )mlir"; 58 DialectRegistry registry; 59 registerAllDialects(registry); 60 registerLLVMDialectTranslation(registry); 61 MLIRContext context(registry); 62 OwningModuleRef module = parseSourceString(moduleStr, &context); 63 ASSERT_TRUE(!!module); 64 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 65 auto jitOrError = ExecutionEngine::create(*module); 66 ASSERT_TRUE(!!jitOrError); 67 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); 68 // The result of the function must be passed as output argument. 69 int result = 0; 70 llvm::Error error = 71 jit->invoke("foo", 42, ExecutionEngine::Result<int>(result)); 72 ASSERT_TRUE(!error); 73 ASSERT_EQ(result, 42 + 42); 74 } 75 76 TEST(MLIRExecutionEngine, SubtractFloat) { 77 std::string moduleStr = R"mlir( 78 func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } { 79 %res = std.subf %arg0, %arg1 : f32 80 return %res : f32 81 } 82 )mlir"; 83 DialectRegistry registry; 84 registerAllDialects(registry); 85 registerLLVMDialectTranslation(registry); 86 MLIRContext context(registry); 87 OwningModuleRef module = parseSourceString(moduleStr, &context); 88 ASSERT_TRUE(!!module); 89 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 90 auto jitOrError = ExecutionEngine::create(*module); 91 ASSERT_TRUE(!!jitOrError); 92 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); 93 // The result of the function must be passed as output argument. 94 float result = -1; 95 llvm::Error error = 96 jit->invoke("foo", 43.0f, 1.0f, ExecutionEngine::result(result)); 97 ASSERT_TRUE(!error); 98 ASSERT_EQ(result, 42.f); 99 } 100 101 TEST(NativeMemRefJit, ZeroRankMemref) { 102 OwningMemRef<float, 0> A({}); 103 A[{}] = 42.; 104 ASSERT_EQ(*A->data, 42); 105 A[{}] = 0; 106 std::string moduleStr = R"mlir( 107 func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } { 108 %cst42 = constant 42.0 : f32 109 memref.store %cst42, %arg0[] : memref<f32> 110 return 111 } 112 )mlir"; 113 DialectRegistry registry; 114 registerAllDialects(registry); 115 registerLLVMDialectTranslation(registry); 116 MLIRContext context(registry); 117 auto module = parseSourceString(moduleStr, &context); 118 ASSERT_TRUE(!!module); 119 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 120 auto jitOrError = ExecutionEngine::create(*module); 121 ASSERT_TRUE(!!jitOrError); 122 auto jit = std::move(jitOrError.get()); 123 124 llvm::Error error = jit->invoke("zero_ranked", &*A); 125 ASSERT_TRUE(!error); 126 EXPECT_EQ((A[{}]), 42.); 127 for (float &elt : *A) 128 EXPECT_EQ(&elt, &(A[{}])); 129 } 130 131 TEST(NativeMemRefJit, RankOneMemref) { 132 int64_t shape[] = {9}; 133 OwningMemRef<float, 1> A(shape); 134 int count = 1; 135 for (float &elt : *A) { 136 EXPECT_EQ(&elt, &(A[{count - 1}])); 137 elt = count++; 138 } 139 140 std::string moduleStr = R"mlir( 141 func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } { 142 %cst42 = constant 42.0 : f32 143 %cst5 = constant 5 : index 144 memref.store %cst42, %arg0[%cst5] : memref<?xf32> 145 return 146 } 147 )mlir"; 148 DialectRegistry registry; 149 registerAllDialects(registry); 150 registerLLVMDialectTranslation(registry); 151 MLIRContext context(registry); 152 auto module = parseSourceString(moduleStr, &context); 153 ASSERT_TRUE(!!module); 154 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 155 auto jitOrError = ExecutionEngine::create(*module); 156 ASSERT_TRUE(!!jitOrError); 157 auto jit = std::move(jitOrError.get()); 158 159 llvm::Error error = jit->invoke("one_ranked", &*A); 160 ASSERT_TRUE(!error); 161 count = 1; 162 for (float &elt : *A) { 163 if (count == 6) 164 EXPECT_EQ(elt, 42.); 165 else 166 EXPECT_EQ(elt, count); 167 count++; 168 } 169 } 170 171 TEST(NativeMemRefJit, BasicMemref) { 172 constexpr int K = 3; 173 constexpr int M = 7; 174 // Prepare arguments beforehand. 175 auto init = [=](float &elt, ArrayRef<int64_t> indices) { 176 assert(indices.size() == 2); 177 elt = M * indices[0] + indices[1]; 178 }; 179 int64_t shape[] = {K, M}; 180 int64_t shapeAlloc[] = {K + 1, M + 1}; 181 OwningMemRef<float, 2> A(shape, shapeAlloc, init); 182 ASSERT_EQ(A->sizes[0], K); 183 ASSERT_EQ(A->sizes[1], M); 184 ASSERT_EQ(A->strides[0], M + 1); 185 ASSERT_EQ(A->strides[1], 1); 186 for (int i = 0; i < K; ++i) { 187 for (int j = 0; j < M; ++j) { 188 EXPECT_EQ((A[{i, j}]), i * M + j); 189 EXPECT_EQ(&(A[{i, j}]), &((*A)[i][j])); 190 } 191 } 192 std::string moduleStr = R"mlir( 193 func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } { 194 %x = constant 2 : index 195 %y = constant 1 : index 196 %cst42 = constant 42.0 : f32 197 memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32> 198 memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32> 199 return 200 } 201 )mlir"; 202 DialectRegistry registry; 203 registerAllDialects(registry); 204 registerLLVMDialectTranslation(registry); 205 MLIRContext context(registry); 206 OwningModuleRef module = parseSourceString(moduleStr, &context); 207 ASSERT_TRUE(!!module); 208 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 209 auto jitOrError = ExecutionEngine::create(*module); 210 ASSERT_TRUE(!!jitOrError); 211 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); 212 213 llvm::Error error = jit->invoke("rank2_memref", &*A, &*A); 214 ASSERT_TRUE(!error); 215 EXPECT_EQ(((*A)[1][2]), 42.); 216 EXPECT_EQ((A[{2, 1}]), 42.); 217 } 218 219 // A helper function that will be called from the JIT 220 static void memref_multiply(::StridedMemRefType<float, 2> *memref, 221 int32_t coefficient) { 222 for (float &elt : *memref) 223 elt *= coefficient; 224 } 225 226 TEST(NativeMemRefJit, JITCallback) { 227 constexpr int K = 2; 228 constexpr int M = 2; 229 int64_t shape[] = {K, M}; 230 int64_t shapeAlloc[] = {K + 1, M + 1}; 231 OwningMemRef<float, 2> A(shape, shapeAlloc); 232 int count = 1; 233 for (float &elt : *A) 234 elt = count++; 235 236 std::string moduleStr = R"mlir( 237 func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } 238 func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } { 239 %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32> 240 call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> () 241 return 242 } 243 )mlir"; 244 DialectRegistry registry; 245 registerAllDialects(registry); 246 registerLLVMDialectTranslation(registry); 247 MLIRContext context(registry); 248 auto module = parseSourceString(moduleStr, &context); 249 ASSERT_TRUE(!!module); 250 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 251 auto jitOrError = ExecutionEngine::create(*module); 252 ASSERT_TRUE(!!jitOrError); 253 auto jit = std::move(jitOrError.get()); 254 // Define any extra symbols so they're available at runtime. 255 jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) { 256 llvm::orc::SymbolMap symbolMap; 257 symbolMap[interner("_mlir_ciface_callback")] = 258 llvm::JITEvaluatedSymbol::fromPointer(memref_multiply); 259 return symbolMap; 260 }); 261 262 int32_t coefficient = 3.; 263 llvm::Error error = jit->invoke("caller_for_callback", &*A, coefficient); 264 ASSERT_TRUE(!error); 265 count = 1; 266 for (float elt : *A) 267 ASSERT_EQ(elt, coefficient * count++); 268 } 269 270 #endif // _WIN32 271