1 //===- mlir-vulkan-runner.cpp - MLIR Vulkan Execution Driver --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This is a command line utility that executes an MLIR file on the Vulkan by 10 // translating MLIR GPU module to SPIR-V and host part to LLVM IR before 11 // JIT-compiling and executing the latter. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h" 16 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" 17 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 18 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" 19 #include "mlir/Dialect/GPU/Passes.h" 20 #include "mlir/Dialect/SPIRV/Passes.h" 21 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 22 #include "mlir/ExecutionEngine/JitRunner.h" 23 #include "mlir/ExecutionEngine/OptUtils.h" 24 #include "mlir/InitAllDialects.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Pass/PassManager.h" 27 #include "llvm/Support/InitLLVM.h" 28 #include "llvm/Support/TargetSelect.h" 29 30 using namespace mlir; 31 runMLIRPasses(ModuleOp module)32static LogicalResult runMLIRPasses(ModuleOp module) { 33 PassManager passManager(module.getContext()); 34 applyPassManagerCLOptions(passManager); 35 36 passManager.addPass(createGpuKernelOutliningPass()); 37 passManager.addPass(createLegalizeStdOpsForSPIRVLoweringPass()); 38 passManager.addPass(createConvertGPUToSPIRVPass()); 39 OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>(); 40 modulePM.addPass(spirv::createLowerABIAttributesPass()); 41 modulePM.addPass(spirv::createUpdateVersionCapabilityExtensionPass()); 42 passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass()); 43 LowerToLLVMOptions llvmOptions = { 44 /*useBarePtrCallConv =*/false, 45 /*emitCWrappers = */ true, 46 /*indexBitwidth =*/kDeriveIndexBitwidthFromDataLayout}; 47 passManager.addPass(createLowerToLLVMPass(llvmOptions)); 48 passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass()); 49 return passManager.run(module); 50 } 51 main(int argc,char ** argv)52int main(int argc, char **argv) { 53 llvm::llvm_shutdown_obj x; 54 registerPassManagerCLOptions(); 55 56 mlir::registerAllDialects(); 57 llvm::InitLLVM y(argc, argv); 58 llvm::InitializeNativeTarget(); 59 llvm::InitializeNativeTargetAsmPrinter(); 60 mlir::initializeLLVMPasses(); 61 62 return mlir::JitRunnerMain(argc, argv, &runMLIRPasses); 63 } 64