1 //===- GPUToCUDAPass.h - MLIR CUDA runtime support --------------*- C++ -*-===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 #ifndef MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_ 9 #define MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_ 10 11 #include "mlir/Support/LLVM.h" 12 #include <functional> 13 #include <memory> 14 #include <string> 15 #include <vector> 16 17 namespace mlir { 18 19 class Location; 20 class ModuleOp; 21 22 namespace LLVM { 23 class LLVMDialect; 24 } // namespace LLVM 25 26 template <typename T> class OpPassBase; 27 28 using OwnedCubin = std::unique_ptr<std::vector<char>>; 29 using CubinGenerator = 30 std::function<OwnedCubin(const std::string &, Location, StringRef)>; 31 32 /// Creates a pass to convert kernel functions into CUBIN blobs. 33 /// 34 /// This transformation takes the body of each function that is annotated with 35 /// the 'nvvm.kernel' attribute, copies it to a new LLVM module, compiles the 36 /// module with help of the nvptx backend to PTX and then invokes the provided 37 /// cubinGenerator to produce a binary blob (the cubin). Such blob is then 38 /// attached as a string attribute named 'nvvm.cubin' to the kernel function. 39 /// After the transformation, the body of the kernel function is removed (i.e., 40 /// it is turned into a declaration). 41 std::unique_ptr<OpPassBase<ModuleOp>> 42 createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator); 43 44 /// Creates a pass to convert a gpu.launch_func operation into a sequence of 45 /// CUDA calls. 46 /// 47 /// This pass does not generate code to call CUDA directly but instead uses a 48 /// small wrapper library that exports a stable and conveniently typed ABI 49 /// on top of CUDA. 50 std::unique_ptr<OpPassBase<ModuleOp>> 51 createConvertGpuLaunchFuncToCudaCallsPass(); 52 53 } // namespace mlir 54 55 #endif // MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_ 56