1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file tvm/relay/expr_functor.h 22 * \brief A more powerful visitor which enables defining arbitrary function 23 * signatures with type based dispatch on first argument. 24 */ 25 #ifndef TVM_RELAY_EXPR_FUNCTOR_H_ 26 #define TVM_RELAY_EXPR_FUNCTOR_H_ 27 28 #include <tvm/node/functor.h> 29 #include <string> 30 #include <utility> 31 #include <unordered_map> 32 #include "./expr.h" 33 #include "./adt.h" 34 #include "./op.h" 35 #include "./error.h" 36 37 namespace tvm { 38 namespace relay { 39 40 /*! 41 * \brief A dynamical functor that dispatches on in the first Expr argument. 42 * You can use this as a more powerful Visitor, since it allows you to 43 * define function signatures of Visit Function. 44 * 45 * \sa tvm/ir_functor.h 46 * 47 * \tparam FType function signiture 48 * This type is only defined for FType with function signature R(const Expr&, 49 * Args...) 50 */ 51 template <typename FType> 52 class ExprFunctor; 53 54 // functions to be overriden. 55 #define EXPR_FUNCTOR_DEFAULT \ 56 { return VisitExprDefault_(op, std::forward<Args>(args)...); } 57 58 #define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ 59 vtable.template set_dispatch<OP>( \ 60 [](const ObjectRef& n, TSelf* self, Args... args) { \ 61 return self->VisitExpr_(static_cast<const OP*>(n.get()), \ 62 std::forward<Args>(args)...); \ 63 }); 64 65 template <typename R, typename... Args> 66 class ExprFunctor<R(const Expr& n, Args...)> { 67 private: 68 using TSelf = ExprFunctor<R(const Expr& n, Args...)>; 69 using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>; 70 71 public: 72 /*! \brief the result type of this functor */ 73 using result_type = R; 74 /*! \brief virtual destructor */ ~ExprFunctor()75 virtual ~ExprFunctor() {} 76 /*! 77 * \brief Same as call. 78 * \param n The expression node. 79 * \param args Additional arguments. 80 * \return The result of the call 81 */ operator()82 R operator()(const Expr& n, Args... args) { 83 return VisitExpr(n, std::forward<Args>(args)...); 84 } 85 /*! 86 * \brief The functor call. 87 * \param n The expression node. 88 * \param args Additional arguments. 89 * \return The result of the call 90 */ VisitExpr(const Expr & n,Args...args)91 virtual R VisitExpr(const Expr& n, Args... args) { 92 CHECK(n.defined()); 93 static FType vtable = InitVTable(); 94 return vtable(n, this, std::forward<Args>(args)...); 95 } 96 // Functions that can be overriden by subclass 97 virtual R VisitExpr_(const ConstantNode* op, 98 Args... args) EXPR_FUNCTOR_DEFAULT; 99 virtual R VisitExpr_(const TupleNode* op, 100 Args... args) EXPR_FUNCTOR_DEFAULT; 101 virtual R VisitExpr_(const VarNode* op, 102 Args... args) EXPR_FUNCTOR_DEFAULT; 103 virtual R VisitExpr_(const GlobalVarNode* op, 104 Args... args) EXPR_FUNCTOR_DEFAULT; 105 virtual R VisitExpr_(const FunctionNode* op, 106 Args... args) EXPR_FUNCTOR_DEFAULT; 107 virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; 108 virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; 109 virtual R VisitExpr_(const IfNode* op, 110 Args... args) EXPR_FUNCTOR_DEFAULT; 111 virtual R VisitExpr_(const OpNode* op, 112 Args... args) EXPR_FUNCTOR_DEFAULT; 113 virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; 114 virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; 115 virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; 116 virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; 117 virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; 118 virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; VisitExprDefault_(const Node * op,Args...)119 virtual R VisitExprDefault_(const Node* op, Args...) { 120 LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); 121 throw; 122 } 123 124 private: 125 // initialize the vtable. InitVTable()126 static FType InitVTable() { 127 FType vtable; 128 // Set dispatch 129 RELAY_EXPR_FUNCTOR_DISPATCH(ConstantNode); 130 RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode); 131 RELAY_EXPR_FUNCTOR_DISPATCH(VarNode); 132 RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode); 133 RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode); 134 RELAY_EXPR_FUNCTOR_DISPATCH(CallNode); 135 RELAY_EXPR_FUNCTOR_DISPATCH(LetNode); 136 RELAY_EXPR_FUNCTOR_DISPATCH(IfNode); 137 RELAY_EXPR_FUNCTOR_DISPATCH(OpNode); 138 RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode); 139 RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode); 140 RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode); 141 RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode); 142 RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode); 143 RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode); 144 return vtable; 145 } 146 }; 147 148 /*! 149 * \brief A simple visitor wrapper around ExprFunctor. 150 * Recursively visit the content. 151 * 152 * ExprVisitor treats Expr as dataflow graph, 153 * and only visit each Expr node once. 154 */ 155 class ExprVisitor 156 : public ::tvm::relay::ExprFunctor<void(const Expr& n)> { 157 public: 158 void VisitExpr(const Expr& expr) override; 159 void VisitExpr_(const VarNode* op) override; 160 void VisitExpr_(const GlobalVarNode* op) override; 161 void VisitExpr_(const ConstantNode* op) override; 162 void VisitExpr_(const TupleNode* op) override; 163 void VisitExpr_(const FunctionNode* op) override; 164 void VisitExpr_(const CallNode* op) override; 165 void VisitExpr_(const LetNode* op) override; 166 void VisitExpr_(const IfNode* op) override; 167 void VisitExpr_(const OpNode* op) override; 168 void VisitExpr_(const TupleGetItemNode* op) override; 169 void VisitExpr_(const RefCreateNode* op) override; 170 void VisitExpr_(const RefReadNode* op) override; 171 void VisitExpr_(const RefWriteNode* op) override; 172 void VisitExpr_(const ConstructorNode* op) override; 173 void VisitExpr_(const MatchNode* op) override; 174 virtual void VisitType(const Type& t); 175 virtual void VisitClause(const Clause& c); 176 virtual void VisitPattern(const Pattern& c); 177 178 protected: 179 // Internal visiting counter 180 std::unordered_map<const Node*, size_t> visit_counter_; 181 }; 182 183 /*! 184 * \brief A wrapper around ExprFunctor which functionally updates the AST. 185 * 186 * ExprMutator treats Expr as dataflow graph, and only Mutate each Expr once. 187 * The mutated results are memoized in a map and reused so that 188 * local transformation on the dataflow preserves the graph structure. 189 */ 190 class ExprMutator 191 : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> { 192 public: 193 /*! 194 * \brief Mutate is alias for VisitExpr 195 * \return expr. 196 */ Mutate(const Expr & expr)197 Expr Mutate(const Expr& expr) { 198 return this->VisitExpr(expr); 199 } 200 Expr VisitExpr(const Expr& expr) override; 201 Expr VisitExpr_(const VarNode* op) override; 202 Expr VisitExpr_(const ConstantNode* op) override; 203 Expr VisitExpr_(const GlobalVarNode* op) override; 204 Expr VisitExpr_(const OpNode* op) override; 205 Expr VisitExpr_(const TupleNode* op) override; 206 Expr VisitExpr_(const FunctionNode* op) override; 207 Expr VisitExpr_(const CallNode* call_node) override; 208 Expr VisitExpr_(const LetNode* op) override; 209 Expr VisitExpr_(const IfNode* op) override; 210 Expr VisitExpr_(const TupleGetItemNode* op) override; 211 Expr VisitExpr_(const RefCreateNode* op) override; 212 Expr VisitExpr_(const RefReadNode* op) override; 213 Expr VisitExpr_(const RefWriteNode* op) override; 214 Expr VisitExpr_(const ConstructorNode* op) override; 215 Expr VisitExpr_(const MatchNode* op) override; 216 217 /*! 218 * \brief Used to visit the types inside of expressions. 219 * 220 * Can be overloaded to transform the types in arbitrary 221 * ways, one way would be to define a sub-class of type 222 * visitor for types which transform them appropriately. 223 */ 224 virtual Type VisitType(const Type& t); 225 virtual Clause VisitClause(const Clause& c); 226 virtual Pattern VisitPattern(const Pattern& c); 227 228 protected: 229 /*! \brief Internal map used for memoization. */ 230 std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_; 231 }; 232 233 /*! 234 * \brief recursively visit the ir in post DFS order node, apply fvisit 235 * Each node is guaranteed to be visited only once. 236 * \param node The ir to be visited. 237 * \param fvisit The visitor function to be applied. 238 */ 239 void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit); 240 241 } // namespace relay 242 } // namespace tvm 243 #endif // TVM_RELAY_EXPR_FUNCTOR_H_ 244