1 //===- ExecutionEngine.cpp - C API for MLIR JIT ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/ExecutionEngine.h"
10 #include "mlir/CAPI/ExecutionEngine.h"
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Support.h"
13 #include "mlir/ExecutionEngine/OptUtils.h"
14 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
15 #include "llvm/ExecutionEngine/Orc/Mangling.h"
16 #include "llvm/Support/TargetSelect.h"
17 
18 using namespace mlir;
19 
20 extern "C" MlirExecutionEngine
mlirExecutionEngineCreate(MlirModule op,int optLevel,int numPaths,const MlirStringRef * sharedLibPaths)21 mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths,
22                           const MlirStringRef *sharedLibPaths) {
23   static bool initOnce = [] {
24     llvm::InitializeNativeTarget();
25     llvm::InitializeNativeTargetAsmPrinter();
26     return true;
27   }();
28   (void)initOnce;
29 
30   mlir::registerLLVMDialectTranslation(*unwrap(op)->getContext());
31 
32   auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
33   if (!tmBuilderOrError) {
34     llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
35     return MlirExecutionEngine{nullptr};
36   }
37   auto tmOrError = tmBuilderOrError->createTargetMachine();
38   if (!tmOrError) {
39     llvm::errs() << "Failed to create a TargetMachine for the host\n";
40     return MlirExecutionEngine{nullptr};
41   }
42 
43   SmallVector<StringRef> libPaths;
44   for (unsigned i = 0; i < static_cast<unsigned>(numPaths); ++i)
45     libPaths.push_back(sharedLibPaths[i].data);
46 
47   // Create a transformer to run all LLVM optimization passes at the
48   // specified optimization level.
49   auto llvmOptLevel = static_cast<llvm::CodeGenOpt::Level>(optLevel);
50   auto transformer = mlir::makeLLVMPassesTransformer(
51       /*passes=*/{}, llvmOptLevel, /*targetMachine=*/tmOrError->get());
52   auto jitOrError =
53       ExecutionEngine::create(unwrap(op), /*llvmModuleBuilder=*/{}, transformer,
54                               llvmOptLevel, libPaths);
55   if (!jitOrError) {
56     consumeError(jitOrError.takeError());
57     return MlirExecutionEngine{nullptr};
58   }
59   return wrap(jitOrError->release());
60 }
61 
mlirExecutionEngineDestroy(MlirExecutionEngine jit)62 extern "C" void mlirExecutionEngineDestroy(MlirExecutionEngine jit) {
63   delete (unwrap(jit));
64 }
65 
66 extern "C" MlirLogicalResult
mlirExecutionEngineInvokePacked(MlirExecutionEngine jit,MlirStringRef name,void ** arguments)67 mlirExecutionEngineInvokePacked(MlirExecutionEngine jit, MlirStringRef name,
68                                 void **arguments) {
69   const std::string ifaceName = ("_mlir_ciface_" + unwrap(name)).str();
70   llvm::Error error = unwrap(jit)->invokePacked(
71       ifaceName, MutableArrayRef<void *>{arguments, (size_t)0});
72   if (error)
73     return wrap(failure());
74   return wrap(success());
75 }
76 
mlirExecutionEngineLookup(MlirExecutionEngine jit,MlirStringRef name)77 extern "C" void *mlirExecutionEngineLookup(MlirExecutionEngine jit,
78                                            MlirStringRef name) {
79   auto expectedFPtr = unwrap(jit)->lookup(unwrap(name));
80   if (!expectedFPtr)
81     return nullptr;
82   return reinterpret_cast<void *>(*expectedFPtr);
83 }
84 
mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit,MlirStringRef name,void * sym)85 extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit,
86                                                   MlirStringRef name,
87                                                   void *sym) {
88   unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
89     llvm::orc::SymbolMap symbolMap;
90     symbolMap[interner(unwrap(name))] =
91         llvm::JITEvaluatedSymbol::fromPointer(sym);
92     return symbolMap;
93   });
94 }
95 
mlirExecutionEngineDumpToObjectFile(MlirExecutionEngine jit,MlirStringRef name)96 extern "C" void mlirExecutionEngineDumpToObjectFile(MlirExecutionEngine jit,
97                                                     MlirStringRef name) {
98   unwrap(jit)->dumpToObjectFile(unwrap(name));
99 }
100