1 //===- Invoke.cpp ------------------------------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
10 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
12 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
13 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
14 #include "mlir/Dialect/Linalg/Passes.h"
15 #include "mlir/ExecutionEngine/CRunnerUtils.h"
16 #include "mlir/ExecutionEngine/ExecutionEngine.h"
17 #include "mlir/ExecutionEngine/MemRefUtils.h"
18 #include "mlir/ExecutionEngine/RunnerUtils.h"
19 #include "mlir/IR/MLIRContext.h"
20 #include "mlir/InitAllDialects.h"
21 #include "mlir/Parser.h"
22 #include "mlir/Pass/PassManager.h"
23 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
24 #include "mlir/Target/LLVMIR/Export.h"
25 #include "llvm/Support/TargetSelect.h"
26 #include "llvm/Support/raw_ostream.h"
27 
28 #include "gmock/gmock.h"
29 
30 using namespace mlir;
31 
32 static struct LLVMInitializer {
LLVMInitializerLLVMInitializer33   LLVMInitializer() {
34     llvm::InitializeNativeTarget();
35     llvm::InitializeNativeTargetAsmPrinter();
36   }
37 } initializer;
38 
39 /// Simple conversion pipeline for the purpose of testing sources written in
40 /// dialects lowering to LLVM Dialect.
lowerToLLVMDialect(ModuleOp module)41 static LogicalResult lowerToLLVMDialect(ModuleOp module) {
42   PassManager pm(module.getContext());
43   pm.addPass(mlir::createMemRefToLLVMPass());
44   pm.addPass(mlir::createLowerToLLVMPass());
45   return pm.run(module);
46 }
47 
48 // The JIT isn't supported on Windows at that time
49 #ifndef _WIN32
50 
TEST(MLIRExecutionEngine,AddInteger)51 TEST(MLIRExecutionEngine, AddInteger) {
52   std::string moduleStr = R"mlir(
53   func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } {
54     %res = std.addi %arg0, %arg0 : i32
55     return %res : i32
56   }
57   )mlir";
58   DialectRegistry registry;
59   registerAllDialects(registry);
60   registerLLVMDialectTranslation(registry);
61   MLIRContext context(registry);
62   OwningModuleRef module = parseSourceString(moduleStr, &context);
63   ASSERT_TRUE(!!module);
64   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
65   auto jitOrError = ExecutionEngine::create(*module);
66   ASSERT_TRUE(!!jitOrError);
67   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
68   // The result of the function must be passed as output argument.
69   int result = 0;
70   llvm::Error error =
71       jit->invoke("foo", 42, ExecutionEngine::Result<int>(result));
72   ASSERT_TRUE(!error);
73   ASSERT_EQ(result, 42 + 42);
74 }
75 
TEST(MLIRExecutionEngine,SubtractFloat)76 TEST(MLIRExecutionEngine, SubtractFloat) {
77   std::string moduleStr = R"mlir(
78   func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } {
79     %res = std.subf %arg0, %arg1 : f32
80     return %res : f32
81   }
82   )mlir";
83   DialectRegistry registry;
84   registerAllDialects(registry);
85   registerLLVMDialectTranslation(registry);
86   MLIRContext context(registry);
87   OwningModuleRef module = parseSourceString(moduleStr, &context);
88   ASSERT_TRUE(!!module);
89   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
90   auto jitOrError = ExecutionEngine::create(*module);
91   ASSERT_TRUE(!!jitOrError);
92   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
93   // The result of the function must be passed as output argument.
94   float result = -1;
95   llvm::Error error =
96       jit->invoke("foo", 43.0f, 1.0f, ExecutionEngine::result(result));
97   ASSERT_TRUE(!error);
98   ASSERT_EQ(result, 42.f);
99 }
100 
TEST(NativeMemRefJit,ZeroRankMemref)101 TEST(NativeMemRefJit, ZeroRankMemref) {
102   OwningMemRef<float, 0> A({});
103   A[{}] = 42.;
104   ASSERT_EQ(*A->data, 42);
105   A[{}] = 0;
106   std::string moduleStr = R"mlir(
107   func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } {
108     %cst42 = constant 42.0 : f32
109     memref.store %cst42, %arg0[] : memref<f32>
110     return
111   }
112   )mlir";
113   DialectRegistry registry;
114   registerAllDialects(registry);
115   registerLLVMDialectTranslation(registry);
116   MLIRContext context(registry);
117   auto module = parseSourceString(moduleStr, &context);
118   ASSERT_TRUE(!!module);
119   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
120   auto jitOrError = ExecutionEngine::create(*module);
121   ASSERT_TRUE(!!jitOrError);
122   auto jit = std::move(jitOrError.get());
123 
124   llvm::Error error = jit->invoke("zero_ranked", &*A);
125   ASSERT_TRUE(!error);
126   EXPECT_EQ((A[{}]), 42.);
127   for (float &elt : *A)
128     EXPECT_EQ(&elt, &(A[{}]));
129 }
130 
TEST(NativeMemRefJit,RankOneMemref)131 TEST(NativeMemRefJit, RankOneMemref) {
132   int64_t shape[] = {9};
133   OwningMemRef<float, 1> A(shape);
134   int count = 1;
135   for (float &elt : *A) {
136     EXPECT_EQ(&elt, &(A[{count - 1}]));
137     elt = count++;
138   }
139 
140   std::string moduleStr = R"mlir(
141   func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } {
142     %cst42 = constant 42.0 : f32
143     %cst5 = constant 5 : index
144     memref.store %cst42, %arg0[%cst5] : memref<?xf32>
145     return
146   }
147   )mlir";
148   DialectRegistry registry;
149   registerAllDialects(registry);
150   registerLLVMDialectTranslation(registry);
151   MLIRContext context(registry);
152   auto module = parseSourceString(moduleStr, &context);
153   ASSERT_TRUE(!!module);
154   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
155   auto jitOrError = ExecutionEngine::create(*module);
156   ASSERT_TRUE(!!jitOrError);
157   auto jit = std::move(jitOrError.get());
158 
159   llvm::Error error = jit->invoke("one_ranked", &*A);
160   ASSERT_TRUE(!error);
161   count = 1;
162   for (float &elt : *A) {
163     if (count == 6)
164       EXPECT_EQ(elt, 42.);
165     else
166       EXPECT_EQ(elt, count);
167     count++;
168   }
169 }
170 
TEST(NativeMemRefJit,BasicMemref)171 TEST(NativeMemRefJit, BasicMemref) {
172   constexpr int K = 3;
173   constexpr int M = 7;
174   // Prepare arguments beforehand.
175   auto init = [=](float &elt, ArrayRef<int64_t> indices) {
176     assert(indices.size() == 2);
177     elt = M * indices[0] + indices[1];
178   };
179   int64_t shape[] = {K, M};
180   int64_t shapeAlloc[] = {K + 1, M + 1};
181   OwningMemRef<float, 2> A(shape, shapeAlloc, init);
182   ASSERT_EQ(A->sizes[0], K);
183   ASSERT_EQ(A->sizes[1], M);
184   ASSERT_EQ(A->strides[0], M + 1);
185   ASSERT_EQ(A->strides[1], 1);
186   for (int i = 0; i < K; ++i) {
187     for (int j = 0; j < M; ++j) {
188       EXPECT_EQ((A[{i, j}]), i * M + j);
189       EXPECT_EQ(&(A[{i, j}]), &((*A)[i][j]));
190     }
191   }
192   std::string moduleStr = R"mlir(
193   func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
194     %x = constant 2 : index
195     %y = constant 1 : index
196     %cst42 = constant 42.0 : f32
197     memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32>
198     memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32>
199     return
200   }
201   )mlir";
202   DialectRegistry registry;
203   registerAllDialects(registry);
204   registerLLVMDialectTranslation(registry);
205   MLIRContext context(registry);
206   OwningModuleRef module = parseSourceString(moduleStr, &context);
207   ASSERT_TRUE(!!module);
208   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
209   auto jitOrError = ExecutionEngine::create(*module);
210   ASSERT_TRUE(!!jitOrError);
211   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
212 
213   llvm::Error error = jit->invoke("rank2_memref", &*A, &*A);
214   ASSERT_TRUE(!error);
215   EXPECT_EQ(((*A)[1][2]), 42.);
216   EXPECT_EQ((A[{2, 1}]), 42.);
217 }
218 
219 // A helper function that will be called from the JIT
memref_multiply(::StridedMemRefType<float,2> * memref,int32_t coefficient)220 static void memref_multiply(::StridedMemRefType<float, 2> *memref,
221                             int32_t coefficient) {
222   for (float &elt : *memref)
223     elt *= coefficient;
224 }
225 
TEST(NativeMemRefJit,JITCallback)226 TEST(NativeMemRefJit, JITCallback) {
227   constexpr int K = 2;
228   constexpr int M = 2;
229   int64_t shape[] = {K, M};
230   int64_t shapeAlloc[] = {K + 1, M + 1};
231   OwningMemRef<float, 2> A(shape, shapeAlloc);
232   int count = 1;
233   for (float &elt : *A)
234     elt = count++;
235 
236   std::string moduleStr = R"mlir(
237   func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32)  attributes { llvm.emit_c_interface }
238   func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } {
239     %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32>
240     call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> ()
241     return
242   }
243   )mlir";
244   DialectRegistry registry;
245   registerAllDialects(registry);
246   registerLLVMDialectTranslation(registry);
247   MLIRContext context(registry);
248   auto module = parseSourceString(moduleStr, &context);
249   ASSERT_TRUE(!!module);
250   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
251   auto jitOrError = ExecutionEngine::create(*module);
252   ASSERT_TRUE(!!jitOrError);
253   auto jit = std::move(jitOrError.get());
254   // Define any extra symbols so they're available at runtime.
255   jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
256     llvm::orc::SymbolMap symbolMap;
257     symbolMap[interner("_mlir_ciface_callback")] =
258         llvm::JITEvaluatedSymbol::fromPointer(memref_multiply);
259     return symbolMap;
260   });
261 
262   int32_t coefficient = 3.;
263   llvm::Error error = jit->invoke("caller_for_callback", &*A, coefficient);
264   ASSERT_TRUE(!error);
265   count = 1;
266   for (float elt : *A)
267     ASSERT_EQ(elt, coefficient * count++);
268 }
269 
270 #endif // _WIN32
271