1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file codegen_hybrid.h
22  * \brief Common utilities to generated C style code.
23  */
24 #ifndef TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_
25 #define TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_
26 
27 #include <tvm/ir.h>
28 #include <tvm/ir_functor_ext.h>
29 #include <tvm/codegen.h>
30 #include <tvm/lowered_func.h>
31 #include <tvm/schedule.h>
32 #include <map>
33 #include <string>
34 #include <unordered_map>
35 #include <utility>
36 #include <vector>
37 
38 namespace tvm {
39 namespace contrib {
40 
41 using namespace ir;
42 /*!
43  * \brief A base class to generate Hybrid Script.
44  *
45  * **NOTE** CodeGenHybrid does not aim at generating Python scripts consumed by Python2/3.
46  * For runtime support, please refer the decorator in ``tvm/python/hybrid/api.py``.
47  */
48 class CodeGenHybrid :
49       public ExprFunctor<void(const Expr&, std::ostream&)>,
50       public StmtFunctor<void(const Stmt&)> {
51  public:
52   /*!
53    * \brief Dump the given function body to hybrid script.
54    * \param stmt The function body to be dumped to hybrid script.
55    * \param inputs Input tensors of this schedule.
56    * \param outputs Output tensors of this schedule.
57    * \param name The name of the function.
58    */
59   void DumpStmt(const Stmt &stmt, const Array<NodeRef> &inputs, const Array<Tensor> &outputs,
60                 const std::string &name = "hybrid_func");
61   /*!
62    * \brief Finalize the compilation and return the code.
63    * \return The code.
64    */
65   std::string Finish();
66   /*! \brief Reserve keywords in avoid of name conflict. */
67   void ReserveKeywords();
68   /*!
69    * \brief Print the Stmt n to CodeGenHybrid->stream
70    * \param n The statement to be printed.
71    */
PrintStmt(const Stmt & n)72   void PrintStmt(const Stmt &n) {
73     this->VisitStmt(n);
74   }
75   /*!
76    * \brief Print the expression n(or its ssa id if in ssa mode) into os
77    * \param n The expression to be printed.
78    * \param os The output stream
79    */
PrintExpr(const Expr & n,std::ostream & os)80   void PrintExpr(const Expr &n, std::ostream &os) {
81     this->VisitExpr(n, os);
82   }
83   /*!
84    * \brief Same as PrintExpr, but simply returns result string
85    * \param n The expression to be printed.
86    */
PrintExpr(const Expr & n)87   std::string PrintExpr(const Expr &n) {
88     std::ostringstream os;
89     PrintExpr(n, os);
90     return os.str();
91   }
92   // expression
93   void VisitExpr_(const Variable* op, std::ostream& os) override;  // NOLINT(*)
94   void VisitExpr_(const Load* op, std::ostream& os) override;  // NOLINT(*)
95   void VisitExpr_(const Let* op, std::ostream& os) override;  // NOLINT(*)
96   void VisitExpr_(const Call* op, std::ostream& os) override;  // NOLINT(*)
97   void VisitExpr_(const Add* op, std::ostream& os) override;  // NOLINT(*)
98   void VisitExpr_(const Sub* op, std::ostream& os) override;  // NOLINT(*)
99   void VisitExpr_(const Mul* op, std::ostream& os) override;  // NOLINT(*)
100   void VisitExpr_(const Div* op, std::ostream& os) override;  // NOLINT(*)
101   void VisitExpr_(const Mod* op, std::ostream& os) override;  // NOLINT(*)
102   void VisitExpr_(const FloorDiv* op, std::ostream& os) override;  // NOLINT(*)
103   void VisitExpr_(const FloorMod* op, std::ostream& os) override;  // NOLINT(*)
104   void VisitExpr_(const Min* op, std::ostream& os) override;  // NOLINT(*)
105   void VisitExpr_(const Max* op, std::ostream& os) override;  // NOLINT(*)
106   void VisitExpr_(const EQ* op, std::ostream& os) override;  // NOLINT(*)
107   void VisitExpr_(const NE* op, std::ostream& os) override;  // NOLINT(*)
108   void VisitExpr_(const LT* op, std::ostream& os) override;  // NOLINT(*)
109   void VisitExpr_(const LE* op, std::ostream& os) override;  // NOLINT(*)
110   void VisitExpr_(const GT* op, std::ostream& os) override;  // NOLINT(*)
111   void VisitExpr_(const GE* op, std::ostream& os) override;  // NOLINT(*)
112   void VisitExpr_(const And* op, std::ostream& os) override;  // NOLINT(*)
113   void VisitExpr_(const Or* op, std::ostream& os) override;  // NOLINT(*)
114   void VisitExpr_(const Cast* op, std::ostream& os) override;  // NOLINT(*)
115   void VisitExpr_(const Not* op, std::ostream& os) override;  // NOLINT(*)
116   void VisitExpr_(const Select* op, std::ostream& os) override;  // NOLINT(*)
117   void VisitExpr_(const Ramp* op, std::ostream& os) override;  // NOLINT(*)
118   void VisitExpr_(const Broadcast* op, std::ostream& os) override;  // NOLINT(*)
119   void VisitExpr_(const IntImm* op, std::ostream& os) override;  // NOLINT(*)
120   void VisitExpr_(const UIntImm* op, std::ostream& os) override;  // NOLINT(*)
121   void VisitExpr_(const FloatImm* op, std::ostream& os) override;  // NOLINT(*)
122   void VisitExpr_(const StringImm* op, std::ostream& os) override;  // NOLINT(*)
123   // statment
124   void VisitStmt_(const LetStmt* op) override;
125   void VisitStmt_(const Store* op) override;
126   void VisitStmt_(const Provide* op) override;
127   void VisitStmt_(const For* op) override;
128   void VisitStmt_(const IfThenElse* op) override;
129   void VisitStmt_(const Allocate* op) override;
130   void VisitStmt_(const Realize* op) override;
131   void VisitStmt_(const AttrStmt* op) override;
132   void VisitStmt_(const AssertStmt* op) override;
133   void VisitStmt_(const Evaluate* op) override;
134   void VisitStmt_(const Block* op) override;
135   void VisitStmt_(const ProducerConsumer* op) override;
136   /*!
137    * \brief Print Type represetnation of type t.
138    * \param t The type representation.
139    * \param os The stream to print the ctype into
140    */
141   virtual void PrintType(Type t, std::ostream& os); // NOLINT(*)
142 
143  private:
144   /*! \brief The current indent of the code dump. */
145   int indent_{0};
146   /*! \brief The tab size of code indent. */
147   const int tab_{4};
148   /*! \brief Print the current indent spaces. */
149   inline void PrintIndent();
150   /*! \brief Keys are ids allocated, and values are the suffix to prevent double-name.  */
151   std::map<std::string, int> ids_allocated_;
152   /*!
153    * \brief Keys are either (tensors, value_index) or (variables, 0).
154    *        Values are the corresponding IDs.*/
155   std::map<std::pair<const Node *, int>, std::string> id_map_;
156   /*! \brief Variables (keys) binded to the threads (values). */
157   std::map<const Variable *, std::string> binds_;
158   /*!
159    * \brief Find an unallocated name for the given prefix.
160    * \param prefix The given prefix.
161    */
162   std::string GetUniqueName(std::string prefix);
163   /*! \brief The output code string builder. */
164   std::stringstream stream;
165   /*!
166    * \brief Get or allocate the ID for the given variable.
167    * \param v The given variable.
168    */
169   std::string GetVarID(const Variable *v);
170   /*!
171    * \brief Get or allocate the ID for the given tensor.
172    * \param func The tensor to allocate a name.
173    * \param value_index The value index of the given tensor.
174    */
175   std::string GetTensorID(const FunctionRef &func, int value_index);
176   /*! \brief the storage scope of allocation */
177   std::map<FunctionRef, std::string> alloc_storage_scope_;
178 };
179 
180 }  // namespace contrib
181 }  // namespace tvm
182 #endif  // TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_
183