1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file tvm/tir/transform.h 22 * \brief TIR specific transformation passes. 23 */ 24 #ifndef TVM_TIR_TRANSFORM_H_ 25 #define TVM_TIR_TRANSFORM_H_ 26 27 #include <tvm/ir/transform.h> 28 #include <tvm/tir/expr.h> 29 #include <tvm/tir/function.h> 30 31 #include <string> 32 33 namespace tvm { 34 namespace tir { 35 namespace transform { 36 37 using tvm::transform::Pass; 38 using tvm::transform::PassContext; 39 using tvm::transform::PassContextNode; 40 using tvm::transform::PassInfo; 41 using tvm::transform::PassInfoNode; 42 using tvm::transform::PassNode; 43 using tvm::transform::Sequential; 44 45 /* 46 * \brief Create a function pass that optimizes PrimFuncs. 47 * 48 * \param pass_func The packed function that contains the optimization. 49 * \param opt_level The optimization level of the function pass. 50 * \param name The name of the function pass. 51 * \param required The list of the passes that the function pass is dependent on. 52 * 53 * \return The created function pass. 54 */ 55 TVM_DLL Pass CreatePrimFuncPass( 56 const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func, 57 int opt_level, String name, tvm::Array<String> required); 58 59 /*! 60 * \brief Inject prefetch instructions into stmt. 61 * 62 * \return The pass. 63 */ 64 TVM_DLL Pass InjectPrefetch(); 65 66 // TODO(tvm-team): consolidate configs to the PassContext 67 /*! 68 * \brief Flatten the multi-dimensional read/write 69 * to single dimensional Load/Store 70 * 71 * \param cache_line_size The size of CPU cache line. 72 * \param create_bound_attribute Whether to create bound attributes. 73 * 74 * \return The Pass 75 */ 76 TVM_DLL Pass StorageFlatten(int cache_line_size, bool create_bound_attribute = false); 77 78 /*! 79 * \brief Inject copy intrinsics with optional pad. 80 * 81 * \param pragma_key The pragma key for hint of copy. 82 * \param fintrin The function with signature 83 * 84 * Stmt fintrin(Buffer src, 85 * Buffer dst, 86 * Array<Expr> pad_before, 87 * Array<Expr> pad_after, 88 * Expr pad_value) 89 * \return The pass. 90 */ 91 TVM_DLL Pass InjectCopyIntrin(String pragma_key, runtime::PackedFunc fintrin); 92 93 /*! 94 * \brief Detect and insert sync points to co-processor. 95 * 96 * \return The pass. 97 */ 98 TVM_DLL Pass CoProcSync(); 99 100 /*! 101 * \brief Lift common attrs with attr_key to outer scope. 102 * 103 * \param attr_key The attribute key to be checked. 104 * \return The pass. 105 */ 106 TVM_DLL Pass LiftAttrScope(String attr_key); 107 108 /*! 109 * \brief partition loops in the stmt. 110 * 111 * \return The pass. 112 */ 113 TVM_DLL Pass LoopPartition(); 114 115 /*! 116 * \brief Lower vectorization loops. 117 * 118 * \param enable_vectorize Whether vectorization is enabled. 119 * 120 * \return The pass. 121 */ 122 TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true); 123 124 /*! 125 * \brief Inject virtual thread loops. 126 * 127 * \return The pass. 128 */ 129 TVM_DLL Pass InjectVirtualThread(); 130 131 /*! 132 * \brief Inject double buffer statements. 133 * 134 * \return The pass. 135 */ 136 TVM_DLL Pass InjectDoubleBuffer(); 137 138 /*! 139 * \brief Rewrite storage allocation pattern. 140 * Moves the allocation to outer most possible scope. 141 * Trying to share space between allocations to make 142 * a static allocation plan when possible. 143 * 144 * \return The pass. 145 */ 146 TVM_DLL Pass StorageRewrite(); 147 148 /*! 149 * \brief unroll the constant loop marked by unroll. 150 * This pass also automatically attach pragma unroll tag to loops which meets the standard. 151 * 152 * \return The pass. 153 */ 154 TVM_DLL Pass UnrollLoop(); 155 156 /*! 157 * \brief Remove No Op from the Stmt. 158 * 159 * \return The pass. 160 */ 161 TVM_DLL Pass RemoveNoOp(); 162 163 /*! 164 * \brief Detect and rewrite unsafe select that contains memory access. 165 * 166 * \return The pass. 167 */ 168 TVM_DLL Pass RewriteUnsafeSelect(); 169 170 /*! 171 * \brief Run arithmetic simplifications on the statements and expressions. 172 * 173 * \return The pass. 174 */ 175 TVM_DLL Pass Simplify(); 176 177 /*! 178 * \brief Instruments bound checkers. 179 * 180 * \return The pass. 181 */ 182 TVM_DLL Pass InstrumentBoundCheckers(); 183 184 /*! 185 * \brief Transform the high-level PrimFunc to a low-level version 186 * that can be used as an API function. 187 * 188 * 189 * The main task of this function is to create code to : 190 * - Map the values in the api_args to Var that is required by body. 191 * - Insert assertions to check type/value of the passed arguments. 192 * 193 * \param num_unpacked_args Number of arguments that 194 * are processed in plain form instead of packed form. 195 * 196 * \note 197 * The function signature have two cases 198 * 199 * let num_packed_args = len(api_args) - num_unpacked_args; 200 * 201 * if num_packed_args is zero: 202 * f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args) 203 * 204 * if num_packed_args is not zero: 205 * f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args, 206 * api_arg_k, api_arg_k+1, ... api_arg_n, 207 * TVMValue* out_ret_val, int* out_ret_tcode) 208 * 209 * where n == len(api_args), k == num_packed_args 210 * 211 * \return The pass. 212 */ 213 TVM_DLL Pass MakePackedAPI(int num_unpacked_args); 214 215 /*! 216 * \brief Remap the thread axis 217 * 218 * This can be used to get equivalent program which uses 219 * threadIdx.y in place of threadIdx.x by passing 220 * {"threadIdx.x": thread_axis("threadIdx.y")} 221 * 222 * 223 * \return The pass. 224 */ 225 TVM_DLL Pass RemapThreadAxis(Map<String, IterVar> axis_map); 226 227 /*! 228 * \brief Lower custom datatypes. 229 * 230 * See tvm::datatypes::Registry for more information on adding custom datatypes. 231 * 232 * \return The pass. 233 */ 234 TVM_DLL Pass LowerCustomDatatypes(); 235 236 /*! 237 * \brief Decorate all the function's body as device function. 238 * 239 * \return The pass. 240 */ 241 TVM_DLL Pass DecorateDeviceScope(); 242 243 /*! 244 * \brief Split the function into a host function and device functions. 245 * 246 * \return The pass. 247 */ 248 TVM_DLL Pass SplitHostDevice(); 249 250 /*! 251 * \brief skip assert stmt. 252 * 253 * \return The pass. 254 */ 255 TVM_DLL Pass SkipAssert(); 256 257 /*! 258 * \brief Insert sync between parallel read/write of shared buffers. 259 * 260 * \param storage_scope The storage scope considered. 261 * \return The pass. 262 */ 263 TVM_DLL Pass ThreadSync(String storage_scope); 264 265 /*! 266 * \brief Lower cross thread alleduce. 267 * 268 * \return The pass. 269 */ 270 TVM_DLL Pass LowerThreadAllreduce(); 271 272 /*! 273 * \brief Infer the TensorCore fragment infomation using tensor intrinsics 274 * 275 * \return The pass. 276 */ 277 TVM_DLL Pass InferFragment(); 278 279 /*! 280 * \brief Lower builtin intrinsics. 281 * \return The pass. 282 */ 283 TVM_DLL Pass LowerTVMBuiltin(); 284 285 /*! 286 * \brief Lower the target specific function intrinsics in each of the function. 287 * 288 * \return The pass. 289 */ 290 TVM_DLL Pass LowerIntrin(); 291 292 /*! 293 * \brief Lower warp memory access to low-level device related function calls. 294 * \return The pass. 295 */ 296 TVM_DLL Pass LowerWarpMemory(); 297 298 /*! 299 * \brief Lower attached storage access information on device. 300 * 301 * \note Run this pass after all storage access analysis finish. 302 * 303 * \return The pass. 304 */ 305 TVM_DLL Pass LowerDeviceStorageAccessInfo(); 306 307 /*! 308 * \brief Combine context calls in the host function. 309 * 310 * \return The pass. 311 */ 312 TVM_DLL Pass CombineContextCall(); 313 314 /*! 315 * \brief Narrow down PrimExpr datatype in stmt to target_bits. 316 * 317 * \param target_bits The target bits 318 * 319 * \note Run this pass after storage flatten. 320 * \return The pass. 321 */ 322 TVM_DLL Pass NarrowDataType(int target_bits); 323 324 /*! 325 * \brief Legalize bf16 typed Ops. Add a cast to fp32 326 * before Ops, then add a cast back to bf16. 327 * \return The pass. 328 */ 329 TVM_DLL Pass BF16Legalize(); 330 331 /*! 332 * \brief Rewrite the pointer content type of arguments, 333 * as well as Alloc internal to the function to use 334 * the most frequently accessed type for load/store 335 * to avoid pointer casting in backend when possible. 336 * 337 * \return The pass. 338 */ 339 TVM_DLL Pass PointerValueTypeRewrite(); 340 341 /*! 342 * \brief Hoist loop-invariant IfThenElse nodes to 343 * outside the elligible loops. 344 * 345 * \return The pass. 346 */ 347 TVM_DLL Pass HoistIfThenElse(); 348 349 } // namespace transform 350 } // namespace tir 351 } // namespace tvm 352 353 #endif // TVM_TIR_TRANSFORM_H_ 354