1 #ifndef HALIDE_CODEGEN_PTX_DEV_H 2 #define HALIDE_CODEGEN_PTX_DEV_H 3 4 /** \file 5 * Defines the code-generator for producing CUDA host code 6 */ 7 8 #include "CodeGen_GPU_Dev.h" 9 #include "CodeGen_GPU_Host.h" 10 #include "CodeGen_LLVM.h" 11 12 namespace llvm { 13 class BasicBlock; 14 } 15 16 namespace Halide { 17 namespace Internal { 18 19 /** A code generator that emits GPU code from a given Halide stmt. */ 20 class CodeGen_PTX_Dev : public CodeGen_LLVM, public CodeGen_GPU_Dev { 21 public: 22 friend class CodeGen_GPU_Host<CodeGen_X86>; 23 friend class CodeGen_GPU_Host<CodeGen_ARM>; 24 25 /** Create a PTX device code generator. */ 26 CodeGen_PTX_Dev(Target host); 27 ~CodeGen_PTX_Dev() override; 28 29 void add_kernel(Stmt stmt, 30 const std::string &name, 31 const std::vector<DeviceArgument> &args) override; 32 33 static void test(); 34 35 std::vector<char> compile_to_src() override; 36 std::string get_current_kernel_name() override; 37 38 void dump() override; 39 40 std::string print_gpu_name(const std::string &name) override; 41 api_unique_name()42 std::string api_unique_name() override { 43 return "cuda"; 44 } 45 46 protected: 47 using CodeGen_LLVM::visit; 48 49 /** (Re)initialize the PTX module. This is separate from compile, since 50 * a PTX device module will often have many kernels compiled into it for 51 * a single pipeline. */ 52 /* override */ void init_module() override; 53 54 /** We hold onto the basic block at the start of the device 55 * function in order to inject allocas */ 56 llvm::BasicBlock *entry_block; 57 58 /** Nodes for which we need to override default behavior for the GPU runtime */ 59 // @{ 60 void visit(const Call *) override; 61 void visit(const For *) override; 62 void visit(const Allocate *) override; 63 void visit(const Free *) override; 64 void visit(const AssertStmt *) override; 65 void visit(const Load *) override; 66 void visit(const Store *) override; 67 void visit(const Atomic *) override; 68 void codegen_vector_reduce(const VectorReduce *op, const Expr &init) override; 69 // @} 70 71 std::string march() const; 72 std::string mcpu() const override; 73 std::string mattrs() const override; 74 bool use_soft_float_abi() const override; 75 int native_vector_bits() const override; promote_indices()76 bool promote_indices() const override { 77 return false; 78 } 79 upgrade_type_for_arithmetic(const Type & t)80 Type upgrade_type_for_arithmetic(const Type &t) const override { 81 return t; 82 } 83 Type upgrade_type_for_storage(const Type &t) const override; 84 85 /** Map from simt variable names (e.g. foo.__block_id_x) to the llvm 86 * ptx intrinsic functions to call to get them. */ 87 std::string simt_intrinsic(const std::string &name); 88 89 bool supports_atomic_add(const Type &t) const override; 90 }; 91 92 } // namespace Internal 93 } // namespace Halide 94 95 #endif 96