1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file tvm/ir_visitor.h 22 * \brief Visitor to quickly visit IR trees 23 */ 24 #ifndef TVM_IR_VISITOR_H_ 25 #define TVM_IR_VISITOR_H_ 26 27 #include "ir.h" 28 #include "tvm/node/functor.h" 29 30 namespace tvm { 31 namespace ir { 32 33 /*! 34 * \brief a base class for visitor to iterative traverse the IR 35 * 36 * This IRVisitor is implemented via NodeFunctor 37 * This enables extensions of possible new Node. 38 * 39 * \sa ExprFunctor, StmtFunctor, PostOrderVisit 40 * 41 * \note If you need to return values during Visit: 42 * - If it is mutation of the IR, use IRMutator 43 * - If you want to return other things, consider use ExprFunctor/StmtFunctor 44 * - Watch out for possible bug pattern if you use IRVisitor to simulate returns. 45 * 46 * \code 47 * 48 * // This is an example code to show cases for traps in IRVisitor 49 * // The use case is to count number of Variables in the ir tree. 50 * class MyCounter : public IRVisitor { 51 * public: 52 * int Count(const ObjectRef& n) { 53 * ret_ = 0; 54 * this->Visit(n); 55 * return ret_; 56 * } 57 * void Visit_(const Variable* op) final { 58 * ret_ = 1; 59 * } 60 * void Visit_(const Add* op) final { 61 * ret_ = count(op->a) + count(op->b); 62 * } 63 64 * private: 65 * int ret_; 66 * }; 67 * MyCounter counter; 68 * Var x("x"); 69 * // this returns 2 70 * CHECK_EQ(counter.Count(x + x), 2); 71 * // Think what is the result of the following count 72 * counter.count(Max::make(x, x)); 73 * // The result is actually 1 74 * // This is because Visit is not overriden for Max 75 * // so it simply calls Visit for the left and right children 76 * // and because Count is not called, ret_ is not cleared. 77 * // There can also be cases where ret_ is forgetten to be set. 78 * 79 * // These traps may not happen if we program carefully 80 * // But it is recommended to use ExprFunctor, which allows direct 81 * // return the value, this helps us to avoid such problems. 82 * 83 * \endcode 84 */ 85 class TVM_DLL IRVisitor { 86 public: 87 /*! 88 * \brief recursively visit an IR node 89 */ Visit(const NodeRef & node)90 virtual void Visit(const NodeRef& node) { 91 static const FVisit& f = vtable(); 92 if (node.defined()) f(node, this); 93 } 94 /*! \brief destructor */ ~IRVisitor()95 virtual ~IRVisitor() {} 96 /*! \brief functor type of visitor */ 97 using FVisit = NodeFunctor<void(const ObjectRef&, IRVisitor*)>; 98 /*! \return internal vtable*/ 99 static FVisit& vtable(); 100 // overloadable visit function. 101 virtual void Visit_(const Variable* op); 102 virtual void Visit_(const LetStmt* op); 103 virtual void Visit_(const AttrStmt* op); 104 virtual void Visit_(const IfThenElse* op); 105 virtual void Visit_(const For* op); 106 virtual void Visit_(const Allocate* op); 107 virtual void Visit_(const Load* op); 108 virtual void Visit_(const Store* op); 109 virtual void Visit_(const Let* op); 110 virtual void Visit_(const Free* op); 111 virtual void Visit_(const Call* op); 112 virtual void Visit_(const Add* op); 113 virtual void Visit_(const Sub* op); 114 virtual void Visit_(const Mul* op); 115 virtual void Visit_(const Div* op); 116 virtual void Visit_(const Mod* op); 117 virtual void Visit_(const FloorDiv* op); 118 virtual void Visit_(const FloorMod* op); 119 virtual void Visit_(const Min* op); 120 virtual void Visit_(const Max* op); 121 virtual void Visit_(const EQ* op); 122 virtual void Visit_(const NE* op); 123 virtual void Visit_(const LT* op); 124 virtual void Visit_(const LE* op); 125 virtual void Visit_(const GT* op); 126 virtual void Visit_(const GE* op); 127 virtual void Visit_(const And* op); 128 virtual void Visit_(const Or* op); 129 virtual void Visit_(const Reduce* op); 130 virtual void Visit_(const Cast* op); 131 virtual void Visit_(const Not* op); 132 virtual void Visit_(const Select* op); 133 virtual void Visit_(const Ramp* op); 134 virtual void Visit_(const Shuffle* op); 135 virtual void Visit_(const Broadcast* op); 136 virtual void Visit_(const AssertStmt* op); 137 virtual void Visit_(const ProducerConsumer* op); 138 virtual void Visit_(const Provide* op); 139 virtual void Visit_(const Realize* op); 140 virtual void Visit_(const Prefetch* op); 141 virtual void Visit_(const Block* op); 142 virtual void Visit_(const Evaluate* op); 143 virtual void Visit_(const IntImm* op); 144 virtual void Visit_(const UIntImm* op); 145 virtual void Visit_(const FloatImm* op); 146 virtual void Visit_(const StringImm* op); 147 }; 148 149 /*! 150 * \brief recursively visit the ir in post DFS order node, apply fvisit 151 * Each node is guaranteed to be visited only once. 152 * \param node The ir to be visited. 153 * \param fvisit The visitor function to be applied. 154 */ 155 TVM_DLL void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit); 156 157 } // namespace ir 158 } // namespace tvm 159 160 #endif // TVM_IR_VISITOR_H_ 161