1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file tvm/relay/expr_functor.h
22  * \brief A more powerful visitor which enables defining arbitrary function
23  * signatures with type based dispatch on first argument.
24  */
25 #ifndef TVM_RELAY_EXPR_FUNCTOR_H_
26 #define TVM_RELAY_EXPR_FUNCTOR_H_
27 
28 #include <tvm/node/functor.h>
29 #include <string>
30 #include <utility>
31 #include <unordered_map>
32 #include "./expr.h"
33 #include "./adt.h"
34 #include "./op.h"
35 #include "./error.h"
36 
37 namespace tvm {
38 namespace relay {
39 
40 /*!
41  * \brief A dynamical functor that dispatches on in the first Expr argument.
42  *  You can use this as a more powerful Visitor, since it allows you to
43  *  define function signatures of Visit Function.
44  *
45  * \sa tvm/ir_functor.h
46  *
47  * \tparam FType function signiture
48  *  This type is only defined for FType with function signature R(const Expr&,
49  * Args...)
50  */
51 template <typename FType>
52 class ExprFunctor;
53 
54 // functions to be overriden.
55 #define EXPR_FUNCTOR_DEFAULT                                      \
56   { return VisitExprDefault_(op, std::forward<Args>(args)...); }
57 
58 #define RELAY_EXPR_FUNCTOR_DISPATCH(OP)                                \
59   vtable.template set_dispatch<OP>(                                    \
60       [](const ObjectRef& n, TSelf* self, Args... args) {                \
61         return self->VisitExpr_(static_cast<const OP*>(n.get()), \
62                                 std::forward<Args>(args)...);          \
63       });
64 
65 template <typename R, typename... Args>
66 class ExprFunctor<R(const Expr& n, Args...)> {
67  private:
68   using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
69   using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
70 
71  public:
72   /*! \brief the result type of this functor */
73   using result_type = R;
74   /*! \brief virtual destructor */
~ExprFunctor()75   virtual ~ExprFunctor() {}
76   /*!
77    * \brief Same as call.
78    * \param n The expression node.
79    * \param args Additional arguments.
80    * \return The result of the call
81    */
operator()82   R operator()(const Expr& n, Args... args) {
83     return VisitExpr(n, std::forward<Args>(args)...);
84   }
85   /*!
86    * \brief The functor call.
87    * \param n The expression node.
88    * \param args Additional arguments.
89    * \return The result of the call
90    */
VisitExpr(const Expr & n,Args...args)91   virtual R VisitExpr(const Expr& n, Args... args) {
92     CHECK(n.defined());
93     static FType vtable = InitVTable();
94     return vtable(n, this, std::forward<Args>(args)...);
95   }
96   // Functions that can be overriden by subclass
97   virtual R VisitExpr_(const ConstantNode* op,
98                        Args... args) EXPR_FUNCTOR_DEFAULT;
99   virtual R VisitExpr_(const TupleNode* op,
100                        Args... args) EXPR_FUNCTOR_DEFAULT;
101   virtual R VisitExpr_(const VarNode* op,
102                        Args... args) EXPR_FUNCTOR_DEFAULT;
103   virtual R VisitExpr_(const GlobalVarNode* op,
104                        Args... args) EXPR_FUNCTOR_DEFAULT;
105   virtual R VisitExpr_(const FunctionNode* op,
106                        Args... args) EXPR_FUNCTOR_DEFAULT;
107   virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
108   virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
109   virtual R VisitExpr_(const IfNode* op,
110                        Args... args) EXPR_FUNCTOR_DEFAULT;
111   virtual R VisitExpr_(const OpNode* op,
112                        Args... args) EXPR_FUNCTOR_DEFAULT;
113   virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
114   virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
115   virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
116   virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
117   virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
118   virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
VisitExprDefault_(const Node * op,Args...)119   virtual R VisitExprDefault_(const Node* op, Args...) {
120     LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
121     throw;
122   }
123 
124  private:
125   // initialize the vtable.
InitVTable()126   static FType InitVTable() {
127     FType vtable;
128     // Set dispatch
129     RELAY_EXPR_FUNCTOR_DISPATCH(ConstantNode);
130     RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode);
131     RELAY_EXPR_FUNCTOR_DISPATCH(VarNode);
132     RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode);
133     RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode);
134     RELAY_EXPR_FUNCTOR_DISPATCH(CallNode);
135     RELAY_EXPR_FUNCTOR_DISPATCH(LetNode);
136     RELAY_EXPR_FUNCTOR_DISPATCH(IfNode);
137     RELAY_EXPR_FUNCTOR_DISPATCH(OpNode);
138     RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode);
139     RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode);
140     RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode);
141     RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode);
142     RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode);
143     RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode);
144     return vtable;
145   }
146 };
147 
148 /*!
149  * \brief A simple visitor wrapper around ExprFunctor.
150  *  Recursively visit the content.
151  *
152  * ExprVisitor treats Expr as dataflow graph,
153  * and only visit each Expr node once.
154  */
155 class ExprVisitor
156     : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
157  public:
158   void VisitExpr(const Expr& expr) override;
159   void VisitExpr_(const VarNode* op) override;
160   void VisitExpr_(const GlobalVarNode* op) override;
161   void VisitExpr_(const ConstantNode* op) override;
162   void VisitExpr_(const TupleNode* op) override;
163   void VisitExpr_(const FunctionNode* op) override;
164   void VisitExpr_(const CallNode* op) override;
165   void VisitExpr_(const LetNode* op) override;
166   void VisitExpr_(const IfNode* op) override;
167   void VisitExpr_(const OpNode* op) override;
168   void VisitExpr_(const TupleGetItemNode* op) override;
169   void VisitExpr_(const RefCreateNode* op) override;
170   void VisitExpr_(const RefReadNode* op) override;
171   void VisitExpr_(const RefWriteNode* op) override;
172   void VisitExpr_(const ConstructorNode* op) override;
173   void VisitExpr_(const MatchNode* op) override;
174   virtual void VisitType(const Type& t);
175   virtual void VisitClause(const Clause& c);
176   virtual void VisitPattern(const Pattern& c);
177 
178  protected:
179   // Internal visiting counter
180   std::unordered_map<const Node*, size_t> visit_counter_;
181 };
182 
183 /*!
184  * \brief A wrapper around ExprFunctor which functionally updates the AST.
185  *
186  * ExprMutator treats Expr as dataflow graph, and only Mutate each Expr once.
187  * The mutated results are memoized in a map and reused so that
188  * local transformation on the dataflow preserves the graph structure.
189  */
190 class ExprMutator
191     : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
192  public:
193   /*!
194    * \brief Mutate is alias for VisitExpr
195    * \return expr.
196    */
Mutate(const Expr & expr)197   Expr Mutate(const Expr& expr) {
198     return this->VisitExpr(expr);
199   }
200   Expr VisitExpr(const Expr& expr) override;
201   Expr VisitExpr_(const VarNode* op) override;
202   Expr VisitExpr_(const ConstantNode* op) override;
203   Expr VisitExpr_(const GlobalVarNode* op) override;
204   Expr VisitExpr_(const OpNode* op) override;
205   Expr VisitExpr_(const TupleNode* op) override;
206   Expr VisitExpr_(const FunctionNode* op) override;
207   Expr VisitExpr_(const CallNode* call_node) override;
208   Expr VisitExpr_(const LetNode* op) override;
209   Expr VisitExpr_(const IfNode* op) override;
210   Expr VisitExpr_(const TupleGetItemNode* op) override;
211   Expr VisitExpr_(const RefCreateNode* op) override;
212   Expr VisitExpr_(const RefReadNode* op) override;
213   Expr VisitExpr_(const RefWriteNode* op) override;
214   Expr VisitExpr_(const ConstructorNode* op) override;
215   Expr VisitExpr_(const MatchNode* op) override;
216 
217   /*!
218    * \brief Used to visit the types inside of expressions.
219    *
220    * Can be overloaded to transform the types in arbitrary
221    * ways, one way would be to define a sub-class of type
222    * visitor for types which transform them appropriately.
223    */
224   virtual Type VisitType(const Type& t);
225   virtual Clause VisitClause(const Clause& c);
226   virtual Pattern VisitPattern(const Pattern& c);
227 
228  protected:
229   /*! \brief Internal map used for memoization. */
230   std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
231 };
232 
233 /*!
234  * \brief recursively visit the ir in post DFS order node, apply fvisit
235  * Each node is guaranteed to be visited only once.
236  * \param node The ir to be visited.
237  * \param fvisit The visitor function to be applied.
238  */
239 void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
240 
241 }  // namespace relay
242 }  // namespace tvm
243 #endif  // TVM_RELAY_EXPR_FUNCTOR_H_
244