1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file tvm/te/schedule.h
22  * \brief Define a schedule.
23  */
24 // Acknowledgement: Many schedule primitives originate from Halide and Loopy.
25 #ifndef TVM_TE_SCHEDULE_H_
26 #define TVM_TE_SCHEDULE_H_
27 
28 #include <tvm/support/with.h>
29 #include <tvm/te/tensor.h>
30 #include <tvm/te/tensor_intrin.h>
31 #include <tvm/tir/expr.h>
32 
33 #include <string>
34 #include <unordered_map>
35 
36 namespace tvm {
37 namespace te {
38 // Node container for Stage
39 class StageNode;
40 // Node container for Schedule
41 class ScheduleNode;
42 // Node container for IterVarRelation
43 class IterVarRelationNode;
44 // Attribute of itervar.
45 class IterVarAttrNode;
46 
47 /*! \brief the attachment type */
48 enum AttachType : int {
49   kGroupRoot = 1,
50   kInline = 2,
51   kInlinedAlready = 3,
52   kScope = 4,
53   kScanUpdate = 5
54 };
55 
56 /*! \brief Stage, contains scheduling for a stage of computation. */
57 class Stage : public ObjectRef {
58  public:
Stage()59   Stage() {}
Stage(ObjectPtr<Object> n)60   explicit Stage(ObjectPtr<Object> n) : ObjectRef(n) {}
61   /*!
62    * \brief create a new schedule for op.
63    * \param op The operator in the schedule
64    */
65   explicit Stage(Operation op);
66   /*!
67    * \brief access the internal node container
68    * \return the pointer to the internal node container
69    */
70   inline const StageNode* operator->() const;
71   /*!
72    * \brief access the internal node container
73    * \return the pointer to the internal node container
74    */
75   inline StageNode* operator->();
76   /*!
77    * \brief set the memory scope of the stage
78    * \param scope The memory scope.
79    */
80   TVM_DLL Stage& set_scope(std::string scope);  // NOLINT(*)
81   /*!
82    * \brief specify the schedule to be computed at the parent schedule's scope.
83    * \param parent The parent schedule.
84    * \param scope The iteration point to carry the schedule.
85    * \return reference to self.
86    */
87   TVM_DLL Stage& compute_at(Stage parent, IterVar scope);  // NOLINT(*)
88   /*!
89    * \brief Compute the function inline.
90    * \return reference to self.
91    */
92   TVM_DLL Stage& compute_inline();  // NOLINT(*)
93   /*!
94    * \brief Compute the function at group root.
95    * \return reference to self.
96    */
97   TVM_DLL Stage& compute_root();  // NOLINT(*)
98   /*!
99    * \brief Bind the IterVar to thread index.
100    *
101    * \param ivar The IterVar to be bound.
102    * \param thread_ivar The thread axis to be bound.
103    * \return reference to self.
104    */
105   TVM_DLL Stage& bind(IterVar ivar, IterVar thread_ivar);
106   /*!
107    * \brief Set the predicate to determine whether a store to the array should be performed.
108    *  Use this when there are multiple threads performing the same store and we only
109    *  need one of them to do the store.
110    *
111    * \note This is a dangerous scheduling primitive that can change behavior of program.
112    *    Only do when we are certain that thare are duplicated stores.
113    * \param predicate The condition to be checked.
114    * \return reference to self.
115    */
116   TVM_DLL Stage& set_store_predicate(PrimExpr predicate);
117   /*!
118    * \brief Specify environment threads that launched around the group's scope.
119    *  This can only be used in group stage.
120    * \param threads The threads to be launched around the scope.
121    * \note Each thread can only appear in one env_threads.
122    *    This is a beta feature.
123    * \return reference to self.
124    */
125   TVM_DLL Stage& env_threads(Array<IterVar> threads);
126   /*!
127    * \brief Split the parent by factor, generate
128    * \param parent The parent iteration domain.
129    * \param factor The split factor of the loop.
130    * \param p_outer The result outer domain
131    * \param p_inner The result inner domain.
132    * \return reference to self.
133    */
134   TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer,
135                        IterVar* p_inner);  // NOLINT(*)
136   /*!
137    * \brief Split the iteration with given number of parts.
138    *
139    * \param parent The parent domain.
140    * \param nparts The number of parts in the outer domain.
141    * \param p_outer The result outer domain.
142    * \param p_inner The result inner domain.
143    * \return reference to self.
144    */
145   TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer,
146                                  IterVar* p_inner);  // NOLINT(*)
147   /*!
148    * \brief Fuse the inner outer domain to the target
149    * \param outer The outer domain to be fused.
150    * \param inner The inner domain to be fused
151    * \param p_target The result target domain.
152    * \return reference to self.
153    */
154   TVM_DLL Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target);  // NOLINT(*)
155   /*!
156    * \brief Fuse all the axes together into a single axis.
157    *
158    * \param axes All the axes to be fused.
159    * \param p_target The result target domain.
160    *
161    * \note axes can be an empty array,
162    *       in that case, a singleton IterVar is created and
163    *       inserted to the outermost loop.
164    *       The fuse of empty array is used to support zero-dimension tensors.
165    *
166    * \return reference to self.
167    */
168   TVM_DLL Stage& fuse(const Array<IterVar>& axes, IterVar* p_target);  // NOLINT(*)
169   /*!
170    * \brief Reorder the iteration
171    * \param order The order of iteration variable.
172    * \return reference to self.
173    */
174   TVM_DLL Stage& reorder(const Array<IterVar>& order);  // NOLINT(*)
175   /*!
176    * \brief Perform tiling on two dimensions
177    *  The final loop order from outmost to inner most are
178    *  [x_outer, y_outer, x_inner, y_inner]
179    *
180    * \param x_parent The original x dimension
181    * \param y_parent The original y dimension
182    * \param x_factor The stride factor on x axis
183    * \param y_factor The stride factor on y axis
184    * \param p_x_outer Outer axis of x dimension
185    * \param p_y_outer Outer axis of y dimension
186    * \param p_x_inner Inner axis of x dimension
187    * \param p_y_inner Inner axis of y dimension
188    * \return reference to self.
189    */
190   TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent,  // NOLINT(*)
191                       PrimExpr x_factor, PrimExpr y_factor, IterVar* p_x_outer, IterVar* p_y_outer,
192                       IterVar* p_x_inner, IterVar* p_y_inner);
193   /*!
194    * \brief Vectorize iteration.
195    * \param var The axis to be vectorized.
196    * \return reference to self.
197    */
198   TVM_DLL Stage& vectorize(IterVar var);  // NOLINT(*)
199   /*!
200    * \brief Replace computation of the current stage by tensor intrinsic f.
201    * \param var The axis marks beginning of tensorization.
202    *  Every operations inside the axis(include axis itself is tensorized).
203    * \param f The Tensor compute intrinsics.
204    * \return reference to self.
205    */
206   TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f);  // NOLINT(*)
207   /*!
208    * \brief Unroll iteration.
209    * \param var The axis to be unrolled.
210    * \return reference to self.
211    */
212   TVM_DLL Stage& unroll(IterVar var);  // NOLINT(*)
213   /*!
214    * \brief Parallelize iteration.
215    * \param var The axis to be parallelized.
216    * \return reference to self.
217    */
218   TVM_DLL Stage& parallel(IterVar var);  // NOLINT(*)
219   /*!
220    * \brief Annotate the iteration with pragma
221    *
222    * \param var The axis to be parallelized.
223    * \param pragma_type The pragma type.
224    * \param pragma_value The pragma value
225    *
226    * \return reference to self.
227    */
228   TVM_DLL Stage& pragma(IterVar var, const std::string& pragma_type,
229                         const PrimExpr& pragma_value = PrimExpr());  // NOLINT(*)
230   /*!
231    * \brief Fetch data in advance.
232    * \param domain the tensor to be prefetched
233    * \param var the iteration point at which to apply prefetching
234    * \param offset the number of iterations be to fetched in advance
235    * \return reference to self
236    */
237   TVM_DLL Stage& prefetch(const Tensor& domain, IterVar var, PrimExpr offset);  // NOLINT(*)
238   /*!
239    * \brief Set alignment requirement for specific dimension.
240    *
241    *  Such that stride[axis] == k * factor + offset for some k.
242    *
243    * \param axis The dimension to be specified for alignment.
244    * \param factor The factor multiple of alignment
245    * \param offset The required offset factor.
246    * \return reference to self
247    */
248   TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset);  // NOLINT(*)
249   /*!
250    * \brief Compute current stage with double buffering.
251    * \return reference to self.
252    */
253   TVM_DLL Stage& double_buffer();  // NOLINT(*)
254   /*!
255    * \brief whether the stage has been scheduled.
256    * \return whether the stage has been scheduled.
257    */
258   bool is_scheduled() const;
259   /*!
260    * \brief Get attachment spec of current stage.
261    *  If the stage compute at Group root, this function
262    *  will traverse the group function to get the
263    *  final spec from the group.
264    * \return A stage representing the attach spec of the group.
265    */
266   Stage GetAttachSpec() const;
267   // declare container type
268   using ContainerType = StageNode;
269 };
270 
271 /*!
272  * \brief Global schedule container
273  *  For operations and all the operations they depend on.
274  *  The schedule per Operation is named as stage.
275  */
276 class Schedule : public ObjectRef {
277  public:
Schedule()278   Schedule() {}
Schedule(ObjectPtr<Object> n)279   explicit Schedule(ObjectPtr<Object> n) : ObjectRef(n) {}
280   /*!
281    * \brief Create a schedule for array of ops(and their dependencies).
282    * \param ops The ops to be scheduled.
283    * \return sch The created Schedule.
284    */
285   TVM_DLL explicit Schedule(Array<Operation> ops);
286   /*!
287    * \brief Get a copy of current schedule.
288    * \return The copied schedule.
289    */
290   Schedule copy() const;
291   /*!
292    * \brief Get the stage corresponds to the op
293    * \param op The operation.
294    */
295   TVM_DLL Stage operator[](const Operation& op);
296   /*!
297    * \brief Short hand for getting the stage of tensor's operation.
298    * \param tensor The tensor
299    * \return The stage corresponding to the tensor's op
300    */
301   TVM_DLL Stage operator[](const Tensor& tensor) { return this->operator[](tensor->op); }
302   /*!
303    * \brief Create a new stage group for all intermediate
304    *  operations between inputs and outputs.
305    *
306    * \param outputs The output boundary of the group.
307    * \param inputs The input boundary of the group.
308    * \param include_inputs Whether include inputs if they are reachable from outputs.
309    * \return The new grouped stage.
310    */
311   TVM_DLL Stage create_group(const Array<Tensor>& outputs, const Array<Tensor>& inputs,
312                              bool include_inputs = false);
313   /*!
314    * \brief create a cache read of original tensor for readers.
315    *  This will mutate the body of the readers.
316    *  A new stage will be created for the tensor.
317    * \param tensor The tensor cached.
318    * \param scope The scope of the cache.
319    * \param readers The readers to redirect to the tensor.
320    * \return The created tensor.
321    */
322   TVM_DLL Tensor cache_read(const Tensor& tensor, const std::string& scope,
323                             const Array<Operation>& readers);
324   /*!
325    * \brief Create a cache write tensor for producing tensor.
326    *  The the tensor will take over body of original tensor op.
327    *
328    *  This function can be used to do data layout transformation.
329    *  If there is a split/fuse/reorder on the data parallel axis of tensor
330    *  before cache_write is called. The intermediate cache stores
331    *  the data in the layout as the iteration order of leave axis.
332    *  The data will be transformed back to the original layout in the original tensor.
333    *  User can further call compute_inline to inline the original layout and keep
334    *  the data stored in the transformed layout.
335    *
336    * \param tensor The tensors to be produced.
337    * \param scope The scope of the storage.
338    * \return The created tensor.
339    */
340   TVM_DLL Array<Tensor> cache_write(const Array<Tensor>& tensor, const std::string& scope);
341   /*!
342    * \brief Create a cache write tensor for producing tensor.
343    *  The the tensor will take over body of original tensor op.
344    *
345    *  This function can be used to do data layout transformation.
346    *  If there is a split/fuse/reorder on the data parallel axis of tensor
347    *  before cache_write is called. The intermediate cache stores
348    *  the data in the layout as the iteration order of leave axis.
349    *  The data will be transformed back to the original layout in the original tensor.
350    *  User can further call compute_inline to inline the original layout and keep
351    *  the data stored in the transformed layout.
352    *
353    * \param tensor The tensor to be produced.
354    * \param scope The scope of the storage.
355    * \return The created tensor.
356    */
357   TVM_DLL Tensor cache_write(const Tensor& tensor, const std::string& scope);
358   /*!
359    * \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
360    * This will create a new stage that generated the new tensor with axis
361    * as the first dimension. The tensor's body will be rewritten as a reduction
362    * over the factored tensor.
363    *
364    *  P. Suriana, A. Adams and S. Kamil. Parallel associative reductions in halide. CGO'17
365    *
366    * \param tensor The tensor to be factored.
367    * \param axis The reduction axis in tensor's schedule to be factored.
368    * \param factor_axis The position where the new axis is placed.
369    * \return The created factored tensors.
370    */
371   TVM_DLL Array<Tensor> rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis = 0);
372   /*!
373    * \brief Normalize the schedule.
374    *  This is needed before bound inference.
375    *  Insert necessary RebaseNode to make sure all leaf_iter_vars
376    *  are in form [0, extent)
377    *
378    * \return A normalized schedule, can be same as current one.
379    */
380   Schedule normalize();
381   /*!
382    * \brief access the internal node container
383    * \return the pointer to the internal node container
384    */
385   inline const ScheduleNode* operator->() const;
386   /*!
387    * \brief access the internal node container
388    * \return the pointer to the internal node container
389    */
390   inline ScheduleNode* operator->();
391   // declare container type
392   using ContainerType = ScheduleNode;
393 };
394 
395 /*!
396  * \brief The schedule relation between IterVars
397  *  can be Split, Fuse.
398  */
399 class IterVarRelation : public ObjectRef {
400  public:
IterVarRelation()401   IterVarRelation() {}
IterVarRelation(ObjectPtr<Object> n)402   explicit IterVarRelation(ObjectPtr<Object> n) : ObjectRef(n) {}
403   /*!
404    * \brief access the internal node container
405    * \return the pointer to the internal node container
406    */
407   inline const IterVarRelationNode* operator->() const;
408 };
409 
410 /*!
411  * \brief Additional scheduable attributes about IterVar.
412  */
413 class IterVarAttr : public ObjectRef {
414  public:
IterVarAttr()415   IterVarAttr() {}
IterVarAttr(ObjectPtr<Object> n)416   explicit IterVarAttr(ObjectPtr<Object> n) : ObjectRef(n) {}
417   /*!
418    * \brief access the internal node container
419    * \return the pointer to the internal node container
420    */
421   inline const IterVarAttrNode* operator->() const;
422 };
423 
424 /*!
425  * \brief represents a stage.
426  *
427  *  relations form a Directed acylic hypergraph in bipartite manner.
428  *  With each node is represented by a IterVar,
429  *  and each hyper-edge is represented by a IterVarRelation.
430  *  The relations connects the IterVars in the graph.
431  *
432  *  Besides typical stage that corresponds to operations.
433  *  There is also group stage, which groups stages together.
434  *  Each stage's group(given by group) represent an constraint,
435  *  the stage can only be attached to stages within the group.
436  *
437  *  The group stage node can be attached to IterVars as in normal stage.
438  */
439 class StageNode : public Object {
440  public:
441   /*!
442    * \brief The operation of stage, can be different from original op.
443    *  If it is null, then this stage is a group stage.
444    */
445   Operation op;
446   /*!
447    * \brief The original operator.
448    *  The op field can change during schedule to alternate the dataflow,
449    *  while origin_op remains fixed.
450    */
451   Operation origin_op;
452   /*! \brief All the nodes in the iter var */
453   Array<IterVar> all_iter_vars;
454   /*! \brief The current active leaf iter vars in the stage. */
455   Array<IterVar> leaf_iter_vars;
456   /*!
457    * \brief Specify threads to be launched at the stage.
458    *  This is only valid for composite ops such as Scan.
459    * \note Experimental primitive: used for thread persistence.
460    */
461   Array<IterVar> env_threads;
462   /*!
463    * \brief The predicate under which store can happen
464    *  Use this when there can be duplicated threads doing the same store.
465    * \note Experimental primitive: used by cross thread-reduction.
466    */
467   PrimExpr store_predicate;
468   /*! \brief The relation bwteen of IterVars */
469   Array<IterVarRelation> relations;
470   /*! \brief additional attributes about iter var. */
471   Map<IterVar, IterVarAttr> iter_var_attrs;
472   /*! \brief The attachment type of the schedule */
473   AttachType attach_type{kGroupRoot};
474   /*! \brief The attach point of this schedule. */
475   IterVar attach_ivar;
476   /*! \brief The stage this node attaches to */
477   Stage attach_stage;
478   /*! \brief The thread storage scope level of the stage */
479   std::string scope;
480   /*! \brief Whether this is an output stage */
481   bool is_output{false};
482   /*! \brief Whether apply double buffer optimization to this stage */
483   bool double_buffer{false};
484   /*!
485    * \brief The parent group of the current stage.
486    *  The stage cannot be assigned to stages outside the group.
487    */
488   Stage group;
489   /*! \brief Number of direct child stages, only used for group stage.*/
490   int num_child_stages{0};
491 
VisitAttrs(AttrVisitor * v)492   void VisitAttrs(AttrVisitor* v) {
493     v->Visit("op", &op);
494     v->Visit("origin_op", &origin_op);
495     v->Visit("all_iter_vars", &all_iter_vars);
496     v->Visit("leaf_iter_vars", &leaf_iter_vars);
497     v->Visit("env_threads", &env_threads);
498     v->Visit("relations", &relations);
499     v->Visit("iter_var_attrs", &iter_var_attrs);
500     v->Visit("attach_type", &attach_type);
501     v->Visit("attach_ivar", &attach_ivar);
502     v->Visit("attach_stage", &attach_stage);
503     v->Visit("scope", &scope);
504     v->Visit("is_output", &is_output);
505     v->Visit("double_buffer", &double_buffer);
506     v->Visit("group", &group);
507     v->Visit("num_child_stages", &num_child_stages);
508   }
509 
510   static constexpr const char* _type_key = "Stage";
511   TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object);
512 };
513 
514 /*! \brief node container for schedule */
515 class ScheduleNode : public Object {
516  public:
517   /*! \brief The output operations in original data flow graph */
518   Array<Operation> outputs;
519   /*!
520    * \brief list of all stages for ops.
521    * The stages are sorted in dependency order.
522    */
523   Array<Stage> stages;
524   /*!
525    * \brief List of all stage groups.
526    */
527   Array<Stage> groups;
528   /*! \brief map of original operation to the stages */
529   Map<Operation, Stage> stage_map;
530   /*!
531    * \brief Internal stage map to map internal ops to stages.
532    *  This is created on demand and can be invalidated.
533    */
534   std::unordered_map<const Object*, Stage> op2stage_cache_;
535 
VisitAttrs(AttrVisitor * v)536   void VisitAttrs(AttrVisitor* v) {
537     v->Visit("outputs", &outputs);
538     v->Visit("stages", &stages);
539     v->Visit("groups", &groups);
540     v->Visit("stage_map", &stage_map);
541   }
542 
543   /*! \brief Initialize temp cache. */
544   void InitCache();
545   /*! \brief Invalidate temp cache. */
546   void InvalidateCache();
547 
548   /*!
549    * \brief Check if the schedule contains an Operation.
550    * \param op The candidate Operation.
551    * \return true if the schedule has the Operation. Otherwise, false.
552    */
553   TVM_DLL bool Contain(const Operation& op) const;
554 
555   /*!
556    * \brief Check if the schedule contains a Tensor.
557    * \param tensor The candidate tensor.
558    * \return true if the schedule has the tensor. Otherwise, false.
559    */
Contain(const Tensor & tensor)560   TVM_DLL bool Contain(const Tensor& tensor) const { return Contain(tensor->op); }
561 
562   static constexpr const char* _type_key = "Schedule";
563   TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, Object);
564 };
565 
566 /*!
567  * \brief Create a schedule for array of ops(and their dependencies).
568  * \param ops The ops to be scheduled.
569  * \return sch The created Schedule.
570  */
create_schedule(Array<Operation> ops)571 inline Schedule create_schedule(Array<Operation> ops) { return Schedule(ops); }
572 
573 /*! \brief node container for IterVar attr */
574 class IterVarAttrNode : public Object {
575  public:
576   /*! \brief The iteration type. */
577   IterVarType iter_type{kDataPar};
578   /*! \brief The thread this iter Var binds, can be null */
579   IterVar bind_thread;
580   /*! \brief List of tensor to be prefetched in this loop */
581   Array<Tensor> prefetch_data;
582   /*! \brief The offset used in each prefetch */
583   Array<PrimExpr> prefetch_offset;
584   /*!
585    * \brief Tensor intrinsic used in tensorization,
586    *   when the axis is marked as Tensorized
587    */
588   TensorIntrin tensor_intrin;
589   /*! \brief Alignment factor of buffer dimension */
590   int dim_align_factor{0};
591   /*! \brief Alignment offset of buffer dimension */
592   int dim_align_offset{0};
593   /*!
594    * \brief Additional pragma keys, array of StringImm
595    */
596   Array<PrimExpr> pragma_keys;
597   /*!
598    * \brief Additional values of pragma, if any
599    */
600   Array<PrimExpr> pragma_values;
601 
VisitAttrs(AttrVisitor * v)602   void VisitAttrs(AttrVisitor* v) {
603     v->Visit("iter_type", &iter_type);
604     v->Visit("bind_thread", &bind_thread);
605     v->Visit("prefetch_data", &prefetch_data);
606     v->Visit("prefetch_offset", &prefetch_offset);
607     v->Visit("tensor_intrin", &tensor_intrin);
608     v->Visit("dim_align_factor", &dim_align_factor);
609     v->Visit("dim_align_offset", &dim_align_offset);
610     v->Visit("pragma_keys", &pragma_keys);
611     v->Visit("pragma_values", &pragma_values);
612   }
613 
614   static constexpr const char* _type_key = "IterVarAttr";
615   TVM_DECLARE_FINAL_OBJECT_INFO(IterVarAttrNode, Object);
616 };
617 
618 /*! \brief base node of iteration var */
619 class IterVarRelationNode : public Object {
620  public:
621   static constexpr const char* _type_key = "IterVarRelation";
622   TVM_DECLARE_BASE_OBJECT_INFO(IterVarRelationNode, Object);
623 };
624 
625 /*!
626  * \brief Split the parent domain into product of
627  *  outer and iter.
628  */
629 class SplitNode : public IterVarRelationNode {
630  public:
631   /*! \brief The parent domain */
632   IterVar parent;
633   /*! \brief The outer domain */
634   IterVar outer;
635   /*! \brief The inner domain */
636   IterVar inner;
637   /*! \brief The split factor */
638   PrimExpr factor;
639   /*! \brief Number of parts, only factor or nparts can be given */
640   PrimExpr nparts;
641 
VisitAttrs(AttrVisitor * v)642   void VisitAttrs(AttrVisitor* v) {
643     v->Visit("parent", &parent);
644     v->Visit("outer", &outer);
645     v->Visit("inner", &inner);
646     v->Visit("factor", &factor);
647     v->Visit("nparts", &nparts);
648   }
649 
650   static constexpr const char* _type_key = "Split";
651   TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode);
652 };
653 
654 /*!
655  * \brief Managed reference to SplitNode
656  * \sa SplitNode
657  */
658 class Split : public IterVarRelation {
659  public:
660   TVM_DLL Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts);
661 
662   TVM_DEFINE_OBJECT_REF_METHODS(Split, IterVarRelation, SplitNode);
663 };
664 
665 /*!
666  * \brief Fuse two domains into one domain.
667  */
668 class FuseNode : public IterVarRelationNode {
669  public:
670   /*! \brief The outer domain */
671   IterVar outer;
672   /*! \brief The inner domain */
673   IterVar inner;
674   /*! \brief The target domain */
675   IterVar fused;
676 
VisitAttrs(AttrVisitor * v)677   void VisitAttrs(AttrVisitor* v) {
678     v->Visit("outer", &outer);
679     v->Visit("inner", &inner);
680     v->Visit("fused", &fused);
681   }
682 
683   static constexpr const char* _type_key = "Fuse";
684   TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode);
685 };
686 
687 /*!
688  * \brief Managed reference to FuseNode
689  * \sa FuseNode
690  */
691 class Fuse : public IterVarRelation {
692  public:
693   TVM_DLL Fuse(IterVar outer, IterVar inner, IterVar fused);
694 
695   TVM_DEFINE_OBJECT_REF_METHODS(Fuse, IterVarRelation, FuseNode);
696 };
697 
698 /*!
699  * \brief Rebase the iteration to make min to be 0.
700  *  This is useful to normalize the Schedule
701  *  to make every leaf variable's min to be 0.
702  */
703 class RebaseNode : public IterVarRelationNode {
704  public:
705   /*! \brief The parent domain */
706   IterVar parent;
707   /*! \brief The inner domain */
708   IterVar rebased;
709 
VisitAttrs(AttrVisitor * v)710   void VisitAttrs(AttrVisitor* v) {
711     v->Visit("parent", &parent);
712     v->Visit("rebased", &rebased);
713   }
714 
715   static constexpr const char* _type_key = "Rebase";
716   TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode);
717 };
718 
719 /*!
720  * \brief Managed reference to RebaseNode
721  * \sa RebaseNode
722  */
723 class Rebase : public IterVarRelation {
724  public:
725   TVM_DLL Rebase(IterVar parent, IterVar rebased);
726 
727   TVM_DEFINE_OBJECT_REF_METHODS(Rebase, IterVarRelation, RebaseNode);
728 };
729 
730 /*!
731  * \brief Singleton iterator [0, 1)
732  */
733 class SingletonNode : public IterVarRelationNode {
734  public:
735   /*! \brief The singleton iterator */
736   IterVar iter;
737 
VisitAttrs(AttrVisitor * v)738   void VisitAttrs(AttrVisitor* v) { v->Visit("iter", &iter); }
739 
740   static constexpr const char* _type_key = "Singleton";
741   TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode);
742 };
743 
744 /*!
745  * \brief Managed reference to SingletonNode
746  * \sa SingletonNode
747  */
748 class Singleton : public IterVarRelation {
749  public:
750   TVM_DLL explicit Singleton(IterVar iter);
751 
752   TVM_DEFINE_OBJECT_REF_METHODS(Singleton, IterVarRelation, SingletonNode);
753 };
754 
755 /*! \brief Container for specialization conditions. */
756 class SpecializedConditionNode : public Object {
757  public:
758   /*!
759    * \brief List of conditions in conjunctive joint form (CNF).
760    *   Each condition should be a simple expression, e.g., n > 16, m % 8 == 0, etc.,
761    *   where n, m are tvm::Var that represents a dimension in the tensor shape.
762    */
763   Array<PrimExpr> clauses;
764 
VisitAttrs(AttrVisitor * v)765   void VisitAttrs(AttrVisitor* v) { v->Visit("clauses", &clauses); }
766 
767   static constexpr const char* _type_key = "SpecializedCondition";
768   TVM_DECLARE_FINAL_OBJECT_INFO(SpecializedConditionNode, Object);
769 };
770 
771 /*!
772  * \brief Specialized condition to enable op specialization
773  */
774 class SpecializedCondition : public ObjectRef {
775  public:
776   /*!
777    * \brief construct from conditions
778    * \param conditions The clauses in the specialized condition.
779    */
780   TVM_DLL SpecializedCondition(Array<PrimExpr> conditions);  // NOLINT(*)
781 
782   /*!
783    * \brief Get the current specialized condition.
784    * \return the current specialized condition.
785    */
786   TVM_DLL static SpecializedCondition Current();
787 
788   TVM_DEFINE_OBJECT_REF_METHODS(SpecializedCondition, ObjectRef, SpecializedConditionNode);
789   class Internal;
790 
791  private:
792   // enable with syntax.
793   friend class Internal;
794   friend class With<SpecializedCondition>;
795   /*! \brief Push a new specialized condition onto the thread local stack. */
796   TVM_DLL void EnterWithScope();
797   /*! \brief Pop a specialized condition off the thread local context stack. */
798   TVM_DLL void ExitWithScope();
799 };
800 
801 // implementations
802 inline const StageNode* Stage::operator->() const { return static_cast<const StageNode*>(get()); }
803 inline StageNode* Stage::operator->() { return static_cast<StageNode*>(get_mutable()); }
804 
805 inline const ScheduleNode* Schedule::operator->() const {
806   return static_cast<const ScheduleNode*>(get());
807 }
808 inline ScheduleNode* Schedule::operator->() { return static_cast<ScheduleNode*>(get_mutable()); }
809 
810 inline const IterVarRelationNode* IterVarRelation::operator->() const {
811   return static_cast<const IterVarRelationNode*>(get());
812 }
813 
814 inline const IterVarAttrNode* IterVarAttr::operator->() const {
815   return static_cast<const IterVarAttrNode*>(get());
816 }
817 
818 }  // namespace te
819 }  // namespace tvm
820 #endif  // TVM_TE_SCHEDULE_H_
821