1 #ifndef HALIDE_CODEGEN_D3D12_COMPUTE_DEV_H 2 #define HALIDE_CODEGEN_D3D12_COMPUTE_DEV_H 3 4 /** \file 5 * Defines the code-generator for producing D3D12-compatible HLSL kernel code 6 */ 7 8 #include <sstream> 9 10 #include "CodeGen_C.h" 11 #include "CodeGen_GPU_Dev.h" 12 #include "Target.h" 13 14 namespace Halide { 15 namespace Internal { 16 17 class CodeGen_D3D12Compute_Dev : public CodeGen_GPU_Dev { 18 public: 19 CodeGen_D3D12Compute_Dev(Target target); 20 21 /** Compile a GPU kernel into the module. This may be called many times 22 * with different kernels, which will all be accumulated into a single 23 * source module shared by a given Halide pipeline. */ 24 void add_kernel(Stmt stmt, 25 const std::string &name, 26 const std::vector<DeviceArgument> &args) override; 27 28 /** (Re)initialize the GPU kernel module. This is separate from compile, 29 * since a GPU device module will often have many kernels compiled into it 30 * for a single pipeline. */ 31 void init_module() override; 32 33 std::vector<char> compile_to_src() override; 34 35 std::string get_current_kernel_name() override; 36 37 void dump() override; 38 39 std::string print_gpu_name(const std::string &name) override; 40 api_unique_name()41 std::string api_unique_name() override { 42 return "d3d12compute"; 43 } 44 45 protected: 46 friend struct StoragePackUnpack; 47 48 class CodeGen_D3D12Compute_C : public CodeGen_C { 49 public: CodeGen_D3D12Compute_C(std::ostream & s,Target t)50 CodeGen_D3D12Compute_C(std::ostream &s, Target t) 51 : CodeGen_C(s, t) { 52 } 53 void add_kernel(Stmt stmt, 54 const std::string &name, 55 const std::vector<DeviceArgument> &args); 56 57 protected: 58 friend struct StoragePackUnpack; 59 60 std::string print_type(Type type, AppendSpaceIfNeeded space_option = DoNotAppendSpace) override; 61 std::string print_storage_type(Type type); 62 std::string print_type_maybe_storage(Type type, bool storage, AppendSpaceIfNeeded space); 63 std::string print_reinterpret(Type type, const Expr &e) override; 64 std::string print_extern_call(const Call *op) override; 65 66 std::string print_vanilla_cast(Type type, const std::string &value_expr); 67 std::string print_reinforced_cast(Type type, const std::string &value_expr); 68 std::string print_cast(Type target_type, Type source_type, const std::string &value_expr); 69 std::string print_reinterpret_cast(Type type, const std::string &value_expr); 70 71 std::string print_assignment(Type t, const std::string &rhs) override; 72 73 using CodeGen_C::visit; 74 void visit(const Evaluate *op) override; 75 void visit(const Min *) override; 76 void visit(const Max *) override; 77 void visit(const Div *) override; 78 void visit(const Mod *) override; 79 void visit(const For *) override; 80 void visit(const Ramp *op) override; 81 void visit(const Broadcast *op) override; 82 void visit(const Call *op) override; 83 void visit(const Load *op) override; 84 void visit(const Store *op) override; 85 void visit(const Select *op) override; 86 void visit(const Allocate *op) override; 87 void visit(const Free *op) override; 88 void visit(const Cast *op) override; 89 void visit(const Atomic *op) override; 90 void visit(const FloatImm *op) override; 91 92 Scope<> groupshared_allocations; 93 }; 94 95 std::ostringstream src_stream; 96 std::string cur_kernel_name; 97 CodeGen_D3D12Compute_C d3d12compute_c; 98 }; 99 100 } // namespace Internal 101 } // namespace Halide 102 103 #endif 104