1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file tvm/ir_visitor.h
22  * \brief Visitor to quickly visit IR trees
23  */
24 #ifndef TVM_IR_VISITOR_H_
25 #define TVM_IR_VISITOR_H_
26 
27 #include "ir.h"
28 #include "tvm/node/functor.h"
29 
30 namespace tvm {
31 namespace ir {
32 
33 /*!
34  * \brief a base class for visitor to iterative traverse the IR
35  *
36  *  This IRVisitor is implemented via NodeFunctor
37  *  This enables extensions of possible new Node.
38  *
39  * \sa ExprFunctor, StmtFunctor, PostOrderVisit
40  *
41  * \note If you need to return values during Visit:
42  *  - If it is mutation of the IR, use IRMutator
43  *  - If you want to return other things, consider use ExprFunctor/StmtFunctor
44  *  - Watch out for possible bug pattern if you use IRVisitor to simulate returns.
45  *
46  * \code
47  *
48  * // This is an example code to show cases for traps in IRVisitor
49  * // The use case is to count number of Variables in the ir tree.
50  * class MyCounter : public IRVisitor {
51  *  public:
52  *   int Count(const ObjectRef& n) {
53  *     ret_ = 0;
54  *     this->Visit(n);
55  *     return ret_;
56  *   }
57  *   void Visit_(const Variable* op) final {
58  *     ret_ = 1;
59  *   }
60  *   void Visit_(const Add* op) final {
61  *     ret_ = count(op->a) + count(op->b);
62  *   }
63 
64  *  private:
65  *   int ret_;
66  * };
67  * MyCounter counter;
68  * Var x("x");
69  * // this returns 2
70  * CHECK_EQ(counter.Count(x + x), 2);
71  * // Think what is the result of the following count
72  * counter.count(Max::make(x, x));
73  * // The result is actually 1
74  * // This is because Visit is not overriden for Max
75  * // so it simply calls Visit for the left and right children
76  * // and because Count is not called, ret_ is not cleared.
77  * // There can also be cases where ret_ is forgetten to be set.
78  *
79  * // These traps may not happen if we program carefully
80  * // But it is recommended to use ExprFunctor, which allows direct
81  * // return the value, this helps us to avoid such problems.
82  *
83  * \endcode
84  */
85 class TVM_DLL IRVisitor {
86  public:
87   /*!
88    * \brief recursively visit an IR node
89    */
Visit(const NodeRef & node)90   virtual void Visit(const NodeRef& node) {
91     static const FVisit& f = vtable();
92     if (node.defined()) f(node, this);
93   }
94   /*! \brief destructor */
~IRVisitor()95   virtual ~IRVisitor() {}
96   /*! \brief functor type of visitor */
97   using FVisit = NodeFunctor<void(const ObjectRef&, IRVisitor*)>;
98   /*! \return internal vtable*/
99   static FVisit& vtable();
100   // overloadable visit function.
101   virtual void Visit_(const Variable* op);
102   virtual void Visit_(const LetStmt* op);
103   virtual void Visit_(const AttrStmt* op);
104   virtual void Visit_(const IfThenElse* op);
105   virtual void Visit_(const For* op);
106   virtual void Visit_(const Allocate* op);
107   virtual void Visit_(const Load* op);
108   virtual void Visit_(const Store* op);
109   virtual void Visit_(const Let* op);
110   virtual void Visit_(const Free* op);
111   virtual void Visit_(const Call* op);
112   virtual void Visit_(const Add* op);
113   virtual void Visit_(const Sub* op);
114   virtual void Visit_(const Mul* op);
115   virtual void Visit_(const Div* op);
116   virtual void Visit_(const Mod* op);
117   virtual void Visit_(const FloorDiv* op);
118   virtual void Visit_(const FloorMod* op);
119   virtual void Visit_(const Min* op);
120   virtual void Visit_(const Max* op);
121   virtual void Visit_(const EQ* op);
122   virtual void Visit_(const NE* op);
123   virtual void Visit_(const LT* op);
124   virtual void Visit_(const LE* op);
125   virtual void Visit_(const GT* op);
126   virtual void Visit_(const GE* op);
127   virtual void Visit_(const And* op);
128   virtual void Visit_(const Or* op);
129   virtual void Visit_(const Reduce* op);
130   virtual void Visit_(const Cast* op);
131   virtual void Visit_(const Not* op);
132   virtual void Visit_(const Select* op);
133   virtual void Visit_(const Ramp* op);
134   virtual void Visit_(const Shuffle* op);
135   virtual void Visit_(const Broadcast* op);
136   virtual void Visit_(const AssertStmt* op);
137   virtual void Visit_(const ProducerConsumer* op);
138   virtual void Visit_(const Provide* op);
139   virtual void Visit_(const Realize* op);
140   virtual void Visit_(const Prefetch* op);
141   virtual void Visit_(const Block* op);
142   virtual void Visit_(const Evaluate* op);
143   virtual void Visit_(const IntImm* op);
144   virtual void Visit_(const UIntImm* op);
145   virtual void Visit_(const FloatImm* op);
146   virtual void Visit_(const StringImm* op);
147 };
148 
149 /*!
150  * \brief recursively visit the ir in post DFS order node, apply fvisit
151  * Each node is guaranteed to be visited only once.
152  * \param node The ir to be visited.
153  * \param fvisit The visitor function to be applied.
154  */
155 TVM_DLL void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit);
156 
157 }  // namespace ir
158 }  // namespace tvm
159 
160 #endif  // TVM_IR_VISITOR_H_
161