1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file tvm/tir/analysis.h
22  * \brief Analysis utilitie and passes for TIR.
23  */
24 #ifndef TVM_TIR_ANALYSIS_H_
25 #define TVM_TIR_ANALYSIS_H_
26 
27 #include <tvm/ir/module.h>
28 #include <tvm/ir/transform.h>
29 #include <tvm/tir/expr.h>
30 #include <tvm/tir/function.h>
31 #include <tvm/tir/op_attr_types.h>
32 #include <tvm/tir/stmt.h>
33 
34 #include <string>
35 
36 namespace tvm {
37 namespace tir {
38 
39 /*!
40  * \brief Compare two expressions recursively and check if they are equal
41  *        to each other without var remapping.
42  *
43  *  This function does not remap variable bindings, it will not
44  *  return true for (let x = 1 in x + 1) vs (let y = 1 in y + 1), unless x.same_as(y).
45  *
46  *  Use StructuralEqual for such cases.
47  *
48  *  Due to the restriction of not remapping variables, this function can run
49  *  faster than StructuralEqual and can be used as a utility function during arithmetic
50  *  simplifications.
51  *
52  * \sa StructuralEqual
53  */
54 struct ExprDeepEqual {
55  public:
56   TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
57 };
58 
59 /*!
60  * \brief Find undefined vars in the statement.
61  * \param stmt The function to be checked.
62  * \param defs The vars that is defined.
63  * \return Array of undefined vars.
64  */
65 TVM_DLL Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
66 
67 /*!
68  * \brief Find undefined vars in the expression.
69  * \param expr The expression to be checked.
70  * \return Array of undefined vars.
71  */
72 TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
73 
74 /*!
75  * \brief Analyze the side effect
76  * \param expr The expression to be checked.
77  *
78  * \return CallEffectKind, can be kPure, kReadState or kUpdateState
79  */
80 TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);
81 
82 /*!
83  * \brief Whether e expression used any var in variable set..
84  * \param expr The expression to be checked.
85  * \param vset_contains The check function to see if var is in the vset.
86  * \return Whether e uses vset.
87  */
88 TVM_DLL bool ExprUseVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains);
89 
90 /*!
91  * \brief Whether e expression used var.
92  * \param expr The expression to be checked.
93  * \param var The variable.
94  * \return Whether e uses v.
95  */
ExprUseVar(const PrimExpr & expr,const Var & var)96 inline bool ExprUseVar(const PrimExpr& expr, const Var& var) {
97   return ExprUseVar(expr, [&](const VarNode* node) { return var.get() == node; });
98 }
99 
100 /*!
101  * \brief Verifies whether the IR stmt or Expr is in SSA form.
102  *  That is: each Var is defined and assigned once(in Let/For)
103  *
104  * \param func The function to be verified.
105  * \return Whether IR is in SSA form.
106  *
107  * \note All passes in TIR consume and produce SSA form.
108  */
109 TVM_DLL bool VerifySSA(const PrimFunc& func);
110 
111 /*!
112  * \brief Verify if memory accesses are legal for a specific target device type.
113  *
114  *  In the case that tgt is cuda, if not all workload is bound with
115  *  threads, CPU code is generated that tries to access GPU memory,
116  *  which is illegal. This pass performs verification for this case.
117  *
118  * \param func The function to be verified.
119  * \return Success of memory verification.
120  */
121 TVM_DLL bool VerifyMemory(const PrimFunc& func);
122 
123 /*!
124  * \brief Verify the correctness of a GPU code
125  *        It will check the whether the amount of memory usage or the number of threads
126  *        in a block exceeds the limit
127  * \param func The function to be checked
128  * \param constraints The dict to specify constraints to check.
129  *        Possible keys are
130  *
131  *        "max_local_memory_per_block": Total amount of local memory per block (in bytes).
132  *        "max_shared_memory_per_block": Total amount of shared memory per block (in bytes).
133  *        "max_threads_per_block": Maximum number of threads per block.
134  *        "max_thread_x": Maximum length of threadIdx.x.
135  *        "max_thread_y": Maximum length of threadIdx.y.
136  *        "max_thread_z": Maximum length of threadIdx.z.
137  *
138  *        If one key is missing in this argument, the pass won't check for that item.
139  * \return valid Whether it is a valid GPU code
140  *
141  */
142 TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints);
143 
144 // Pass variants of verification analysis
145 // directly throws RuntimeError when verification fails.
146 namespace transform {
147 
148 using tvm::transform::Pass;
149 using tvm::transform::PassContext;
150 
151 /*!
152  * \brief Pass variant of VerifySSA.
153  *
154  * \returns The pass.
155  * \sa tvm::tir::VerifySSA
156  */
157 TVM_DLL Pass VerifySSA();
158 
159 /*!
160  * \brief Pass variant of VerifyMemory.
161  *
162  * \returns The pass.
163  * \sa tvm::tir::VerifyMemory
164  */
165 TVM_DLL Pass VerifyMemory();
166 
167 /*!
168  * \brief Pass variant of VerifyGPUCode.
169  *
170  * \param constraints The dict to specify constraints to check.
171  *
172  * \returns The pass.
173  * \sa tvm::tir::VerifyGPUCode
174  */
175 TVM_DLL Pass VerifyGPUCode(Map<String, PrimExpr> constraints);
176 
177 }  // namespace transform
178 }  // namespace tir
179 }  // namespace tvm
180 #endif  // TVM_TIR_ANALYSIS_H_
181