1 //===- GPUToCUDAPass.h - MLIR CUDA runtime support --------------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #ifndef MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
9 #define MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
10 
11 #include "mlir/Support/LLVM.h"
12 #include <functional>
13 #include <memory>
14 #include <string>
15 #include <vector>
16 
17 namespace mlir {
18 
19 class Location;
20 class ModuleOp;
21 
22 namespace LLVM {
23 class LLVMDialect;
24 } // namespace LLVM
25 
26 template <typename T> class OpPassBase;
27 
28 using OwnedCubin = std::unique_ptr<std::vector<char>>;
29 using CubinGenerator =
30     std::function<OwnedCubin(const std::string &, Location, StringRef)>;
31 
32 /// Creates a pass to convert kernel functions into CUBIN blobs.
33 ///
34 /// This transformation takes the body of each function that is annotated with
35 /// the 'nvvm.kernel' attribute, copies it to a new LLVM module, compiles the
36 /// module with help of the nvptx backend to PTX and then invokes the provided
37 /// cubinGenerator to produce a binary blob (the cubin). Such blob is then
38 /// attached as a string attribute named 'nvvm.cubin' to the kernel function.
39 /// After the transformation, the body of the kernel function is removed (i.e.,
40 /// it is turned into a declaration).
41 std::unique_ptr<OpPassBase<ModuleOp>>
42 createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator);
43 
44 /// Creates a pass to convert a gpu.launch_func operation into a sequence of
45 /// CUDA calls.
46 ///
47 /// This pass does not generate code to call CUDA directly but instead uses a
48 /// small wrapper library that exports a stable and conveniently typed ABI
49 /// on top of CUDA.
50 std::unique_ptr<OpPassBase<ModuleOp>>
51 createConvertGpuLaunchFuncToCudaCallsPass();
52 
53 } // namespace mlir
54 
55 #endif // MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
56