1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file tvm/tir/transform.h
22  * \brief TIR specific transformation passes.
23  */
24 #ifndef TVM_TIR_TRANSFORM_H_
25 #define TVM_TIR_TRANSFORM_H_
26 
27 #include <tvm/ir/transform.h>
28 #include <tvm/tir/expr.h>
29 #include <tvm/tir/function.h>
30 
31 #include <string>
32 
33 namespace tvm {
34 namespace tir {
35 namespace transform {
36 
37 using tvm::transform::Pass;
38 using tvm::transform::PassContext;
39 using tvm::transform::PassContextNode;
40 using tvm::transform::PassInfo;
41 using tvm::transform::PassInfoNode;
42 using tvm::transform::PassNode;
43 using tvm::transform::Sequential;
44 
45 /*
46  * \brief Create a function pass that optimizes PrimFuncs.
47  *
48  * \param pass_func The packed function that contains the optimization.
49  * \param opt_level The optimization level of the function pass.
50  * \param name The name of the function pass.
51  * \param required The list of the passes that the function pass is dependent on.
52  *
53  * \return The created function pass.
54  */
55 TVM_DLL Pass CreatePrimFuncPass(
56     const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
57     int opt_level, String name, tvm::Array<String> required);
58 
59 /*!
60  * \brief Inject prefetch instructions into stmt.
61  *
62  * \return The pass.
63  */
64 TVM_DLL Pass InjectPrefetch();
65 
66 // TODO(tvm-team): consolidate configs to the PassContext
67 /*!
68  * \brief Flatten the multi-dimensional read/write
69  *  to single dimensional Load/Store
70  *
71  * \param cache_line_size The size of CPU cache line.
72  * \param create_bound_attribute Whether to create bound attributes.
73  *
74  * \return The Pass
75  */
76 TVM_DLL Pass StorageFlatten(int cache_line_size, bool create_bound_attribute = false);
77 
78 /*!
79  * \brief Inject copy intrinsics with optional pad.
80  *
81  * \param pragma_key The pragma key for hint of copy.
82  * \param fintrin The function with signature
83  *
84  *   Stmt fintrin(Buffer src,
85  *                Buffer dst,
86  *                Array<Expr> pad_before,
87  *                Array<Expr> pad_after,
88  *                Expr pad_value)
89  * \return The pass.
90  */
91 TVM_DLL Pass InjectCopyIntrin(String pragma_key, runtime::PackedFunc fintrin);
92 
93 /*!
94  * \brief Detect and insert sync points to co-processor.
95  *
96  * \return The pass.
97  */
98 TVM_DLL Pass CoProcSync();
99 
100 /*!
101  * \brief Lift common attrs with attr_key to outer scope.
102  *
103  * \param attr_key The attribute key to be checked.
104  * \return The pass.
105  */
106 TVM_DLL Pass LiftAttrScope(String attr_key);
107 
108 /*!
109  * \brief partition loops in the stmt.
110  *
111  * \return The pass.
112  */
113 TVM_DLL Pass LoopPartition();
114 
115 /*!
116  * \brief Lower vectorization loops.
117  *
118  * \param enable_vectorize Whether vectorization is enabled.
119  *
120  * \return The pass.
121  */
122 TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true);
123 
124 /*!
125  * \brief Inject virtual thread loops.
126  *
127  * \return The pass.
128  */
129 TVM_DLL Pass InjectVirtualThread();
130 
131 /*!
132  * \brief Inject double buffer statements.
133  *
134  * \return The pass.
135  */
136 TVM_DLL Pass InjectDoubleBuffer();
137 
138 /*!
139  * \brief Rewrite storage allocation pattern.
140  *  Moves the allocation to outer most possible scope.
141  *  Trying to share space between allocations to make
142  *  a static allocation plan when possible.
143  *
144  * \return The pass.
145  */
146 TVM_DLL Pass StorageRewrite();
147 
148 /*!
149  * \brief unroll the constant loop marked by unroll.
150  * This pass also automatically attach pragma unroll tag to loops which meets the standard.
151  *
152  * \return The pass.
153  */
154 TVM_DLL Pass UnrollLoop();
155 
156 /*!
157  * \brief Remove No Op from the Stmt.
158  *
159  * \return The pass.
160  */
161 TVM_DLL Pass RemoveNoOp();
162 
163 /*!
164  * \brief Detect and rewrite unsafe select that contains memory access.
165  *
166  * \return The pass.
167  */
168 TVM_DLL Pass RewriteUnsafeSelect();
169 
170 /*!
171  * \brief Run arithmetic simplifications on the statements and expressions.
172  *
173  * \return The pass.
174  */
175 TVM_DLL Pass Simplify();
176 
177 /*!
178  * \brief Instruments bound checkers.
179  *
180  * \return The pass.
181  */
182 TVM_DLL Pass InstrumentBoundCheckers();
183 
184 /*!
185  * \brief Transform the high-level PrimFunc to a low-level version
186  *        that can be used as an API function.
187  *
188  *
189  *  The main task of this function is to create code to :
190  *   - Map the values in the api_args to Var that is required by body.
191  *   - Insert assertions to check type/value of the passed arguments.
192  *
193  * \param num_unpacked_args Number of arguments that
194  *         are processed in plain form instead of packed form.
195  *
196  * \note
197  *  The function signature have two cases
198  *
199  *  let num_packed_args = len(api_args) - num_unpacked_args;
200  *
201  *  if num_packed_args is zero:
202  *     f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
203  *
204  *  if num_packed_args is not zero:
205  *       f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
206  *         api_arg_k, api_arg_k+1, ... api_arg_n,
207  *         TVMValue* out_ret_val, int* out_ret_tcode)
208  *
209  *       where n == len(api_args), k == num_packed_args
210  *
211  * \return The pass.
212  */
213 TVM_DLL Pass MakePackedAPI(int num_unpacked_args);
214 
215 /*!
216  * \brief Remap the thread axis
217  *
218  *  This can be used to get equivalent program which uses
219  *  threadIdx.y in place of threadIdx.x by passing
220  *  {"threadIdx.x": thread_axis("threadIdx.y")}
221  *
222  *
223  * \return The pass.
224  */
225 TVM_DLL Pass RemapThreadAxis(Map<String, IterVar> axis_map);
226 
227 /*!
228  * \brief Lower custom datatypes.
229  *
230  * See tvm::datatypes::Registry for more information on adding custom datatypes.
231  *
232  * \return The pass.
233  */
234 TVM_DLL Pass LowerCustomDatatypes();
235 
236 /*!
237  * \brief Decorate all the function's body as device function.
238  *
239  * \return The pass.
240  */
241 TVM_DLL Pass DecorateDeviceScope();
242 
243 /*!
244  * \brief Split the function into a host function and device functions.
245  *
246  * \return The pass.
247  */
248 TVM_DLL Pass SplitHostDevice();
249 
250 /*!
251  * \brief skip assert stmt.
252  *
253  * \return The pass.
254  */
255 TVM_DLL Pass SkipAssert();
256 
257 /*!
258  * \brief Insert sync between parallel read/write of shared buffers.
259  *
260  * \param storage_scope The storage scope considered.
261  * \return The pass.
262  */
263 TVM_DLL Pass ThreadSync(String storage_scope);
264 
265 /*!
266  * \brief Lower cross thread alleduce.
267  *
268  * \return The pass.
269  */
270 TVM_DLL Pass LowerThreadAllreduce();
271 
272 /*!
273  * \brief Infer the TensorCore fragment infomation using tensor intrinsics
274  *
275  * \return The pass.
276  */
277 TVM_DLL Pass InferFragment();
278 
279 /*!
280  * \brief Lower builtin intrinsics.
281  * \return The pass.
282  */
283 TVM_DLL Pass LowerTVMBuiltin();
284 
285 /*!
286  * \brief Lower the target specific function intrinsics in each of the function.
287  *
288  * \return The pass.
289  */
290 TVM_DLL Pass LowerIntrin();
291 
292 /*!
293  * \brief Lower warp memory access to low-level device related function calls.
294  * \return The pass.
295  */
296 TVM_DLL Pass LowerWarpMemory();
297 
298 /*!
299  * \brief Lower attached storage access information on device.
300  *
301  * \note Run this pass after all storage access analysis finish.
302  *
303  * \return The pass.
304  */
305 TVM_DLL Pass LowerDeviceStorageAccessInfo();
306 
307 /*!
308  * \brief Combine context calls in the host function.
309  *
310  * \return The pass.
311  */
312 TVM_DLL Pass CombineContextCall();
313 
314 /*!
315  * \brief Narrow down PrimExpr datatype in stmt to target_bits.
316  *
317  * \param target_bits The target bits
318  *
319  * \note Run this pass after storage flatten.
320  * \return The pass.
321  */
322 TVM_DLL Pass NarrowDataType(int target_bits);
323 
324 /*!
325  * \brief Legalize bf16 typed Ops. Add a cast to fp32
326  *   before Ops, then add a cast back to bf16.
327  * \return The pass.
328  */
329 TVM_DLL Pass BF16Legalize();
330 
331 /*!
332  * \brief Rewrite the pointer content type of arguments,
333  *  as well as Alloc internal to the function to use
334  *  the most frequently accessed type for load/store
335  *  to avoid pointer casting in backend when possible.
336  *
337  * \return The pass.
338  */
339 TVM_DLL Pass PointerValueTypeRewrite();
340 
341 /*!
342  * \brief Hoist loop-invariant IfThenElse nodes to
343  * outside the elligible loops.
344  *
345  * \return The pass.
346  */
347 TVM_DLL Pass HoistIfThenElse();
348 
349 }  // namespace transform
350 }  // namespace tir
351 }  // namespace tvm
352 
353 #endif  // TVM_TIR_TRANSFORM_H_
354