1 #ifndef HALIDE_CODEGEN_D3D12_COMPUTE_DEV_H
2 #define HALIDE_CODEGEN_D3D12_COMPUTE_DEV_H
3 
4 /** \file
5  * Defines the code-generator for producing D3D12-compatible HLSL kernel code
6  */
7 
8 #include <sstream>
9 
10 #include "CodeGen_C.h"
11 #include "CodeGen_GPU_Dev.h"
12 #include "Target.h"
13 
14 namespace Halide {
15 namespace Internal {
16 
17 class CodeGen_D3D12Compute_Dev : public CodeGen_GPU_Dev {
18 public:
19     CodeGen_D3D12Compute_Dev(Target target);
20 
21     /** Compile a GPU kernel into the module. This may be called many times
22      * with different kernels, which will all be accumulated into a single
23      * source module shared by a given Halide pipeline. */
24     void add_kernel(Stmt stmt,
25                     const std::string &name,
26                     const std::vector<DeviceArgument> &args) override;
27 
28     /** (Re)initialize the GPU kernel module. This is separate from compile,
29      * since a GPU device module will often have many kernels compiled into it
30      * for a single pipeline. */
31     void init_module() override;
32 
33     std::vector<char> compile_to_src() override;
34 
35     std::string get_current_kernel_name() override;
36 
37     void dump() override;
38 
39     std::string print_gpu_name(const std::string &name) override;
40 
api_unique_name()41     std::string api_unique_name() override {
42         return "d3d12compute";
43     }
44 
45 protected:
46     friend struct StoragePackUnpack;
47 
48     class CodeGen_D3D12Compute_C : public CodeGen_C {
49     public:
CodeGen_D3D12Compute_C(std::ostream & s,Target t)50         CodeGen_D3D12Compute_C(std::ostream &s, Target t)
51             : CodeGen_C(s, t) {
52         }
53         void add_kernel(Stmt stmt,
54                         const std::string &name,
55                         const std::vector<DeviceArgument> &args);
56 
57     protected:
58         friend struct StoragePackUnpack;
59 
60         std::string print_type(Type type, AppendSpaceIfNeeded space_option = DoNotAppendSpace) override;
61         std::string print_storage_type(Type type);
62         std::string print_type_maybe_storage(Type type, bool storage, AppendSpaceIfNeeded space);
63         std::string print_reinterpret(Type type, const Expr &e) override;
64         std::string print_extern_call(const Call *op) override;
65 
66         std::string print_vanilla_cast(Type type, const std::string &value_expr);
67         std::string print_reinforced_cast(Type type, const std::string &value_expr);
68         std::string print_cast(Type target_type, Type source_type, const std::string &value_expr);
69         std::string print_reinterpret_cast(Type type, const std::string &value_expr);
70 
71         std::string print_assignment(Type t, const std::string &rhs) override;
72 
73         using CodeGen_C::visit;
74         void visit(const Evaluate *op) override;
75         void visit(const Min *) override;
76         void visit(const Max *) override;
77         void visit(const Div *) override;
78         void visit(const Mod *) override;
79         void visit(const For *) override;
80         void visit(const Ramp *op) override;
81         void visit(const Broadcast *op) override;
82         void visit(const Call *op) override;
83         void visit(const Load *op) override;
84         void visit(const Store *op) override;
85         void visit(const Select *op) override;
86         void visit(const Allocate *op) override;
87         void visit(const Free *op) override;
88         void visit(const Cast *op) override;
89         void visit(const Atomic *op) override;
90         void visit(const FloatImm *op) override;
91 
92         Scope<> groupshared_allocations;
93     };
94 
95     std::ostringstream src_stream;
96     std::string cur_kernel_name;
97     CodeGen_D3D12Compute_C d3d12compute_c;
98 };
99 
100 }  // namespace Internal
101 }  // namespace Halide
102 
103 #endif
104