1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file tvm/tir/analysis.h
22 * \brief Analysis utilitie and passes for TIR.
23 */
24 #ifndef TVM_TIR_ANALYSIS_H_
25 #define TVM_TIR_ANALYSIS_H_
26
27 #include <tvm/ir/module.h>
28 #include <tvm/ir/transform.h>
29 #include <tvm/tir/expr.h>
30 #include <tvm/tir/function.h>
31 #include <tvm/tir/op_attr_types.h>
32 #include <tvm/tir/stmt.h>
33
34 #include <string>
35
36 namespace tvm {
37 namespace tir {
38
39 /*!
40 * \brief Compare two expressions recursively and check if they are equal
41 * to each other without var remapping.
42 *
43 * This function does not remap variable bindings, it will not
44 * return true for (let x = 1 in x + 1) vs (let y = 1 in y + 1), unless x.same_as(y).
45 *
46 * Use StructuralEqual for such cases.
47 *
48 * Due to the restriction of not remapping variables, this function can run
49 * faster than StructuralEqual and can be used as a utility function during arithmetic
50 * simplifications.
51 *
52 * \sa StructuralEqual
53 */
54 struct ExprDeepEqual {
55 public:
56 TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
57 };
58
59 /*!
60 * \brief Find undefined vars in the statement.
61 * \param stmt The function to be checked.
62 * \param defs The vars that is defined.
63 * \return Array of undefined vars.
64 */
65 TVM_DLL Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
66
67 /*!
68 * \brief Find undefined vars in the expression.
69 * \param expr The expression to be checked.
70 * \return Array of undefined vars.
71 */
72 TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
73
74 /*!
75 * \brief Analyze the side effect
76 * \param expr The expression to be checked.
77 *
78 * \return CallEffectKind, can be kPure, kReadState or kUpdateState
79 */
80 TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);
81
82 /*!
83 * \brief Whether e expression used any var in variable set..
84 * \param expr The expression to be checked.
85 * \param vset_contains The check function to see if var is in the vset.
86 * \return Whether e uses vset.
87 */
88 TVM_DLL bool ExprUseVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains);
89
90 /*!
91 * \brief Whether e expression used var.
92 * \param expr The expression to be checked.
93 * \param var The variable.
94 * \return Whether e uses v.
95 */
ExprUseVar(const PrimExpr & expr,const Var & var)96 inline bool ExprUseVar(const PrimExpr& expr, const Var& var) {
97 return ExprUseVar(expr, [&](const VarNode* node) { return var.get() == node; });
98 }
99
100 /*!
101 * \brief Verifies whether the IR stmt or Expr is in SSA form.
102 * That is: each Var is defined and assigned once(in Let/For)
103 *
104 * \param func The function to be verified.
105 * \return Whether IR is in SSA form.
106 *
107 * \note All passes in TIR consume and produce SSA form.
108 */
109 TVM_DLL bool VerifySSA(const PrimFunc& func);
110
111 /*!
112 * \brief Verify if memory accesses are legal for a specific target device type.
113 *
114 * In the case that tgt is cuda, if not all workload is bound with
115 * threads, CPU code is generated that tries to access GPU memory,
116 * which is illegal. This pass performs verification for this case.
117 *
118 * \param func The function to be verified.
119 * \return Success of memory verification.
120 */
121 TVM_DLL bool VerifyMemory(const PrimFunc& func);
122
123 /*!
124 * \brief Verify the correctness of a GPU code
125 * It will check the whether the amount of memory usage or the number of threads
126 * in a block exceeds the limit
127 * \param func The function to be checked
128 * \param constraints The dict to specify constraints to check.
129 * Possible keys are
130 *
131 * "max_local_memory_per_block": Total amount of local memory per block (in bytes).
132 * "max_shared_memory_per_block": Total amount of shared memory per block (in bytes).
133 * "max_threads_per_block": Maximum number of threads per block.
134 * "max_thread_x": Maximum length of threadIdx.x.
135 * "max_thread_y": Maximum length of threadIdx.y.
136 * "max_thread_z": Maximum length of threadIdx.z.
137 *
138 * If one key is missing in this argument, the pass won't check for that item.
139 * \return valid Whether it is a valid GPU code
140 *
141 */
142 TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints);
143
144 // Pass variants of verification analysis
145 // directly throws RuntimeError when verification fails.
146 namespace transform {
147
148 using tvm::transform::Pass;
149 using tvm::transform::PassContext;
150
151 /*!
152 * \brief Pass variant of VerifySSA.
153 *
154 * \returns The pass.
155 * \sa tvm::tir::VerifySSA
156 */
157 TVM_DLL Pass VerifySSA();
158
159 /*!
160 * \brief Pass variant of VerifyMemory.
161 *
162 * \returns The pass.
163 * \sa tvm::tir::VerifyMemory
164 */
165 TVM_DLL Pass VerifyMemory();
166
167 /*!
168 * \brief Pass variant of VerifyGPUCode.
169 *
170 * \param constraints The dict to specify constraints to check.
171 *
172 * \returns The pass.
173 * \sa tvm::tir::VerifyGPUCode
174 */
175 TVM_DLL Pass VerifyGPUCode(Map<String, PrimExpr> constraints);
176
177 } // namespace transform
178 } // namespace tir
179 } // namespace tvm
180 #endif // TVM_TIR_ANALYSIS_H_
181