1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file codegen_hybrid.h 22 * \brief Common utilities to generated C style code. 23 */ 24 #ifndef TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_ 25 #define TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_ 26 27 #include <tvm/ir.h> 28 #include <tvm/ir_functor_ext.h> 29 #include <tvm/codegen.h> 30 #include <tvm/lowered_func.h> 31 #include <tvm/schedule.h> 32 #include <map> 33 #include <string> 34 #include <unordered_map> 35 #include <utility> 36 #include <vector> 37 38 namespace tvm { 39 namespace contrib { 40 41 using namespace ir; 42 /*! 43 * \brief A base class to generate Hybrid Script. 44 * 45 * **NOTE** CodeGenHybrid does not aim at generating Python scripts consumed by Python2/3. 46 * For runtime support, please refer the decorator in ``tvm/python/hybrid/api.py``. 47 */ 48 class CodeGenHybrid : 49 public ExprFunctor<void(const Expr&, std::ostream&)>, 50 public StmtFunctor<void(const Stmt&)> { 51 public: 52 /*! 53 * \brief Dump the given function body to hybrid script. 54 * \param stmt The function body to be dumped to hybrid script. 55 * \param inputs Input tensors of this schedule. 56 * \param outputs Output tensors of this schedule. 57 * \param name The name of the function. 58 */ 59 void DumpStmt(const Stmt &stmt, const Array<NodeRef> &inputs, const Array<Tensor> &outputs, 60 const std::string &name = "hybrid_func"); 61 /*! 62 * \brief Finalize the compilation and return the code. 63 * \return The code. 64 */ 65 std::string Finish(); 66 /*! \brief Reserve keywords in avoid of name conflict. */ 67 void ReserveKeywords(); 68 /*! 69 * \brief Print the Stmt n to CodeGenHybrid->stream 70 * \param n The statement to be printed. 71 */ PrintStmt(const Stmt & n)72 void PrintStmt(const Stmt &n) { 73 this->VisitStmt(n); 74 } 75 /*! 76 * \brief Print the expression n(or its ssa id if in ssa mode) into os 77 * \param n The expression to be printed. 78 * \param os The output stream 79 */ PrintExpr(const Expr & n,std::ostream & os)80 void PrintExpr(const Expr &n, std::ostream &os) { 81 this->VisitExpr(n, os); 82 } 83 /*! 84 * \brief Same as PrintExpr, but simply returns result string 85 * \param n The expression to be printed. 86 */ PrintExpr(const Expr & n)87 std::string PrintExpr(const Expr &n) { 88 std::ostringstream os; 89 PrintExpr(n, os); 90 return os.str(); 91 } 92 // expression 93 void VisitExpr_(const Variable* op, std::ostream& os) override; // NOLINT(*) 94 void VisitExpr_(const Load* op, std::ostream& os) override; // NOLINT(*) 95 void VisitExpr_(const Let* op, std::ostream& os) override; // NOLINT(*) 96 void VisitExpr_(const Call* op, std::ostream& os) override; // NOLINT(*) 97 void VisitExpr_(const Add* op, std::ostream& os) override; // NOLINT(*) 98 void VisitExpr_(const Sub* op, std::ostream& os) override; // NOLINT(*) 99 void VisitExpr_(const Mul* op, std::ostream& os) override; // NOLINT(*) 100 void VisitExpr_(const Div* op, std::ostream& os) override; // NOLINT(*) 101 void VisitExpr_(const Mod* op, std::ostream& os) override; // NOLINT(*) 102 void VisitExpr_(const FloorDiv* op, std::ostream& os) override; // NOLINT(*) 103 void VisitExpr_(const FloorMod* op, std::ostream& os) override; // NOLINT(*) 104 void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*) 105 void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*) 106 void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*) 107 void VisitExpr_(const NE* op, std::ostream& os) override; // NOLINT(*) 108 void VisitExpr_(const LT* op, std::ostream& os) override; // NOLINT(*) 109 void VisitExpr_(const LE* op, std::ostream& os) override; // NOLINT(*) 110 void VisitExpr_(const GT* op, std::ostream& os) override; // NOLINT(*) 111 void VisitExpr_(const GE* op, std::ostream& os) override; // NOLINT(*) 112 void VisitExpr_(const And* op, std::ostream& os) override; // NOLINT(*) 113 void VisitExpr_(const Or* op, std::ostream& os) override; // NOLINT(*) 114 void VisitExpr_(const Cast* op, std::ostream& os) override; // NOLINT(*) 115 void VisitExpr_(const Not* op, std::ostream& os) override; // NOLINT(*) 116 void VisitExpr_(const Select* op, std::ostream& os) override; // NOLINT(*) 117 void VisitExpr_(const Ramp* op, std::ostream& os) override; // NOLINT(*) 118 void VisitExpr_(const Broadcast* op, std::ostream& os) override; // NOLINT(*) 119 void VisitExpr_(const IntImm* op, std::ostream& os) override; // NOLINT(*) 120 void VisitExpr_(const UIntImm* op, std::ostream& os) override; // NOLINT(*) 121 void VisitExpr_(const FloatImm* op, std::ostream& os) override; // NOLINT(*) 122 void VisitExpr_(const StringImm* op, std::ostream& os) override; // NOLINT(*) 123 // statment 124 void VisitStmt_(const LetStmt* op) override; 125 void VisitStmt_(const Store* op) override; 126 void VisitStmt_(const Provide* op) override; 127 void VisitStmt_(const For* op) override; 128 void VisitStmt_(const IfThenElse* op) override; 129 void VisitStmt_(const Allocate* op) override; 130 void VisitStmt_(const Realize* op) override; 131 void VisitStmt_(const AttrStmt* op) override; 132 void VisitStmt_(const AssertStmt* op) override; 133 void VisitStmt_(const Evaluate* op) override; 134 void VisitStmt_(const Block* op) override; 135 void VisitStmt_(const ProducerConsumer* op) override; 136 /*! 137 * \brief Print Type represetnation of type t. 138 * \param t The type representation. 139 * \param os The stream to print the ctype into 140 */ 141 virtual void PrintType(Type t, std::ostream& os); // NOLINT(*) 142 143 private: 144 /*! \brief The current indent of the code dump. */ 145 int indent_{0}; 146 /*! \brief The tab size of code indent. */ 147 const int tab_{4}; 148 /*! \brief Print the current indent spaces. */ 149 inline void PrintIndent(); 150 /*! \brief Keys are ids allocated, and values are the suffix to prevent double-name. */ 151 std::map<std::string, int> ids_allocated_; 152 /*! 153 * \brief Keys are either (tensors, value_index) or (variables, 0). 154 * Values are the corresponding IDs.*/ 155 std::map<std::pair<const Node *, int>, std::string> id_map_; 156 /*! \brief Variables (keys) binded to the threads (values). */ 157 std::map<const Variable *, std::string> binds_; 158 /*! 159 * \brief Find an unallocated name for the given prefix. 160 * \param prefix The given prefix. 161 */ 162 std::string GetUniqueName(std::string prefix); 163 /*! \brief The output code string builder. */ 164 std::stringstream stream; 165 /*! 166 * \brief Get or allocate the ID for the given variable. 167 * \param v The given variable. 168 */ 169 std::string GetVarID(const Variable *v); 170 /*! 171 * \brief Get or allocate the ID for the given tensor. 172 * \param func The tensor to allocate a name. 173 * \param value_index The value index of the given tensor. 174 */ 175 std::string GetTensorID(const FunctionRef &func, int value_index); 176 /*! \brief the storage scope of allocation */ 177 std::map<FunctionRef, std::string> alloc_storage_scope_; 178 }; 179 180 } // namespace contrib 181 } // namespace tvm 182 #endif // TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_ 183