1 #ifndef HALIDE_CODEGEN_PTX_DEV_H
2 #define HALIDE_CODEGEN_PTX_DEV_H
3 
4 /** \file
5  * Defines the code-generator for producing CUDA host code
6  */
7 
8 #include "CodeGen_GPU_Dev.h"
9 #include "CodeGen_GPU_Host.h"
10 #include "CodeGen_LLVM.h"
11 
12 namespace llvm {
13 class BasicBlock;
14 }
15 
16 namespace Halide {
17 namespace Internal {
18 
19 /** A code generator that emits GPU code from a given Halide stmt. */
20 class CodeGen_PTX_Dev : public CodeGen_LLVM, public CodeGen_GPU_Dev {
21 public:
22     friend class CodeGen_GPU_Host<CodeGen_X86>;
23     friend class CodeGen_GPU_Host<CodeGen_ARM>;
24 
25     /** Create a PTX device code generator. */
26     CodeGen_PTX_Dev(Target host);
27     ~CodeGen_PTX_Dev() override;
28 
29     void add_kernel(Stmt stmt,
30                     const std::string &name,
31                     const std::vector<DeviceArgument> &args) override;
32 
33     static void test();
34 
35     std::vector<char> compile_to_src() override;
36     std::string get_current_kernel_name() override;
37 
38     void dump() override;
39 
40     std::string print_gpu_name(const std::string &name) override;
41 
api_unique_name()42     std::string api_unique_name() override {
43         return "cuda";
44     }
45 
46 protected:
47     using CodeGen_LLVM::visit;
48 
49     /** (Re)initialize the PTX module. This is separate from compile, since
50      * a PTX device module will often have many kernels compiled into it for
51      * a single pipeline. */
52     /* override */ void init_module() override;
53 
54     /** We hold onto the basic block at the start of the device
55      * function in order to inject allocas */
56     llvm::BasicBlock *entry_block;
57 
58     /** Nodes for which we need to override default behavior for the GPU runtime */
59     // @{
60     void visit(const Call *) override;
61     void visit(const For *) override;
62     void visit(const Allocate *) override;
63     void visit(const Free *) override;
64     void visit(const AssertStmt *) override;
65     void visit(const Load *) override;
66     void visit(const Store *) override;
67     void visit(const Atomic *) override;
68     void codegen_vector_reduce(const VectorReduce *op, const Expr &init) override;
69     // @}
70 
71     std::string march() const;
72     std::string mcpu() const override;
73     std::string mattrs() const override;
74     bool use_soft_float_abi() const override;
75     int native_vector_bits() const override;
promote_indices()76     bool promote_indices() const override {
77         return false;
78     }
79 
upgrade_type_for_arithmetic(const Type & t)80     Type upgrade_type_for_arithmetic(const Type &t) const override {
81         return t;
82     }
83     Type upgrade_type_for_storage(const Type &t) const override;
84 
85     /** Map from simt variable names (e.g. foo.__block_id_x) to the llvm
86      * ptx intrinsic functions to call to get them. */
87     std::string simt_intrinsic(const std::string &name);
88 
89     bool supports_atomic_add(const Type &t) const override;
90 };
91 
92 }  // namespace Internal
93 }  // namespace Halide
94 
95 #endif
96