1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file auto_scheduler/loop_state.h
22  * \brief The definition of the "state" in the search.
23  *
24  * Each LoopState corresponds to a schedule for its ComputeDAG.
25  * A LoopState consists of: 1. a current loop structure; 2. a list of transformation steps used to
26  * construct the loop structure.
27  * The loop structure keeps a preview of how the schedule will finally look like after lowering the
28  * current state (e.g. number of iterators, the extent of each iterator, the compute_at locations
29  * ...).
30  * During the schedule search process, the loop structure can provide search policy with necessary
31  * information on how to manipulate the current state.
32  * The transform history is a sequence of `TransformStep` which will finally be mapped to TVM
33  * schedule primitives. The steps are also used for the serialization of a state.
34  *
35  * The LoopState can be seen as a lightweight loop structure IR specifically for schedule search.
36  * We don't use the existing TVM IR but to extend a new structure on it is because:
37  * 1. We want fast incremental change to the loop structures. The search policy needs to get the
38  * immediate loop structures update rather than after TVM lowering;
39  * 2. We want serializable transform history for replay, backtracking, and mutation;
40  * 3. We may create some macro schedule primitives that represent the combination of several
41  * TVM schedule primitives.
42  *
43  * When the search is finished, we will lower the state to TVM IR with TVM's schedule primitives.
44  * Since we share a lot of common objects during search, the transformation is implemented in
45  * copy on write style. All objects are immutable, which is similar to TVM IR.
46  */
47 
48 #ifndef TVM_AUTO_SCHEDULER_LOOP_STATE_H_
49 #define TVM_AUTO_SCHEDULER_LOOP_STATE_H_
50 
51 #include <dmlc/common.h>
52 #include <tvm/auto_scheduler/transform_step.h>
53 #include <tvm/runtime/container.h>
54 
55 #include <functional>
56 #include <unordered_map>
57 #include <utility>
58 #include <vector>
59 
60 namespace tvm {
61 namespace auto_scheduler {
62 
63 using namespace tvm::tir;
64 
65 class ComputeDAG;
66 
67 /*! \brief The type of a stage. */
68 enum class StageKind : int {
69   /*! \brief A placeholder stage. */
70   kPlaceholder = 0,
71   /*! \brief A compute stage. */
72   kCompute = 1
73 };
74 
75 /*! \brief The type of compute location. */
76 enum class ComputeAtKind : int {
77   /*! \brief Compute at root. */
78   kRoot = 0,
79   /*! \brief Compute inlined. */
80   kInlined = 1,
81   /*! \brief Compute at some iterator. */
82   kIter = 2,
83 };
84 
85 /*! \brief Stage-level attributes. */
86 struct StageAttributes {
87   /*! \brief The maximum steps for the pragma `auto_unroll_max_step`. */
88   int auto_unroll_max_step;
89   /*! \brief The storage offset for the schedule primitive `storage_align`. */
90   int storage_offset;
91 };
92 
93 /*!
94  * \brief A op stage in the compute declaration.
95  * Similar to te::Stage in `include/tvm/te/schedule.h`.
96  */
97 class StageNode : public Object {
98  public:
99   /*! \brief The operator of this stage */
100   te::Operation op;
101   /*! \brief The iterators in this stage. */
102   Array<Iterator> iters;
103   /*! \brief The type of this stage. */
104   StageKind op_type;
105   /*! \brief The compute location of this stage. */
106   ComputeAtKind compute_at;
107   /*! \brief Other stage-level attributes. */
108   StageAttributes attrs;
109 
VisitAttrs(tvm::AttrVisitor * v)110   void VisitAttrs(tvm::AttrVisitor* v) {
111     v->Visit("op", &op);
112     v->Visit("iters", &iters);
113     v->Visit("op_type", &op_type);
114     v->Visit("compute_at", &compute_at);
115   }
116 
117   static constexpr const char* _type_key = "auto_scheduler.Stage";
118   TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object);
119 };
120 
121 /*!
122  * \brief Managed reference to StageNode.
123  * \sa StageNode
124  */
125 class Stage : public ObjectRef {
126  public:
127   /*!
128    * \brief The constructor.
129    * \param op A `te::Operation`.
130    */
131   explicit Stage(te::Operation op);
132   /*!
133    * \brief The constructor.
134    * \param op The source operation
135    * \param op_type The stage type of this op.
136    * \param iters The iterators of this op.
137    * \param compute_at The compute at type of this op.
138    * \param attrs Other stage-level attributes.
139    */
140   Stage(te::Operation op, StageKind op_type, const Array<Iterator>& iters, ComputeAtKind compute_at,
141         StageAttributes attrs);
142 
143   TVM_DEFINE_OBJECT_REF_METHODS(Stage, ObjectRef, StageNode);
144   TVM_DEFINE_OBJECT_REF_COW_METHOD(StageNode);
145 };
146 
147 /*! \brief Use stage_id to represent a stage. */
148 using StageKey = int;
149 /*! \brief Use stage_id and iter_id to represent a iterator. */
150 using IterKey = std::pair<int, int>;
151 
152 /*!
153  * \brief stores the compute_at relation between stages
154  * This stores a bi-directional mapping from stages and iter:
155  * 1. Stage to its attached iterator
156  * 2. Iterator to the stage attached to it
157  * You can use AttachMapNode::stage_to_attach_iter and AttachMapNode::iter_to_attached_stages
158  * to query the relations
159  */
160 class AttachMapNode : public Object {
161  public:
162   struct IterKeyHash {
operatorIterKeyHash163     std::size_t operator()(const IterKey& k) const {
164       return ::dmlc::HashCombine(std::hash<int>()(k.first), std::hash<int>()(k.second));
165     }
166   };
167 
168   /*! \brief A Map to store the mapping of stage to its attached iterator. */
169   std::unordered_map<StageKey, IterKey> stage_to_attach_iter;
170   /*! \brief A Map to store the mapping of iterator to the stages attached to it. */
171   std::unordered_map<IterKey, std::vector<StageKey>, IterKeyHash> iter_to_attached_stages;
172 
173   static constexpr const char* _type_key = "auto_scheduler.AttachMap";
174   TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object);
175 };
176 
177 /*!
178  * \brief Managed reference to AttachMapNode.
179  * \sa AttachMapNode
180  */
181 class AttachMap : public ObjectRef {
182  public:
183   /*!
184    * \brief Process the stage/iterator mapping after compute at.
185    * \param stage_id The index of the source stage of computed at.
186    * \param target_stage_id The index of stage that this step will compute at to.
187    * \param target_iter_id The index of target iterator in the target stage.
188    */
189   void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id);
190 
191   /*!
192    * \brief Delete the entry of a specific stage. This is a public wrapper of `DeleteStageEntry`.
193    * \param stage_id The index of the stage to be deleted.
194    */
195   void DeleteStage(int stage_id);
196 
197   /*!
198    * \brief Find the relations of original iterators in AttachMap, and update them with the new
199    * iterators. Both `stage_to_attach_iter` and `iter_to_attached_stages` will be updated.
200    * \param original_iters The original IterKey.
201    * \param new_iters The new IterKey for replacing the old ones.
202    */
203   void UpdateIters(const std::vector<IterKey>& original_iters,
204                    const std::vector<IterKey>& new_iters);
205 
206   /*!
207    * \brief Traverse through `stage_to_attach_iter` and `iter_to_attached_stages` map, add offset
208    * to stage indexes that are larger than the start_id. Used for steps that insert new stages to
209    * ComputeDAG (e.g., CacheRead/CacheWrite step).
210    * \param start_id The index threshold. This function only adds offset for stages
211    * with indices larger then this threshold.
212    * \param offset The index offset to be added to the stage index.
213    * \return The updated AttachMap after applying stage index offset.
214    */
215   AttachMap ApplyStageIdOffset(int start_id, int offset = 1) const;
216 
217   TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode);
218   TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode);
219 
220  private:
221   /*!
222    * \brief Delete the entry of a specific stage. This will remove the items related to this
223    * stage in both `stage_to_attach_iter` and `iter_to_attached_stages` map.
224    * \param pnode A mutable pointer to AttachMapNode.
225    * \param stage_id The index of stage that will be removed from the map.
226    */
227   static void DeleteStageEntry(AttachMapNode* pnode, int stage_id);
228 };
229 
230 /*!
231  * \brief A state in the search process.
232  * It consists of the current loop structure and a list of transformation steps used to construct
233  * it.
234  * Each State corresponds to a specific schedule for its ComputeDAG.
235  */
236 class StateNode : public Object {
237  public:
238   /*! \brief Current stages and loop structures. */
239   Array<Stage> stages;
240   /*! \brief History transformation steps. */
241   Array<Step> transform_steps;
242   /*!
243    * \brief The attach relations of stages and iterators. This is used to track the compute at
244    * operation.
245    */
246   AttachMap attach_map;
247   /*! \brief The up-to-date ComputeDAG of this state. The default value is an empty NullOpt,
248    * meaning the dag of this state is the same as the original ComputeDAG in the SearchTask.
249    * Otherwise, the stored value is the up-to-date ComputeDAG for this state, meaning some steps
250    * (e.g., CacheReadStep/CacheWriteStep) have modified the ComputeDAG.
251    */
252   Optional<ObjectRef> current_compute_dag;
253   /*!
254    * \brief Indicate whether this state has unfilled tile sizes. A concrete state means that all
255    * tile sizes of the state is filled. Only concrete state can be apply to TVM schedule.
256    */
257   bool concrete;
258 
VisitAttrs(tvm::AttrVisitor * v)259   void VisitAttrs(tvm::AttrVisitor* v) {
260     v->Visit("stages", &stages);
261     v->Visit("transform_steps", &transform_steps);
262     v->Visit("concrete", &concrete);
263   }
264 
265   static constexpr const char* _type_key = "auto_scheduler.State";
266   TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object);
267 };
268 
269 /*!
270  * \brief Managed reference to StateNode.
271  * \sa StateNode
272  */
273 class State : public ObjectRef {
274  public:
275   /*!
276    * \brief The constructor.
277    * \param ops `te::Operation`s for a compute declaration.
278    */
279   explicit State(const Array<te::Operation>& ops);
280 
281   /*!
282    * \brief Pretty-print the state to a human readable string.
283    * \param delete_trivial_loop True for skipping the trivial loops.
284    * (undefined or extent == 1, default set to True)
285    * \return The human readable string.
286    */
287   String ToStr(bool delete_trivial_loop = true) const;
288 
289   /********** Step APIs working on a single stage **********/
290   /*!
291    * \brief The schedule primitive corresponding to `te::Stage::bind`.
292    * \param stage_id The index of the stage to be binded.
293    * \param it The iterator to be binded.
294    * \param thread_type The thread type.
295    * \return The new iterator after binding.
296    */
297   TVM_DLL Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type);
298   /*!
299    * \brief The schedule primitive corresponding to `te::Stage::parallel`.
300    * \param stage_id The index of the stage to be paralleled.
301    * \param it The iterator to be paralleled.
302    * \return The new iterator after parallel.
303    */
304   TVM_DLL Iterator parallel(int stage_id, const Iterator& it);
305   /*!
306    * \brief The schedule primitive corresponding to `te::Stage::unroll`.
307    * \param stage_id The index of the stage to be unrolled.
308    * \param it The iterator to be unrolled.
309    * \param max_unroll The max unroll limit. Iterator with extent larger than this limit will be
310    * skipped.
311    * \return The new iterator after unroll.
312    */
313   TVM_DLL Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1);
314   /*!
315    * \brief The schedule primitive corresponding to `te::Stage::vectorize`.
316    * \param stage_id The index of the stage to be vectorized.
317    * \param it The iterator to be vectorized.
318    * \return The new iterator after vectorization.
319    */
320   TVM_DLL Iterator vectorize(int stage_id, const Iterator& it);
321   /*!
322    * \brief The schedule primitive corresponding to `te::Stage::fuse`.
323    * \param stage_id The index of the stage to be fused.
324    * \param iters The iterators to be fused.
325    * \return The iterator result after fuse.
326    * \note If the iterators to be fused have stages attached at them(by compute_at), the fused
327    * result will become the new attach point.
328    */
329   TVM_DLL Iterator fuse(int stage_id, const Array<Iterator>& iters);
330   /*!
331    * \brief The schedule primitive corresponding to `te.Stage.pragma`.
332    * \param stage_id The index of the stage to add pragma.
333    * \param it The iterator to add pragma.
334    * \param pragma_type The pragma string.
335    */
336   TVM_DLL void pragma(int stage_id, const Iterator& it, const String& pragma_type);
337   /*!
338    * \brief The schedule primitive corresponding to `te::Stage::reorder`.
339    * \param stage_id The index of the stage to be reordered.
340    * \param order The expected iterator order.
341    */
342   TVM_DLL void reorder(int stage_id, const Array<Iterator>& order);
343   /*!
344    * \brief The schedule primitive corresponding to `te::Stage::split`.
345    * \param stage_id The index of the stage to be split.
346    * \param it The iterator to be split.
347    * \param lengths The multiple split factors. Can be None to be filled by search policy.
348    * \param inner_to_outer Whether the factors go from inner to outer, or from outer to inner.
349    * \return The new iterator after splitting.
350    * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner
351    * most iterator of split results will become the new attach point.
352    */
353   TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
354                                 const Array<Optional<Integer>>& lengths,
355                                 bool inner_to_outer = true);
356   /*!
357    * \brief The schedule primitive similar to split, but uses split factors from previous steps.
358    * \param stage_id The index of the stage to be split.
359    * \param it The iterator to be split.
360    * \param src_step_id The index of the split step to be followed in the history.
361    * \param n_split The number of split level.
362    * \return The split new Iterators.
363    */
364   TVM_DLL Array<Iterator> follow_split(int stage_id, const Iterator& it, int src_step_id,
365                                        int n_split);
366   /*!
367    * \brief The schedule primitive similar to split, but uses split factors from
368    * fused previous steps.
369    * \param stage_id The index of the stage to be split.
370    * \param it The iterator to be split.
371    * \param src_step_ids The indices of the split steps to be followed in the history.
372    * \param level Use the length in this split level.
373    * \param factor_or_nparts True to use `factor` for split from inner to outer,
374       False to use `nparts` for split from outer to inner.
375    * \return The split new Iterators.
376    */
377   TVM_DLL Array<Iterator> follow_fused_split(int stage_id, const Iterator& it,
378                                              const Array<Integer>& src_step_ids, int level,
379                                              bool factor_or_nparts);
380   /*!
381    * \brief The schedule primitive corresponding to `te.Stage.storage_align`.
382    * \param stage_id The index of the stage to be aligned.
383    * \param it The iterator to be aligned.
384    * \param factor The factor in alignment specification.
385    * \param offset The offset in the alignment specification.
386    */
387   TVM_DLL void storage_align(int stage_id, const Iterator& it, int factor, int offset);
388 
389   /********** Step APIs working on multiple stages **********/
390   /*!
391    * \brief The schedule primitive corresponding to `te::Stage::compute_at`.
392    * \param stage_id The index of the source stage of computed at.
393    * \param target_stage_id The index of stage that this step will compute at to.
394    * \param target_iter The indiex of the target iterator in the target stage.
395    * \note After compute_at, we need careful dependency analysis to compute the accurate bound
396    * information. However, it is relatively expensive and complicated, so we just fill "None" as
397    * bound for the newly created iterators.
398    * Call ComputeDAG::InferBound on the updated state if you need the complete bound information.
399    */
400   TVM_DLL void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter);
401   /*!
402    * \brief The schedule primitive corresponding to `te::Stage::compute_inline`.
403    * \param stage_id The index of the stage to be marked compute inlined.
404    */
405   TVM_DLL void compute_inline(int stage_id);
406   /*!
407    * \brief The schedule primitive corresponding to `te::Stage::compute_root`.
408    * \param stage_id The index of the stage to be marked compute at root.
409    * \note After compute_root, we need careful dependency analysis to compute the accurate bound
410    * information. However, it is relatively expensive and complicated, so we just fill "None" as
411    * bound for the newly created iterators.
412    * Call ComputeDAG::InferBound on the updated state if you need the complete bound information.
413    */
414   TVM_DLL void compute_root(int stage_id);
415 
416   /********** Step APIs adding new stages **********/
417   /*!
418    * \brief The schedule primitive corresponding to `te::Schedule::cache_read`.
419    * \param stage_id The index of the stage to be cache_read.
420    * \param scope_name The scope name of the newly added stage.
421    * \param reader_stage_ids The indices of reader stages.
422    * \param dag The original ComputeDAG of this state.
423    * \note Cache read step will add an extra stage to the original ComputeDAG (at the back of the
424    * target stage), an up-to-date ComputeDAG is stored in State's `current_compute_dag`.
425    */
426   TVM_DLL int cache_read(int stage_id, const String& scope_name,
427                          const Array<Integer>& reader_stage_ids, const ComputeDAG& dag);
428   /*!
429    * \brief The schedule primitive corresponding to `te::Schedule::cache_write`.
430    * \param stage_id The index of the stage to be cache_write.
431    * \param scope_name The scope name of the newly added stage.
432    * \param dag The original ComputeDAG of this state.
433    * \note Cache write step will add an extra stage to the original ComputeDAG (in the front of the
434    * target stage), an up-to-date ComputeDAG is stored in State's `current_compute_dag`.
435    * This step will cache write all output tensors of the target stage.
436    */
437   TVM_DLL int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag);
438   /*!
439    * \brief The schedule primitive corresponding to `te::Schedule::rfactor`.
440    * \param stage_id The index of the iterator to be factored.
441    * \param it The iterator to be factored.
442    * \param factor_iter_id The position where the new iterator is placed.
443    * \param dag The original ComputeDAG of this state.
444    * \note Rfactor step will add an extra stage to the original ComputeDAG (in the front of the
445    * target stage), an up-to-date ComputeDAG is stored in State's `current_compute_dag`.
446    */
447   TVM_DLL int rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag);
448 
449   TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
450   TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode);
451 };
452 
453 }  // namespace auto_scheduler
454 }  // namespace tvm
455 
456 // Hash and equal function for State
457 namespace std {
458 
459 /*!
460  * \brief The equal_to function for auto_scheduler::State.
461  * This function checks the equality by looking at the lowered string format of states.
462  * If two states with different transform history have the same lowered string format,
463  * they will be considered being equal.
464  */
465 template <>
466 struct equal_to<::tvm::auto_scheduler::State> {
467   bool operator()(const ::tvm::auto_scheduler::State& lhs,
468                   const ::tvm::auto_scheduler::State& rhs) const {
469     return lhs.ToStr() == rhs.ToStr();
470   }
471 };
472 
473 /*! \brief The hash function for auto_scheduler::State. */
474 template <>
475 struct hash<::tvm::auto_scheduler::State> {
476   std::size_t operator()(const ::tvm::auto_scheduler::State& state) const {
477     return tvm::runtime::ObjectHash()(state.ToStr());
478   }
479 };
480 
481 }  // namespace std
482 
483 #endif  // TVM_AUTO_SCHEDULER_LOOP_STATE_H_
484