1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file tvm/schedule.h
22  * \brief Define a schedule.
23  */
24 // Acknowledgement: Many schedule primitives originate from Halide and Loopy.
25 #ifndef TVM_SCHEDULE_H_
26 #define TVM_SCHEDULE_H_
27 
28 #include <string>
29 #include <unordered_map>
30 #include "base.h"
31 #include "expr.h"
32 #include "tensor.h"
33 #include "tensor_intrin.h"
34 
35 namespace tvm {
36 
37 // Node container for Stage
38 class StageNode;
39 // Node container for Schedule
40 class ScheduleNode;
41 // Node container for IterVarRelation
42 class IterVarRelationNode;
43 // Attribute of itervar.
44 class IterVarAttrNode;
45 
46 /*! \brief the attachment type */
47 enum AttachType : int {
48   kGroupRoot = 1,
49   kInline = 2,
50   kInlinedAlready = 3,
51   kScope = 4,
52   kScanUpdate = 5
53 };
54 
55 /*! \brief Stage, contains scheduling for a stage of computation. */
56 class Stage : public NodeRef {
57  public:
Stage()58   Stage() {}
Stage(ObjectPtr<Object> n)59   explicit Stage(ObjectPtr<Object> n) : NodeRef(n) {}
60   /*!
61    * \brief create a new schedule for op.
62    * \param op The operator in the schedule
63    */
64   explicit Stage(Operation op);
65   /*!
66    * \brief access the internal node container
67    * \return the pointer to the internal node container
68    */
69   inline const StageNode* operator->() const;
70   /*!
71    * \brief access the internal node container
72    * \return the pointer to the internal node container
73    */
74   inline StageNode* operator->();
75   /*!
76    * \brief set the memory scope of the stage
77    * \param scope The memory scope.
78    */
79   TVM_DLL Stage& set_scope(std::string scope);  // NOLINT(*)
80   /*!
81    * \brief specify the schedule to be computed at the parent schedule's scope.
82    * \param parent The parent schedule.
83    * \param scope The iteration point to carry the schedule.
84    * \return reference to self.
85    */
86   TVM_DLL Stage& compute_at(Stage parent, IterVar scope);   // NOLINT(*)
87   /*!
88    * \brief Compute the function inline.
89    * \return reference to self.
90    */
91   TVM_DLL Stage& compute_inline();   // NOLINT(*)
92   /*!
93    * \brief Compute the function at group root.
94    * \return reference to self.
95    */
96   TVM_DLL Stage& compute_root();  // NOLINT(*)
97   /*!
98    * \brief Bind the IterVar to thread index.
99    *
100    * \param ivar The IterVar to be bound.
101    * \param thread_ivar The thread axis to be bound.
102    * \return reference to self.
103    */
104   TVM_DLL Stage& bind(IterVar ivar, IterVar thread_ivar);
105   /*!
106    * \brief Set the predicate to determine whether a store to the array should be performed.
107    *  Use this when there are multiple threads performing the same store and we only
108    *  need one of them to do the store.
109    *
110    * \note This is a dangerous scheduling primitive that can change behavior of program.
111    *    Only do when we are certain that thare are duplicated stores.
112    * \param predicate The condition to be checked.
113    * \return reference to self.
114    */
115   TVM_DLL Stage& set_store_predicate(Expr predicate);
116   /*!
117    * \brief Specify environment threads that launched around the group's scope.
118    *  This can only be used in group stage.
119    * \param threads The threads to be launched around the scope.
120    * \note Each thread can only appear in one env_threads.
121    *    This is a beta feature.
122    * \return reference to self.
123    */
124   TVM_DLL Stage& env_threads(Array<IterVar> threads);
125   /*!
126    * \brief Split the parent by factor, generate
127    * \param parent The parent iteration domain.
128    * \param factor The split factor of the loop.
129    * \param p_outer The result outer domain
130    * \param p_inner The result inner domain.
131    * \return reference to self.
132    */
133   TVM_DLL Stage& split(IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner);  // NOLINT(*)
134   /*!
135    * \brief Split the iteration with given number of parts.
136    *
137    * \param parent The parent domain.
138    * \param nparts The number of parts in the outer domain.
139    * \param p_outer The result outer domain.
140    * \param p_inner The result inner domain.
141    * \return reference to self.
142    */
143   TVM_DLL Stage& split_by_nparts(IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner);   // NOLINT(*)
144   /*!
145    * \brief Fuse the inner outer domain to the target
146    * \param outer The outer domain to be fused.
147    * \param inner The inner domain to be fused
148    * \param p_target The result target domain.
149    * \return reference to self.
150    */
151   TVM_DLL Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target);  // NOLINT(*)
152   /*!
153    * \brief Fuse all the axes together into a single axis.
154    *
155    * \param axes All the axes to be fused.
156    * \param p_target The result target domain.
157    *
158    * \note axes can be an empty array,
159    *       in that case, a singleton IterVar is created and
160    *       inserted to the outermost loop.
161    *       The fuse of empty array is used to support zero-dimension tensors.
162    *
163    * \return reference to self.
164    */
165   TVM_DLL Stage& fuse(const Array<IterVar>& axes, IterVar* p_target);  // NOLINT(*)
166   /*!
167    * \brief Reorder the iteration
168    * \param order The order of iteration variable.
169    * \return reference to self.
170    */
171   TVM_DLL Stage& reorder(const Array<IterVar>& order);   // NOLINT(*)
172   /*!
173    * \brief Perform tiling on two dimensions
174    *  The final loop order from outmost to inner most are
175    *  [x_outer, y_outer, x_inner, y_inner]
176    *
177    * \param x_parent The original x dimension
178    * \param y_parent The original y dimension
179    * \param x_factor The stride factor on x axis
180    * \param y_factor The stride factor on y axis
181    * \param p_x_outer Outer axis of x dimension
182    * \param p_y_outer Outer axis of y dimension
183    * \param p_x_inner Inner axis of x dimension
184    * \param p_y_inner Inner axis of y dimension
185    * \return reference to self.
186    */
187   TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent,   // NOLINT(*)
188                      Expr x_factor, Expr y_factor,
189                      IterVar* p_x_outer, IterVar* p_y_outer,
190                      IterVar* p_x_inner, IterVar* p_y_inner);
191   /*!
192    * \brief Vectorize iteration.
193    * \param var The axis to be vectorized.
194    * \return reference to self.
195    */
196   TVM_DLL Stage& vectorize(IterVar var);   // NOLINT(*)
197   /*!
198    * \brief Replace computation of the current stage by tensor intrinsic f.
199    * \param var The axis marks beginning of tensorization.
200    *  Every operations inside the axis(include axis itself is tensorized).
201    * \param f The Tensor compute intrinsics.
202    * \return reference to self.
203    */
204   TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f);   // NOLINT(*)
205   /*!
206    * \brief Unroll iteration.
207    * \param var The axis to be unrolled.
208    * \return reference to self.
209    */
210   TVM_DLL Stage& unroll(IterVar var);   // NOLINT(*)
211   /*!
212    * \brief Parallelize iteration.
213    * \param var The axis to be parallelized.
214    * \return reference to self.
215    */
216   TVM_DLL Stage& parallel(IterVar var);   // NOLINT(*)
217   /*!
218    * \brief Annotate the iteration with pragma
219    *
220    * \param var The axis to be parallelized.
221    * \param pragma_type The pragma type.
222    * \param pragma_value The pragma value
223    *
224    * \return reference to self.
225    */
226   TVM_DLL Stage& pragma(IterVar var,
227                        const std::string& pragma_type,
228                        const Expr& pragma_value = Expr());   // NOLINT(*)
229   /*!
230    * \brief Fetch data in advance.
231    * \param domain the tensor to be prefetched
232    * \param var the iteration point at which to apply prefetching
233    * \param offset the number of iterations be to fetched in advance
234    * \return reference to self
235    */
236   TVM_DLL Stage& prefetch(const Tensor &domain, IterVar var, Expr offset); //NOLINT(*)
237   /*!
238    * \brief Set alignment requirement for specific dimension.
239    *
240    *  Such that stride[axis] == k * factor + offset for some k.
241    *
242    * \param axis The dimension to be specified for alignment.
243    * \param factor The factor multiple of alignment
244    * \param offset The required offset factor.
245    * \return reference to self
246    */
247   TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*)
248   /*!
249    * \brief Compute current stage with double buffering.
250    * \return reference to self.
251    */
252   TVM_DLL Stage& double_buffer();   // NOLINT(*)
253   /*!
254    * \brief Schedule for OpenGL fragment shader.
255    * \return reference to self.
256    */
257   Stage& opengl(); // NOLINT(*)
258   /*!
259    * \brief whether the stage has been scheduled.
260    * \return whether the stage has been scheduled.
261    */
262   bool is_scheduled() const;
263   /*!
264    * \brief Get attachment spec of current stage.
265    *  If the stage compute at Group root, this function
266    *  will traverse the group function to get the
267    *  final spec from the group.
268    * \return A stage representing the attach spec of the group.
269    */
270   Stage GetAttachSpec() const;
271   // declare container type
272   using ContainerType = StageNode;
273 };
274 
275 /*!
276  * \brief Global schedule container
277  *  For operations and all the operations they depend on.
278  *  The schedule per Operation is named as stage.
279  */
280 class Schedule : public NodeRef {
281  public:
Schedule()282   Schedule() {}
Schedule(ObjectPtr<Object> n)283   explicit Schedule(ObjectPtr<Object> n) : NodeRef(n) {}
284   /*!
285    * \brief Get a copy of current schedule.
286    * \return The copied schedule.
287    */
288   Schedule copy() const;
289   /*!
290    * \brief Get the stage corresponds to the op
291    * \param op The operation.
292    */
293   TVM_DLL Stage operator[](const Operation& op);
294   /*!
295    * \brief Short hand for getting the stage of tensor's operation.
296    * \param tensor The tensor
297    * \return The stage corresponding to the tensor's op
298    */
299   TVM_DLL Stage operator[](const Tensor& tensor) {
300     return this->operator[](tensor->op);
301   }
302   /*!
303    * \brief Create a new stage group for all intermediate
304    *  operations between inputs and outputs.
305    *
306    * \param outputs The output boundary of the group.
307    * \param inputs The input boundary of the group.
308    * \param include_inputs Whether include inputs if they are reachable from outputs.
309    * \return The new grouped stage.
310    */
311   TVM_DLL Stage create_group(const Array<Tensor>& outputs,
312                      const Array<Tensor>& inputs,
313                      bool include_inputs = false);
314   /*!
315    * \brief create a cache read of original tensor for readers.
316    *  This will mutate the body of the readers.
317    *  A new stage will be created for the tensor.
318    * \param tensor The tensor cached.
319    * \param scope The scope of the cache.
320    * \param readers The readers to redirect to the tensor.
321    * \return The created tensor.
322    */
323   TVM_DLL Tensor cache_read(const Tensor& tensor,
324                     const std::string& scope,
325                     const Array<Operation>& readers);
326   /*!
327    * \brief Create a cache write tensor for producing tensor.
328    *  The the tensor will take over body of original tensor op.
329    *
330    *  This function can be used to do data layout transformation.
331    *  If there is a split/fuse/reorder on the data parallel axis of tensor
332    *  before cache_write is called. The intermediate cache stores
333    *  the data in the layout as the iteration order of leave axis.
334    *  The data will be transformed back to the original layout in the original tensor.
335    *  User can further call compute_inline to inline the original layout and keep
336    *  the data stored in the transformed layout.
337    *
338    * \param tensor The tensors to be produced.
339    * \param scope The scope of the storage.
340    * \return The created tensor.
341    */
342   TVM_DLL Array<Tensor> cache_write(const Array<Tensor>& tensor, const std::string& scope);
343   /*!
344    * \brief Create a cache write tensor for producing tensor.
345    *  The the tensor will take over body of original tensor op.
346    *
347    *  This function can be used to do data layout transformation.
348    *  If there is a split/fuse/reorder on the data parallel axis of tensor
349    *  before cache_write is called. The intermediate cache stores
350    *  the data in the layout as the iteration order of leave axis.
351    *  The data will be transformed back to the original layout in the original tensor.
352    *  User can further call compute_inline to inline the original layout and keep
353    *  the data stored in the transformed layout.
354    *
355    * \param tensor The tensor to be produced.
356    * \param scope The scope of the storage.
357    * \return The created tensor.
358    */
359   TVM_DLL Tensor cache_write(const Tensor& tensor, const std::string& scope);
360   /*!
361    * \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
362    * This will create a new stage that generated the new tensor with axis
363    * as the first dimension. The tensor's body will be rewritten as a reduction
364    * over the factored tensor.
365    *
366    *  P. Suriana, A. Adams and S. Kamil. Parallel associative reductions in halide. CGO'17
367    *
368    * \param tensor The tensor to be factored.
369    * \param axis The reduction axis in tensor's schedule to be factored.
370    * \param factor_axis The position where the new axis is placed.
371    * \return The created factored tensors.
372    */
373   TVM_DLL Array<Tensor> rfactor(const Tensor& tensor,
374                         const IterVar& axis,
375                         int factor_axis = 0);
376   /*!
377    * \brief Normalize the schedule.
378    *  This is needed before bound inference.
379    *  Insert necessary RebaseNode to make sure all leaf_iter_vars
380    *  are in form [0, extent)
381    *
382    * \return A normalized schedule, can be same as current one.
383    */
384   Schedule normalize();
385   /*!
386    * \brief access the internal node container
387    * \return the pointer to the internal node container
388    */
389   inline const ScheduleNode* operator->() const;
390   /*!
391    * \brief access the internal node container
392    * \return the pointer to the internal node container
393    */
394   inline ScheduleNode* operator->();
395   // declare container type
396   using ContainerType = ScheduleNode;
397 };
398 
399 /*!
400  * \brief The schedule relation between IterVars
401  *  can be Split, Fuse.
402  */
403 class IterVarRelation : public NodeRef {
404  public:
IterVarRelation()405   IterVarRelation() {}
IterVarRelation(ObjectPtr<Object> n)406   explicit IterVarRelation(ObjectPtr<Object> n) : NodeRef(n) {}
407   /*!
408    * \brief access the internal node container
409    * \return the pointer to the internal node container
410    */
411   inline const IterVarRelationNode* operator->() const;
412 };
413 
414 /*!
415  * \brief Additional scheduable attributes about IterVar.
416  */
417 class IterVarAttr : public NodeRef {
418  public:
IterVarAttr()419   IterVarAttr() {}
IterVarAttr(ObjectPtr<Object> n)420   explicit IterVarAttr(ObjectPtr<Object> n) : NodeRef(n) {}
421   /*!
422    * \brief access the internal node container
423    * \return the pointer to the internal node container
424    */
425   inline const IterVarAttrNode* operator->() const;
426 };
427 
428 /*!
429  * \brief represents a stage.
430  *
431  *  relations form a Directed acylic hypergraph in bipartite manner.
432  *  With each node is represented by a IterVar,
433  *  and each hyper-edge is represented by a IterVarRelation.
434  *  The relations connects the IterVars in the graph.
435  *
436  *  Besides typical stage that corresponds to operations.
437  *  There is also group stage, which groups stages together.
438  *  Each stage's group(given by group) represent an constraint,
439  *  the stage can only be attached to stages within the group.
440  *
441  *  The group stage node can be attached to IterVars as in normal stage.
442  */
443 class StageNode : public Node {
444  public:
445   /*!
446    * \brief The operation of stage, can be different from original op.
447    *  If it is null, then this stage is a group stage.
448    */
449   Operation op;
450   /*!
451    * \brief The original operator.
452    *  The op field can change during schedule to alternate the dataflow,
453    *  while origin_op remains fixed.
454    */
455   Operation origin_op;
456   /*! \brief All the nodes in the iter var */
457   Array<IterVar> all_iter_vars;
458   /*! \brief The current active leaf iter vars in the stage. */
459   Array<IterVar> leaf_iter_vars;
460   /*!
461    * \brief Specify threads to be launched at the stage.
462    *  This is only valid for composite ops such as Scan.
463    * \note Experimental primitive: used for thread persistence.
464    */
465   Array<IterVar> env_threads;
466   /*!
467    * \brief The predicate under which store can happen
468    *  Use this when there can be duplicated threads doing the same store.
469    * \note Experimental primitive: used by cross thread-reduction.
470    */
471   Expr store_predicate;
472   /*! \brief The relation bwteen of IterVars */
473   Array<IterVarRelation> relations;
474   /*! \brief additional attributes about iter var. */
475   Map<IterVar, IterVarAttr> iter_var_attrs;
476   /*! \brief The attachment type of the schedule */
477   AttachType attach_type{kGroupRoot};
478   /*! \brief The attach point of this schedule. */
479   IterVar attach_ivar;
480   /*! \brief The stage this node attaches to */
481   Stage attach_stage;
482   /*! \brief The thread storage scope level of the stage */
483   std::string scope;
484   /*! \brief Whether this is an output stage */
485   bool is_output{false};
486   /*! \brief Whether this is an OpenGL stage */
487   bool is_opengl{false};
488   /*! \brief Whether apply double buffer optimization to this stage */
489   bool double_buffer{false};
490   /*!
491    * \brief The parent group of the current stage.
492    *  The stage cannot be assigned to stages outside the group.
493    */
494   Stage group;
495   /*! \brief Number of direct child stages, only used for group stage.*/
496   int num_child_stages{0};
497 
VisitAttrs(AttrVisitor * v)498   void VisitAttrs(AttrVisitor* v) {
499     v->Visit("op", &op);
500     v->Visit("origin_op", &origin_op);
501     v->Visit("all_iter_vars", &all_iter_vars);
502     v->Visit("leaf_iter_vars", &leaf_iter_vars);
503     v->Visit("env_threads", &env_threads);
504     v->Visit("relations", &relations);
505     v->Visit("iter_var_attrs", &iter_var_attrs);
506     v->Visit("attach_type", &attach_type);
507     v->Visit("attach_ivar", &attach_ivar);
508     v->Visit("attach_stage", &attach_stage);
509     v->Visit("scope", &scope);
510     v->Visit("is_output", &is_output);
511     v->Visit("is_opengl", &is_opengl);
512     v->Visit("double_buffer", &double_buffer);
513     v->Visit("group", &group);
514     v->Visit("num_child_stages", &num_child_stages);
515   }
516 
517   static constexpr const char* _type_key = "Stage";
518   TVM_DECLARE_NODE_TYPE_INFO(StageNode, Node);
519 };
520 
521 /*! \brief node container for schedule */
522 class ScheduleNode : public Node {
523  public:
524   /*! \brief The output operations in original data flow graph */
525   Array<Operation> outputs;
526   /*!
527    * \brief list of all stages for ops.
528    * The stages are sorted in dependency order.
529    */
530   Array<Stage> stages;
531   /*!
532    * \brief List of all stage groups.
533    */
534   Array<Stage> groups;
535   /*! \brief map of original operation to the stages */
536   Map<Operation, Stage> stage_map;
537   /*!
538    * \brief Internal stage map to map internal ops to stages.
539    *  This is created on demand and can be invalidated.
540    */
541   std::unordered_map<const Node*, Stage> op2stage_cache_;
542 
VisitAttrs(AttrVisitor * v)543   void VisitAttrs(AttrVisitor* v) {
544     v->Visit("outputs", &outputs);
545     v->Visit("stages", &stages);
546     v->Visit("groups", &groups);
547     v->Visit("stage_map", &stage_map);
548   }
549 
550   /*! \brief Initialize temp cache. */
551   void InitCache();
552   /*! \brief Invalidate temp cache. */
553   void InvalidateCache();
554 
555   /*!
556    * \brief Check if the schedule contains an Operation.
557    * \param op The candidate Operation.
558    * \return true if the schedule has the Operation. Otherwise, false.
559    */
560   TVM_DLL bool Contain(const Operation& op) const;
561 
562   /*!
563    * \brief Check if the schedule contains a Tensor.
564    * \param tensor The candidate tensor.
565    * \return true if the schedule has the tensor. Otherwise, false.
566    */
Contain(const Tensor & tensor)567   TVM_DLL bool Contain(const Tensor& tensor) const {
568     return Contain(tensor->op);
569   }
570 
571   /*!
572    * \brief Create a schedule for array of ops(and their dependencies).
573    * \param ops The ops to be scheduled.
574    * \return sch The created Schedule.
575    */
576   TVM_DLL static Schedule make(Array<Operation> ops);
577 
578   static constexpr const char* _type_key = "Schedule";
579   TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node);
580 };
581 
582 /*!
583  * \brief Create a schedule for array of ops(and their dependencies).
584  * \param ops The ops to be scheduled.
585  * \return sch The created Schedule.
586  */
create_schedule(Array<Operation> ops)587 inline Schedule create_schedule(Array<Operation> ops) {
588   return ScheduleNode::make(ops);
589 }
590 
591 /*! \brief node container for IterVar attr */
592 class IterVarAttrNode : public Node {
593  public:
594   /*! \brief The iteration type. */
595   IterVarType iter_type{kDataPar};
596   /*! \brief The thread this iter Var binds, can be null */
597   IterVar bind_thread;
598   /*! \brief List of tensor to be prefetched in this loop */
599   Array<Tensor> prefetch_data;
600   /*! \brief The offset used in each prefetch */
601   Array<Expr> prefetch_offset;
602   /*!
603    * \brief Tensor intrinsic used in tensorization,
604    *   when the axis is marked as Tensorized
605    */
606   TensorIntrin tensor_intrin;
607   /*! \brief Alignment factor of buffer dimension */
608   int dim_align_factor{0};
609   /*! \brief Alignment offset of buffer dimension */
610   int dim_align_offset{0};
611   /*!
612    * \brief Additional pragma keys, array of StringImm
613    */
614   Array<Expr> pragma_keys;
615   /*!
616    * \brief Additional values of pragma, if any
617    */
618   Array<Expr> pragma_values;
619 
VisitAttrs(AttrVisitor * v)620   void VisitAttrs(AttrVisitor* v) {
621     v->Visit("iter_type", &iter_type);
622     v->Visit("bind_thread", &bind_thread);
623     v->Visit("prefetch_data", &prefetch_data);
624     v->Visit("prefetch_offset", &prefetch_offset);
625     v->Visit("tensor_intrin", &tensor_intrin);
626     v->Visit("dim_align_factor", &dim_align_factor);
627     v->Visit("dim_align_offset", &dim_align_offset);
628     v->Visit("pragma_keys", &pragma_keys);
629     v->Visit("pragma_values", &pragma_values);
630   }
631 
632   static constexpr const char* _type_key = "IterVarAttr";
633   TVM_DECLARE_NODE_TYPE_INFO(IterVarAttrNode, Node);
634 };
635 
636 /*! \brief base node of iteration var */
637 class IterVarRelationNode : public Node {
638  public:
639   static constexpr const char* _type_key = "IterVarRelation";
640   TVM_DECLARE_BASE_NODE_INFO(IterVarRelationNode, Node);
641 };
642 
643 /*!
644  * \brief Split the parent domain into product of
645  *  outer and iter.
646  */
647 class SplitNode : public IterVarRelationNode {
648  public:
649   /*! \brief The parent domain */
650   IterVar parent;
651   /*! \brief The outer domain */
652   IterVar outer;
653   /*! \brief The inner domain */
654   IterVar inner;
655   /*! \brief The split factor */
656   Expr factor;
657   /*! \brief Number of parts, only factor or nparts can be given */
658   Expr nparts;
659 
VisitAttrs(AttrVisitor * v)660   void VisitAttrs(AttrVisitor* v) {
661     v->Visit("parent", &parent);
662     v->Visit("outer", &outer);
663     v->Visit("inner", &inner);
664     v->Visit("factor", &factor);
665     v->Visit("nparts", &nparts);
666   }
667 
668   static IterVarRelation make(IterVar parent,
669                               IterVar outer,
670                               IterVar inner,
671                               Expr factor,
672                               Expr nparts);
673 
674   static constexpr const char* _type_key = "Split";
675   TVM_DECLARE_NODE_TYPE_INFO(SplitNode, IterVarRelationNode);
676 };
677 
678 /*!
679  * \brief Fuse two domains into one domain.
680  */
681 class FuseNode : public IterVarRelationNode {
682  public:
683   /*! \brief The outer domain */
684   IterVar outer;
685   /*! \brief The inner domain */
686   IterVar inner;
687   /*! \brief The target domain */
688   IterVar fused;
689 
VisitAttrs(AttrVisitor * v)690   void VisitAttrs(AttrVisitor* v) {
691     v->Visit("outer", &outer);
692     v->Visit("inner", &inner);
693     v->Visit("fused", &fused);
694   }
695 
696   static IterVarRelation make(
697       IterVar outer, IterVar inner, IterVar fused);
698 
699   static constexpr const char* _type_key = "Fuse";
700   TVM_DECLARE_NODE_TYPE_INFO(FuseNode, IterVarRelationNode);
701 };
702 
703 /*!
704  * \brief Rebase the iteration to make min to be 0.
705  *  This is useful to normalize the Schedule
706  *  to make every leaf variable's min to be 0.
707  */
708 class RebaseNode : public IterVarRelationNode {
709  public:
710   /*! \brief The parent domain */
711   IterVar parent;
712   /*! \brief The inner domain */
713   IterVar rebased;
714 
VisitAttrs(AttrVisitor * v)715   void VisitAttrs(AttrVisitor* v) {
716     v->Visit("parent", &parent);
717     v->Visit("rebased", &rebased);
718   }
719 
720   static IterVarRelation make(IterVar parent, IterVar rebased);
721 
722   static constexpr const char* _type_key = "Rebase";
723   TVM_DECLARE_NODE_TYPE_INFO(RebaseNode, IterVarRelationNode);
724 };
725 
726 
727 /*!
728  * \brief Singleton iterator [0, 1)
729  */
730 class SingletonNode : public IterVarRelationNode {
731  public:
732   /*! \brief The singleton iterator */
733   IterVar iter;
734 
VisitAttrs(AttrVisitor * v)735   void VisitAttrs(AttrVisitor* v) {
736     v->Visit("iter", &iter);
737   }
738 
739   static IterVarRelation make(IterVar iter);
740 
741   static constexpr const char* _type_key = "Singleton";
742   TVM_DECLARE_NODE_TYPE_INFO(SingletonNode, IterVarRelationNode);
743 };
744 
745 
746 // implementations
747 inline const StageNode* Stage::operator->() const {
748   return static_cast<const StageNode*>(get());
749 }
750 inline StageNode* Stage::operator->() {
751   return static_cast<StageNode*>(get_mutable());
752 }
753 
754 inline const ScheduleNode* Schedule::operator->() const {
755   return static_cast<const ScheduleNode*>(get());
756 }
757 inline ScheduleNode* Schedule::operator->() {
758   return static_cast<ScheduleNode*>(get_mutable());
759 }
760 
761 inline const IterVarRelationNode* IterVarRelation::operator->() const {
762   return static_cast<const IterVarRelationNode*>(get());
763 }
764 
765 inline const IterVarAttrNode* IterVarAttr::operator->() const {
766   return static_cast<const IterVarAttrNode*>(get());
767 }
768 }  // namespace tvm
769 #endif  // TVM_SCHEDULE_H_
770