1 #ifndef HALIDE_CODEGEN_METAL_DEV_H 2 #define HALIDE_CODEGEN_METAL_DEV_H 3 4 /** \file 5 * Defines the code-generator for producing Apple Metal shading language kernel code 6 */ 7 8 #include <sstream> 9 10 #include "CodeGen_C.h" 11 #include "CodeGen_GPU_Dev.h" 12 #include "Target.h" 13 14 namespace Halide { 15 namespace Internal { 16 17 class CodeGen_Metal_Dev : public CodeGen_GPU_Dev { 18 public: 19 CodeGen_Metal_Dev(Target target); 20 21 /** Compile a GPU kernel into the module. This may be called many times 22 * with different kernels, which will all be accumulated into a single 23 * source module shared by a given Halide pipeline. */ 24 void add_kernel(Stmt stmt, 25 const std::string &name, 26 const std::vector<DeviceArgument> &args) override; 27 28 /** (Re)initialize the GPU kernel module. This is separate from compile, 29 * since a GPU device module will often have many kernels compiled into it 30 * for a single pipeline. */ 31 void init_module() override; 32 33 std::vector<char> compile_to_src() override; 34 35 std::string get_current_kernel_name() override; 36 37 void dump() override; 38 39 std::string print_gpu_name(const std::string &name) override; 40 api_unique_name()41 std::string api_unique_name() override { 42 return "metal"; 43 } 44 45 protected: 46 class CodeGen_Metal_C : public CodeGen_C { 47 public: CodeGen_Metal_C(std::ostream & s,Target t)48 CodeGen_Metal_C(std::ostream &s, Target t) 49 : CodeGen_C(s, t) { 50 } 51 void add_kernel(const Stmt &stmt, 52 const std::string &name, 53 const std::vector<DeviceArgument> &args); 54 55 protected: 56 using CodeGen_C::visit; 57 std::string print_type(Type type, AppendSpaceIfNeeded space_option = DoNotAppendSpace) override; 58 // Vectors in Metal come in two varieties, regular and packed. 59 // For storage allocations and pointers used in address arithmetic, 60 // packed types must be used. For temporaries, constructors, etc. 61 // regular types must be used. 62 // This concept also potentially applies to half types, which are 63 // often only supported for storage, not arithmetic, 64 // hence the method name. 65 std::string print_storage_type(Type type); 66 std::string print_type_maybe_storage(Type type, bool storage, AppendSpaceIfNeeded space); 67 std::string print_reinterpret(Type type, const Expr &e) override; 68 std::string print_extern_call(const Call *op) override; 69 70 std::string get_memory_space(const std::string &); 71 72 std::string shared_name; 73 74 void visit(const Min *) override; 75 void visit(const Max *) override; 76 void visit(const Div *) override; 77 void visit(const Mod *) override; 78 void visit(const For *) override; 79 void visit(const Ramp *op) override; 80 void visit(const Broadcast *op) override; 81 void visit(const Call *op) override; 82 void visit(const Load *op) override; 83 void visit(const Store *op) override; 84 void visit(const Select *op) override; 85 void visit(const Allocate *op) override; 86 void visit(const Free *op) override; 87 void visit(const Cast *op) override; 88 void visit(const Atomic *op) override; 89 }; 90 91 std::ostringstream src_stream; 92 std::string cur_kernel_name; 93 CodeGen_Metal_C metal_c; 94 }; 95 96 } // namespace Internal 97 } // namespace Halide 98 99 #endif 100