1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file auto_scheduler/loop_state.h 22 * \brief The definition of the "state" in the search. 23 * 24 * Each LoopState corresponds to a schedule for its ComputeDAG. 25 * A LoopState consists of: 1. a current loop structure; 2. a list of transformation steps used to 26 * construct the loop structure. 27 * The loop structure keeps a preview of how the schedule will finally look like after lowering the 28 * current state (e.g. number of iterators, the extent of each iterator, the compute_at locations 29 * ...). 30 * During the schedule search process, the loop structure can provide search policy with necessary 31 * information on how to manipulate the current state. 32 * The transform history is a sequence of `TransformStep` which will finally be mapped to TVM 33 * schedule primitives. The steps are also used for the serialization of a state. 34 * 35 * The LoopState can be seen as a lightweight loop structure IR specifically for schedule search. 36 * We don't use the existing TVM IR but to extend a new structure on it is because: 37 * 1. We want fast incremental change to the loop structures. The search policy needs to get the 38 * immediate loop structures update rather than after TVM lowering; 39 * 2. We want serializable transform history for replay, backtracking, and mutation; 40 * 3. We may create some macro schedule primitives that represent the combination of several 41 * TVM schedule primitives. 42 * 43 * When the search is finished, we will lower the state to TVM IR with TVM's schedule primitives. 44 * Since we share a lot of common objects during search, the transformation is implemented in 45 * copy on write style. All objects are immutable, which is similar to TVM IR. 46 */ 47 48 #ifndef TVM_AUTO_SCHEDULER_LOOP_STATE_H_ 49 #define TVM_AUTO_SCHEDULER_LOOP_STATE_H_ 50 51 #include <dmlc/common.h> 52 #include <tvm/auto_scheduler/transform_step.h> 53 #include <tvm/runtime/container.h> 54 55 #include <functional> 56 #include <unordered_map> 57 #include <utility> 58 #include <vector> 59 60 namespace tvm { 61 namespace auto_scheduler { 62 63 using namespace tvm::tir; 64 65 class ComputeDAG; 66 67 /*! \brief The type of a stage. */ 68 enum class StageKind : int { 69 /*! \brief A placeholder stage. */ 70 kPlaceholder = 0, 71 /*! \brief A compute stage. */ 72 kCompute = 1 73 }; 74 75 /*! \brief The type of compute location. */ 76 enum class ComputeAtKind : int { 77 /*! \brief Compute at root. */ 78 kRoot = 0, 79 /*! \brief Compute inlined. */ 80 kInlined = 1, 81 /*! \brief Compute at some iterator. */ 82 kIter = 2, 83 }; 84 85 /*! \brief Stage-level attributes. */ 86 struct StageAttributes { 87 /*! \brief The maximum steps for the pragma `auto_unroll_max_step`. */ 88 int auto_unroll_max_step; 89 /*! \brief The storage offset for the schedule primitive `storage_align`. */ 90 int storage_offset; 91 }; 92 93 /*! 94 * \brief A op stage in the compute declaration. 95 * Similar to te::Stage in `include/tvm/te/schedule.h`. 96 */ 97 class StageNode : public Object { 98 public: 99 /*! \brief The operator of this stage */ 100 te::Operation op; 101 /*! \brief The iterators in this stage. */ 102 Array<Iterator> iters; 103 /*! \brief The type of this stage. */ 104 StageKind op_type; 105 /*! \brief The compute location of this stage. */ 106 ComputeAtKind compute_at; 107 /*! \brief Other stage-level attributes. */ 108 StageAttributes attrs; 109 VisitAttrs(tvm::AttrVisitor * v)110 void VisitAttrs(tvm::AttrVisitor* v) { 111 v->Visit("op", &op); 112 v->Visit("iters", &iters); 113 v->Visit("op_type", &op_type); 114 v->Visit("compute_at", &compute_at); 115 } 116 117 static constexpr const char* _type_key = "auto_scheduler.Stage"; 118 TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object); 119 }; 120 121 /*! 122 * \brief Managed reference to StageNode. 123 * \sa StageNode 124 */ 125 class Stage : public ObjectRef { 126 public: 127 /*! 128 * \brief The constructor. 129 * \param op A `te::Operation`. 130 */ 131 explicit Stage(te::Operation op); 132 /*! 133 * \brief The constructor. 134 * \param op The source operation 135 * \param op_type The stage type of this op. 136 * \param iters The iterators of this op. 137 * \param compute_at The compute at type of this op. 138 * \param attrs Other stage-level attributes. 139 */ 140 Stage(te::Operation op, StageKind op_type, const Array<Iterator>& iters, ComputeAtKind compute_at, 141 StageAttributes attrs); 142 143 TVM_DEFINE_OBJECT_REF_METHODS(Stage, ObjectRef, StageNode); 144 TVM_DEFINE_OBJECT_REF_COW_METHOD(StageNode); 145 }; 146 147 /*! \brief Use stage_id to represent a stage. */ 148 using StageKey = int; 149 /*! \brief Use stage_id and iter_id to represent a iterator. */ 150 using IterKey = std::pair<int, int>; 151 152 /*! 153 * \brief stores the compute_at relation between stages 154 * This stores a bi-directional mapping from stages and iter: 155 * 1. Stage to its attached iterator 156 * 2. Iterator to the stage attached to it 157 * You can use AttachMapNode::stage_to_attach_iter and AttachMapNode::iter_to_attached_stages 158 * to query the relations 159 */ 160 class AttachMapNode : public Object { 161 public: 162 struct IterKeyHash { operatorIterKeyHash163 std::size_t operator()(const IterKey& k) const { 164 return ::dmlc::HashCombine(std::hash<int>()(k.first), std::hash<int>()(k.second)); 165 } 166 }; 167 168 /*! \brief A Map to store the mapping of stage to its attached iterator. */ 169 std::unordered_map<StageKey, IterKey> stage_to_attach_iter; 170 /*! \brief A Map to store the mapping of iterator to the stages attached to it. */ 171 std::unordered_map<IterKey, std::vector<StageKey>, IterKeyHash> iter_to_attached_stages; 172 173 static constexpr const char* _type_key = "auto_scheduler.AttachMap"; 174 TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object); 175 }; 176 177 /*! 178 * \brief Managed reference to AttachMapNode. 179 * \sa AttachMapNode 180 */ 181 class AttachMap : public ObjectRef { 182 public: 183 /*! 184 * \brief Process the stage/iterator mapping after compute at. 185 * \param stage_id The index of the source stage of computed at. 186 * \param target_stage_id The index of stage that this step will compute at to. 187 * \param target_iter_id The index of target iterator in the target stage. 188 */ 189 void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id); 190 191 /*! 192 * \brief Delete the entry of a specific stage. This is a public wrapper of `DeleteStageEntry`. 193 * \param stage_id The index of the stage to be deleted. 194 */ 195 void DeleteStage(int stage_id); 196 197 /*! 198 * \brief Find the relations of original iterators in AttachMap, and update them with the new 199 * iterators. Both `stage_to_attach_iter` and `iter_to_attached_stages` will be updated. 200 * \param original_iters The original IterKey. 201 * \param new_iters The new IterKey for replacing the old ones. 202 */ 203 void UpdateIters(const std::vector<IterKey>& original_iters, 204 const std::vector<IterKey>& new_iters); 205 206 /*! 207 * \brief Traverse through `stage_to_attach_iter` and `iter_to_attached_stages` map, add offset 208 * to stage indexes that are larger than the start_id. Used for steps that insert new stages to 209 * ComputeDAG (e.g., CacheRead/CacheWrite step). 210 * \param start_id The index threshold. This function only adds offset for stages 211 * with indices larger then this threshold. 212 * \param offset The index offset to be added to the stage index. 213 * \return The updated AttachMap after applying stage index offset. 214 */ 215 AttachMap ApplyStageIdOffset(int start_id, int offset = 1) const; 216 217 TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode); 218 TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode); 219 220 private: 221 /*! 222 * \brief Delete the entry of a specific stage. This will remove the items related to this 223 * stage in both `stage_to_attach_iter` and `iter_to_attached_stages` map. 224 * \param pnode A mutable pointer to AttachMapNode. 225 * \param stage_id The index of stage that will be removed from the map. 226 */ 227 static void DeleteStageEntry(AttachMapNode* pnode, int stage_id); 228 }; 229 230 /*! 231 * \brief A state in the search process. 232 * It consists of the current loop structure and a list of transformation steps used to construct 233 * it. 234 * Each State corresponds to a specific schedule for its ComputeDAG. 235 */ 236 class StateNode : public Object { 237 public: 238 /*! \brief Current stages and loop structures. */ 239 Array<Stage> stages; 240 /*! \brief History transformation steps. */ 241 Array<Step> transform_steps; 242 /*! 243 * \brief The attach relations of stages and iterators. This is used to track the compute at 244 * operation. 245 */ 246 AttachMap attach_map; 247 /*! \brief The up-to-date ComputeDAG of this state. The default value is an empty NullOpt, 248 * meaning the dag of this state is the same as the original ComputeDAG in the SearchTask. 249 * Otherwise, the stored value is the up-to-date ComputeDAG for this state, meaning some steps 250 * (e.g., CacheReadStep/CacheWriteStep) have modified the ComputeDAG. 251 */ 252 Optional<ObjectRef> current_compute_dag; 253 /*! 254 * \brief Indicate whether this state has unfilled tile sizes. A concrete state means that all 255 * tile sizes of the state is filled. Only concrete state can be apply to TVM schedule. 256 */ 257 bool concrete; 258 VisitAttrs(tvm::AttrVisitor * v)259 void VisitAttrs(tvm::AttrVisitor* v) { 260 v->Visit("stages", &stages); 261 v->Visit("transform_steps", &transform_steps); 262 v->Visit("concrete", &concrete); 263 } 264 265 static constexpr const char* _type_key = "auto_scheduler.State"; 266 TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object); 267 }; 268 269 /*! 270 * \brief Managed reference to StateNode. 271 * \sa StateNode 272 */ 273 class State : public ObjectRef { 274 public: 275 /*! 276 * \brief The constructor. 277 * \param ops `te::Operation`s for a compute declaration. 278 */ 279 explicit State(const Array<te::Operation>& ops); 280 281 /*! 282 * \brief Pretty-print the state to a human readable string. 283 * \param delete_trivial_loop True for skipping the trivial loops. 284 * (undefined or extent == 1, default set to True) 285 * \return The human readable string. 286 */ 287 String ToStr(bool delete_trivial_loop = true) const; 288 289 /********** Step APIs working on a single stage **********/ 290 /*! 291 * \brief The schedule primitive corresponding to `te::Stage::bind`. 292 * \param stage_id The index of the stage to be binded. 293 * \param it The iterator to be binded. 294 * \param thread_type The thread type. 295 * \return The new iterator after binding. 296 */ 297 TVM_DLL Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type); 298 /*! 299 * \brief The schedule primitive corresponding to `te::Stage::parallel`. 300 * \param stage_id The index of the stage to be paralleled. 301 * \param it The iterator to be paralleled. 302 * \return The new iterator after parallel. 303 */ 304 TVM_DLL Iterator parallel(int stage_id, const Iterator& it); 305 /*! 306 * \brief The schedule primitive corresponding to `te::Stage::unroll`. 307 * \param stage_id The index of the stage to be unrolled. 308 * \param it The iterator to be unrolled. 309 * \param max_unroll The max unroll limit. Iterator with extent larger than this limit will be 310 * skipped. 311 * \return The new iterator after unroll. 312 */ 313 TVM_DLL Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); 314 /*! 315 * \brief The schedule primitive corresponding to `te::Stage::vectorize`. 316 * \param stage_id The index of the stage to be vectorized. 317 * \param it The iterator to be vectorized. 318 * \return The new iterator after vectorization. 319 */ 320 TVM_DLL Iterator vectorize(int stage_id, const Iterator& it); 321 /*! 322 * \brief The schedule primitive corresponding to `te::Stage::fuse`. 323 * \param stage_id The index of the stage to be fused. 324 * \param iters The iterators to be fused. 325 * \return The iterator result after fuse. 326 * \note If the iterators to be fused have stages attached at them(by compute_at), the fused 327 * result will become the new attach point. 328 */ 329 TVM_DLL Iterator fuse(int stage_id, const Array<Iterator>& iters); 330 /*! 331 * \brief The schedule primitive corresponding to `te.Stage.pragma`. 332 * \param stage_id The index of the stage to add pragma. 333 * \param it The iterator to add pragma. 334 * \param pragma_type The pragma string. 335 */ 336 TVM_DLL void pragma(int stage_id, const Iterator& it, const String& pragma_type); 337 /*! 338 * \brief The schedule primitive corresponding to `te::Stage::reorder`. 339 * \param stage_id The index of the stage to be reordered. 340 * \param order The expected iterator order. 341 */ 342 TVM_DLL void reorder(int stage_id, const Array<Iterator>& order); 343 /*! 344 * \brief The schedule primitive corresponding to `te::Stage::split`. 345 * \param stage_id The index of the stage to be split. 346 * \param it The iterator to be split. 347 * \param lengths The multiple split factors. Can be None to be filled by search policy. 348 * \param inner_to_outer Whether the factors go from inner to outer, or from outer to inner. 349 * \return The new iterator after splitting. 350 * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner 351 * most iterator of split results will become the new attach point. 352 */ 353 TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it, 354 const Array<Optional<Integer>>& lengths, 355 bool inner_to_outer = true); 356 /*! 357 * \brief The schedule primitive similar to split, but uses split factors from previous steps. 358 * \param stage_id The index of the stage to be split. 359 * \param it The iterator to be split. 360 * \param src_step_id The index of the split step to be followed in the history. 361 * \param n_split The number of split level. 362 * \return The split new Iterators. 363 */ 364 TVM_DLL Array<Iterator> follow_split(int stage_id, const Iterator& it, int src_step_id, 365 int n_split); 366 /*! 367 * \brief The schedule primitive similar to split, but uses split factors from 368 * fused previous steps. 369 * \param stage_id The index of the stage to be split. 370 * \param it The iterator to be split. 371 * \param src_step_ids The indices of the split steps to be followed in the history. 372 * \param level Use the length in this split level. 373 * \param factor_or_nparts True to use `factor` for split from inner to outer, 374 False to use `nparts` for split from outer to inner. 375 * \return The split new Iterators. 376 */ 377 TVM_DLL Array<Iterator> follow_fused_split(int stage_id, const Iterator& it, 378 const Array<Integer>& src_step_ids, int level, 379 bool factor_or_nparts); 380 /*! 381 * \brief The schedule primitive corresponding to `te.Stage.storage_align`. 382 * \param stage_id The index of the stage to be aligned. 383 * \param it The iterator to be aligned. 384 * \param factor The factor in alignment specification. 385 * \param offset The offset in the alignment specification. 386 */ 387 TVM_DLL void storage_align(int stage_id, const Iterator& it, int factor, int offset); 388 389 /********** Step APIs working on multiple stages **********/ 390 /*! 391 * \brief The schedule primitive corresponding to `te::Stage::compute_at`. 392 * \param stage_id The index of the source stage of computed at. 393 * \param target_stage_id The index of stage that this step will compute at to. 394 * \param target_iter The indiex of the target iterator in the target stage. 395 * \note After compute_at, we need careful dependency analysis to compute the accurate bound 396 * information. However, it is relatively expensive and complicated, so we just fill "None" as 397 * bound for the newly created iterators. 398 * Call ComputeDAG::InferBound on the updated state if you need the complete bound information. 399 */ 400 TVM_DLL void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); 401 /*! 402 * \brief The schedule primitive corresponding to `te::Stage::compute_inline`. 403 * \param stage_id The index of the stage to be marked compute inlined. 404 */ 405 TVM_DLL void compute_inline(int stage_id); 406 /*! 407 * \brief The schedule primitive corresponding to `te::Stage::compute_root`. 408 * \param stage_id The index of the stage to be marked compute at root. 409 * \note After compute_root, we need careful dependency analysis to compute the accurate bound 410 * information. However, it is relatively expensive and complicated, so we just fill "None" as 411 * bound for the newly created iterators. 412 * Call ComputeDAG::InferBound on the updated state if you need the complete bound information. 413 */ 414 TVM_DLL void compute_root(int stage_id); 415 416 /********** Step APIs adding new stages **********/ 417 /*! 418 * \brief The schedule primitive corresponding to `te::Schedule::cache_read`. 419 * \param stage_id The index of the stage to be cache_read. 420 * \param scope_name The scope name of the newly added stage. 421 * \param reader_stage_ids The indices of reader stages. 422 * \param dag The original ComputeDAG of this state. 423 * \note Cache read step will add an extra stage to the original ComputeDAG (at the back of the 424 * target stage), an up-to-date ComputeDAG is stored in State's `current_compute_dag`. 425 */ 426 TVM_DLL int cache_read(int stage_id, const String& scope_name, 427 const Array<Integer>& reader_stage_ids, const ComputeDAG& dag); 428 /*! 429 * \brief The schedule primitive corresponding to `te::Schedule::cache_write`. 430 * \param stage_id The index of the stage to be cache_write. 431 * \param scope_name The scope name of the newly added stage. 432 * \param dag The original ComputeDAG of this state. 433 * \note Cache write step will add an extra stage to the original ComputeDAG (in the front of the 434 * target stage), an up-to-date ComputeDAG is stored in State's `current_compute_dag`. 435 * This step will cache write all output tensors of the target stage. 436 */ 437 TVM_DLL int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag); 438 /*! 439 * \brief The schedule primitive corresponding to `te::Schedule::rfactor`. 440 * \param stage_id The index of the iterator to be factored. 441 * \param it The iterator to be factored. 442 * \param factor_iter_id The position where the new iterator is placed. 443 * \param dag The original ComputeDAG of this state. 444 * \note Rfactor step will add an extra stage to the original ComputeDAG (in the front of the 445 * target stage), an up-to-date ComputeDAG is stored in State's `current_compute_dag`. 446 */ 447 TVM_DLL int rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag); 448 449 TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); 450 TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); 451 }; 452 453 } // namespace auto_scheduler 454 } // namespace tvm 455 456 // Hash and equal function for State 457 namespace std { 458 459 /*! 460 * \brief The equal_to function for auto_scheduler::State. 461 * This function checks the equality by looking at the lowered string format of states. 462 * If two states with different transform history have the same lowered string format, 463 * they will be considered being equal. 464 */ 465 template <> 466 struct equal_to<::tvm::auto_scheduler::State> { 467 bool operator()(const ::tvm::auto_scheduler::State& lhs, 468 const ::tvm::auto_scheduler::State& rhs) const { 469 return lhs.ToStr() == rhs.ToStr(); 470 } 471 }; 472 473 /*! \brief The hash function for auto_scheduler::State. */ 474 template <> 475 struct hash<::tvm::auto_scheduler::State> { 476 std::size_t operator()(const ::tvm::auto_scheduler::State& state) const { 477 return tvm::runtime::ObjectHash()(state.ToStr()); 478 } 479 }; 480 481 } // namespace std 482 483 #endif // TVM_AUTO_SCHEDULER_LOOP_STATE_H_ 484