1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file tvm/schedule.h
22 * \brief Define a schedule.
23 */
24 // Acknowledgement: Many schedule primitives originate from Halide and Loopy.
25 #ifndef TVM_SCHEDULE_H_
26 #define TVM_SCHEDULE_H_
27
28 #include <string>
29 #include <unordered_map>
30 #include "base.h"
31 #include "expr.h"
32 #include "tensor.h"
33 #include "tensor_intrin.h"
34
35 namespace tvm {
36
37 // Node container for Stage
38 class StageNode;
39 // Node container for Schedule
40 class ScheduleNode;
41 // Node container for IterVarRelation
42 class IterVarRelationNode;
43 // Attribute of itervar.
44 class IterVarAttrNode;
45
46 /*! \brief the attachment type */
47 enum AttachType : int {
48 kGroupRoot = 1,
49 kInline = 2,
50 kInlinedAlready = 3,
51 kScope = 4,
52 kScanUpdate = 5
53 };
54
55 /*! \brief Stage, contains scheduling for a stage of computation. */
56 class Stage : public NodeRef {
57 public:
Stage()58 Stage() {}
Stage(ObjectPtr<Object> n)59 explicit Stage(ObjectPtr<Object> n) : NodeRef(n) {}
60 /*!
61 * \brief create a new schedule for op.
62 * \param op The operator in the schedule
63 */
64 explicit Stage(Operation op);
65 /*!
66 * \brief access the internal node container
67 * \return the pointer to the internal node container
68 */
69 inline const StageNode* operator->() const;
70 /*!
71 * \brief access the internal node container
72 * \return the pointer to the internal node container
73 */
74 inline StageNode* operator->();
75 /*!
76 * \brief set the memory scope of the stage
77 * \param scope The memory scope.
78 */
79 TVM_DLL Stage& set_scope(std::string scope); // NOLINT(*)
80 /*!
81 * \brief specify the schedule to be computed at the parent schedule's scope.
82 * \param parent The parent schedule.
83 * \param scope The iteration point to carry the schedule.
84 * \return reference to self.
85 */
86 TVM_DLL Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*)
87 /*!
88 * \brief Compute the function inline.
89 * \return reference to self.
90 */
91 TVM_DLL Stage& compute_inline(); // NOLINT(*)
92 /*!
93 * \brief Compute the function at group root.
94 * \return reference to self.
95 */
96 TVM_DLL Stage& compute_root(); // NOLINT(*)
97 /*!
98 * \brief Bind the IterVar to thread index.
99 *
100 * \param ivar The IterVar to be bound.
101 * \param thread_ivar The thread axis to be bound.
102 * \return reference to self.
103 */
104 TVM_DLL Stage& bind(IterVar ivar, IterVar thread_ivar);
105 /*!
106 * \brief Set the predicate to determine whether a store to the array should be performed.
107 * Use this when there are multiple threads performing the same store and we only
108 * need one of them to do the store.
109 *
110 * \note This is a dangerous scheduling primitive that can change behavior of program.
111 * Only do when we are certain that thare are duplicated stores.
112 * \param predicate The condition to be checked.
113 * \return reference to self.
114 */
115 TVM_DLL Stage& set_store_predicate(Expr predicate);
116 /*!
117 * \brief Specify environment threads that launched around the group's scope.
118 * This can only be used in group stage.
119 * \param threads The threads to be launched around the scope.
120 * \note Each thread can only appear in one env_threads.
121 * This is a beta feature.
122 * \return reference to self.
123 */
124 TVM_DLL Stage& env_threads(Array<IterVar> threads);
125 /*!
126 * \brief Split the parent by factor, generate
127 * \param parent The parent iteration domain.
128 * \param factor The split factor of the loop.
129 * \param p_outer The result outer domain
130 * \param p_inner The result inner domain.
131 * \return reference to self.
132 */
133 TVM_DLL Stage& split(IterVar parent, Expr factor, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
134 /*!
135 * \brief Split the iteration with given number of parts.
136 *
137 * \param parent The parent domain.
138 * \param nparts The number of parts in the outer domain.
139 * \param p_outer The result outer domain.
140 * \param p_inner The result inner domain.
141 * \return reference to self.
142 */
143 TVM_DLL Stage& split_by_nparts(IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
144 /*!
145 * \brief Fuse the inner outer domain to the target
146 * \param outer The outer domain to be fused.
147 * \param inner The inner domain to be fused
148 * \param p_target The result target domain.
149 * \return reference to self.
150 */
151 TVM_DLL Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*)
152 /*!
153 * \brief Fuse all the axes together into a single axis.
154 *
155 * \param axes All the axes to be fused.
156 * \param p_target The result target domain.
157 *
158 * \note axes can be an empty array,
159 * in that case, a singleton IterVar is created and
160 * inserted to the outermost loop.
161 * The fuse of empty array is used to support zero-dimension tensors.
162 *
163 * \return reference to self.
164 */
165 TVM_DLL Stage& fuse(const Array<IterVar>& axes, IterVar* p_target); // NOLINT(*)
166 /*!
167 * \brief Reorder the iteration
168 * \param order The order of iteration variable.
169 * \return reference to self.
170 */
171 TVM_DLL Stage& reorder(const Array<IterVar>& order); // NOLINT(*)
172 /*!
173 * \brief Perform tiling on two dimensions
174 * The final loop order from outmost to inner most are
175 * [x_outer, y_outer, x_inner, y_inner]
176 *
177 * \param x_parent The original x dimension
178 * \param y_parent The original y dimension
179 * \param x_factor The stride factor on x axis
180 * \param y_factor The stride factor on y axis
181 * \param p_x_outer Outer axis of x dimension
182 * \param p_y_outer Outer axis of y dimension
183 * \param p_x_inner Inner axis of x dimension
184 * \param p_y_inner Inner axis of y dimension
185 * \return reference to self.
186 */
187 TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
188 Expr x_factor, Expr y_factor,
189 IterVar* p_x_outer, IterVar* p_y_outer,
190 IterVar* p_x_inner, IterVar* p_y_inner);
191 /*!
192 * \brief Vectorize iteration.
193 * \param var The axis to be vectorized.
194 * \return reference to self.
195 */
196 TVM_DLL Stage& vectorize(IterVar var); // NOLINT(*)
197 /*!
198 * \brief Replace computation of the current stage by tensor intrinsic f.
199 * \param var The axis marks beginning of tensorization.
200 * Every operations inside the axis(include axis itself is tensorized).
201 * \param f The Tensor compute intrinsics.
202 * \return reference to self.
203 */
204 TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*)
205 /*!
206 * \brief Unroll iteration.
207 * \param var The axis to be unrolled.
208 * \return reference to self.
209 */
210 TVM_DLL Stage& unroll(IterVar var); // NOLINT(*)
211 /*!
212 * \brief Parallelize iteration.
213 * \param var The axis to be parallelized.
214 * \return reference to self.
215 */
216 TVM_DLL Stage& parallel(IterVar var); // NOLINT(*)
217 /*!
218 * \brief Annotate the iteration with pragma
219 *
220 * \param var The axis to be parallelized.
221 * \param pragma_type The pragma type.
222 * \param pragma_value The pragma value
223 *
224 * \return reference to self.
225 */
226 TVM_DLL Stage& pragma(IterVar var,
227 const std::string& pragma_type,
228 const Expr& pragma_value = Expr()); // NOLINT(*)
229 /*!
230 * \brief Fetch data in advance.
231 * \param domain the tensor to be prefetched
232 * \param var the iteration point at which to apply prefetching
233 * \param offset the number of iterations be to fetched in advance
234 * \return reference to self
235 */
236 TVM_DLL Stage& prefetch(const Tensor &domain, IterVar var, Expr offset); //NOLINT(*)
237 /*!
238 * \brief Set alignment requirement for specific dimension.
239 *
240 * Such that stride[axis] == k * factor + offset for some k.
241 *
242 * \param axis The dimension to be specified for alignment.
243 * \param factor The factor multiple of alignment
244 * \param offset The required offset factor.
245 * \return reference to self
246 */
247 TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*)
248 /*!
249 * \brief Compute current stage with double buffering.
250 * \return reference to self.
251 */
252 TVM_DLL Stage& double_buffer(); // NOLINT(*)
253 /*!
254 * \brief Schedule for OpenGL fragment shader.
255 * \return reference to self.
256 */
257 Stage& opengl(); // NOLINT(*)
258 /*!
259 * \brief whether the stage has been scheduled.
260 * \return whether the stage has been scheduled.
261 */
262 bool is_scheduled() const;
263 /*!
264 * \brief Get attachment spec of current stage.
265 * If the stage compute at Group root, this function
266 * will traverse the group function to get the
267 * final spec from the group.
268 * \return A stage representing the attach spec of the group.
269 */
270 Stage GetAttachSpec() const;
271 // declare container type
272 using ContainerType = StageNode;
273 };
274
275 /*!
276 * \brief Global schedule container
277 * For operations and all the operations they depend on.
278 * The schedule per Operation is named as stage.
279 */
280 class Schedule : public NodeRef {
281 public:
Schedule()282 Schedule() {}
Schedule(ObjectPtr<Object> n)283 explicit Schedule(ObjectPtr<Object> n) : NodeRef(n) {}
284 /*!
285 * \brief Get a copy of current schedule.
286 * \return The copied schedule.
287 */
288 Schedule copy() const;
289 /*!
290 * \brief Get the stage corresponds to the op
291 * \param op The operation.
292 */
293 TVM_DLL Stage operator[](const Operation& op);
294 /*!
295 * \brief Short hand for getting the stage of tensor's operation.
296 * \param tensor The tensor
297 * \return The stage corresponding to the tensor's op
298 */
299 TVM_DLL Stage operator[](const Tensor& tensor) {
300 return this->operator[](tensor->op);
301 }
302 /*!
303 * \brief Create a new stage group for all intermediate
304 * operations between inputs and outputs.
305 *
306 * \param outputs The output boundary of the group.
307 * \param inputs The input boundary of the group.
308 * \param include_inputs Whether include inputs if they are reachable from outputs.
309 * \return The new grouped stage.
310 */
311 TVM_DLL Stage create_group(const Array<Tensor>& outputs,
312 const Array<Tensor>& inputs,
313 bool include_inputs = false);
314 /*!
315 * \brief create a cache read of original tensor for readers.
316 * This will mutate the body of the readers.
317 * A new stage will be created for the tensor.
318 * \param tensor The tensor cached.
319 * \param scope The scope of the cache.
320 * \param readers The readers to redirect to the tensor.
321 * \return The created tensor.
322 */
323 TVM_DLL Tensor cache_read(const Tensor& tensor,
324 const std::string& scope,
325 const Array<Operation>& readers);
326 /*!
327 * \brief Create a cache write tensor for producing tensor.
328 * The the tensor will take over body of original tensor op.
329 *
330 * This function can be used to do data layout transformation.
331 * If there is a split/fuse/reorder on the data parallel axis of tensor
332 * before cache_write is called. The intermediate cache stores
333 * the data in the layout as the iteration order of leave axis.
334 * The data will be transformed back to the original layout in the original tensor.
335 * User can further call compute_inline to inline the original layout and keep
336 * the data stored in the transformed layout.
337 *
338 * \param tensor The tensors to be produced.
339 * \param scope The scope of the storage.
340 * \return The created tensor.
341 */
342 TVM_DLL Array<Tensor> cache_write(const Array<Tensor>& tensor, const std::string& scope);
343 /*!
344 * \brief Create a cache write tensor for producing tensor.
345 * The the tensor will take over body of original tensor op.
346 *
347 * This function can be used to do data layout transformation.
348 * If there is a split/fuse/reorder on the data parallel axis of tensor
349 * before cache_write is called. The intermediate cache stores
350 * the data in the layout as the iteration order of leave axis.
351 * The data will be transformed back to the original layout in the original tensor.
352 * User can further call compute_inline to inline the original layout and keep
353 * the data stored in the transformed layout.
354 *
355 * \param tensor The tensor to be produced.
356 * \param scope The scope of the storage.
357 * \return The created tensor.
358 */
359 TVM_DLL Tensor cache_write(const Tensor& tensor, const std::string& scope);
360 /*!
361 * \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
362 * This will create a new stage that generated the new tensor with axis
363 * as the first dimension. The tensor's body will be rewritten as a reduction
364 * over the factored tensor.
365 *
366 * P. Suriana, A. Adams and S. Kamil. Parallel associative reductions in halide. CGO'17
367 *
368 * \param tensor The tensor to be factored.
369 * \param axis The reduction axis in tensor's schedule to be factored.
370 * \param factor_axis The position where the new axis is placed.
371 * \return The created factored tensors.
372 */
373 TVM_DLL Array<Tensor> rfactor(const Tensor& tensor,
374 const IterVar& axis,
375 int factor_axis = 0);
376 /*!
377 * \brief Normalize the schedule.
378 * This is needed before bound inference.
379 * Insert necessary RebaseNode to make sure all leaf_iter_vars
380 * are in form [0, extent)
381 *
382 * \return A normalized schedule, can be same as current one.
383 */
384 Schedule normalize();
385 /*!
386 * \brief access the internal node container
387 * \return the pointer to the internal node container
388 */
389 inline const ScheduleNode* operator->() const;
390 /*!
391 * \brief access the internal node container
392 * \return the pointer to the internal node container
393 */
394 inline ScheduleNode* operator->();
395 // declare container type
396 using ContainerType = ScheduleNode;
397 };
398
399 /*!
400 * \brief The schedule relation between IterVars
401 * can be Split, Fuse.
402 */
403 class IterVarRelation : public NodeRef {
404 public:
IterVarRelation()405 IterVarRelation() {}
IterVarRelation(ObjectPtr<Object> n)406 explicit IterVarRelation(ObjectPtr<Object> n) : NodeRef(n) {}
407 /*!
408 * \brief access the internal node container
409 * \return the pointer to the internal node container
410 */
411 inline const IterVarRelationNode* operator->() const;
412 };
413
414 /*!
415 * \brief Additional scheduable attributes about IterVar.
416 */
417 class IterVarAttr : public NodeRef {
418 public:
IterVarAttr()419 IterVarAttr() {}
IterVarAttr(ObjectPtr<Object> n)420 explicit IterVarAttr(ObjectPtr<Object> n) : NodeRef(n) {}
421 /*!
422 * \brief access the internal node container
423 * \return the pointer to the internal node container
424 */
425 inline const IterVarAttrNode* operator->() const;
426 };
427
428 /*!
429 * \brief represents a stage.
430 *
431 * relations form a Directed acylic hypergraph in bipartite manner.
432 * With each node is represented by a IterVar,
433 * and each hyper-edge is represented by a IterVarRelation.
434 * The relations connects the IterVars in the graph.
435 *
436 * Besides typical stage that corresponds to operations.
437 * There is also group stage, which groups stages together.
438 * Each stage's group(given by group) represent an constraint,
439 * the stage can only be attached to stages within the group.
440 *
441 * The group stage node can be attached to IterVars as in normal stage.
442 */
443 class StageNode : public Node {
444 public:
445 /*!
446 * \brief The operation of stage, can be different from original op.
447 * If it is null, then this stage is a group stage.
448 */
449 Operation op;
450 /*!
451 * \brief The original operator.
452 * The op field can change during schedule to alternate the dataflow,
453 * while origin_op remains fixed.
454 */
455 Operation origin_op;
456 /*! \brief All the nodes in the iter var */
457 Array<IterVar> all_iter_vars;
458 /*! \brief The current active leaf iter vars in the stage. */
459 Array<IterVar> leaf_iter_vars;
460 /*!
461 * \brief Specify threads to be launched at the stage.
462 * This is only valid for composite ops such as Scan.
463 * \note Experimental primitive: used for thread persistence.
464 */
465 Array<IterVar> env_threads;
466 /*!
467 * \brief The predicate under which store can happen
468 * Use this when there can be duplicated threads doing the same store.
469 * \note Experimental primitive: used by cross thread-reduction.
470 */
471 Expr store_predicate;
472 /*! \brief The relation bwteen of IterVars */
473 Array<IterVarRelation> relations;
474 /*! \brief additional attributes about iter var. */
475 Map<IterVar, IterVarAttr> iter_var_attrs;
476 /*! \brief The attachment type of the schedule */
477 AttachType attach_type{kGroupRoot};
478 /*! \brief The attach point of this schedule. */
479 IterVar attach_ivar;
480 /*! \brief The stage this node attaches to */
481 Stage attach_stage;
482 /*! \brief The thread storage scope level of the stage */
483 std::string scope;
484 /*! \brief Whether this is an output stage */
485 bool is_output{false};
486 /*! \brief Whether this is an OpenGL stage */
487 bool is_opengl{false};
488 /*! \brief Whether apply double buffer optimization to this stage */
489 bool double_buffer{false};
490 /*!
491 * \brief The parent group of the current stage.
492 * The stage cannot be assigned to stages outside the group.
493 */
494 Stage group;
495 /*! \brief Number of direct child stages, only used for group stage.*/
496 int num_child_stages{0};
497
VisitAttrs(AttrVisitor * v)498 void VisitAttrs(AttrVisitor* v) {
499 v->Visit("op", &op);
500 v->Visit("origin_op", &origin_op);
501 v->Visit("all_iter_vars", &all_iter_vars);
502 v->Visit("leaf_iter_vars", &leaf_iter_vars);
503 v->Visit("env_threads", &env_threads);
504 v->Visit("relations", &relations);
505 v->Visit("iter_var_attrs", &iter_var_attrs);
506 v->Visit("attach_type", &attach_type);
507 v->Visit("attach_ivar", &attach_ivar);
508 v->Visit("attach_stage", &attach_stage);
509 v->Visit("scope", &scope);
510 v->Visit("is_output", &is_output);
511 v->Visit("is_opengl", &is_opengl);
512 v->Visit("double_buffer", &double_buffer);
513 v->Visit("group", &group);
514 v->Visit("num_child_stages", &num_child_stages);
515 }
516
517 static constexpr const char* _type_key = "Stage";
518 TVM_DECLARE_NODE_TYPE_INFO(StageNode, Node);
519 };
520
521 /*! \brief node container for schedule */
522 class ScheduleNode : public Node {
523 public:
524 /*! \brief The output operations in original data flow graph */
525 Array<Operation> outputs;
526 /*!
527 * \brief list of all stages for ops.
528 * The stages are sorted in dependency order.
529 */
530 Array<Stage> stages;
531 /*!
532 * \brief List of all stage groups.
533 */
534 Array<Stage> groups;
535 /*! \brief map of original operation to the stages */
536 Map<Operation, Stage> stage_map;
537 /*!
538 * \brief Internal stage map to map internal ops to stages.
539 * This is created on demand and can be invalidated.
540 */
541 std::unordered_map<const Node*, Stage> op2stage_cache_;
542
VisitAttrs(AttrVisitor * v)543 void VisitAttrs(AttrVisitor* v) {
544 v->Visit("outputs", &outputs);
545 v->Visit("stages", &stages);
546 v->Visit("groups", &groups);
547 v->Visit("stage_map", &stage_map);
548 }
549
550 /*! \brief Initialize temp cache. */
551 void InitCache();
552 /*! \brief Invalidate temp cache. */
553 void InvalidateCache();
554
555 /*!
556 * \brief Check if the schedule contains an Operation.
557 * \param op The candidate Operation.
558 * \return true if the schedule has the Operation. Otherwise, false.
559 */
560 TVM_DLL bool Contain(const Operation& op) const;
561
562 /*!
563 * \brief Check if the schedule contains a Tensor.
564 * \param tensor The candidate tensor.
565 * \return true if the schedule has the tensor. Otherwise, false.
566 */
Contain(const Tensor & tensor)567 TVM_DLL bool Contain(const Tensor& tensor) const {
568 return Contain(tensor->op);
569 }
570
571 /*!
572 * \brief Create a schedule for array of ops(and their dependencies).
573 * \param ops The ops to be scheduled.
574 * \return sch The created Schedule.
575 */
576 TVM_DLL static Schedule make(Array<Operation> ops);
577
578 static constexpr const char* _type_key = "Schedule";
579 TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node);
580 };
581
582 /*!
583 * \brief Create a schedule for array of ops(and their dependencies).
584 * \param ops The ops to be scheduled.
585 * \return sch The created Schedule.
586 */
create_schedule(Array<Operation> ops)587 inline Schedule create_schedule(Array<Operation> ops) {
588 return ScheduleNode::make(ops);
589 }
590
591 /*! \brief node container for IterVar attr */
592 class IterVarAttrNode : public Node {
593 public:
594 /*! \brief The iteration type. */
595 IterVarType iter_type{kDataPar};
596 /*! \brief The thread this iter Var binds, can be null */
597 IterVar bind_thread;
598 /*! \brief List of tensor to be prefetched in this loop */
599 Array<Tensor> prefetch_data;
600 /*! \brief The offset used in each prefetch */
601 Array<Expr> prefetch_offset;
602 /*!
603 * \brief Tensor intrinsic used in tensorization,
604 * when the axis is marked as Tensorized
605 */
606 TensorIntrin tensor_intrin;
607 /*! \brief Alignment factor of buffer dimension */
608 int dim_align_factor{0};
609 /*! \brief Alignment offset of buffer dimension */
610 int dim_align_offset{0};
611 /*!
612 * \brief Additional pragma keys, array of StringImm
613 */
614 Array<Expr> pragma_keys;
615 /*!
616 * \brief Additional values of pragma, if any
617 */
618 Array<Expr> pragma_values;
619
VisitAttrs(AttrVisitor * v)620 void VisitAttrs(AttrVisitor* v) {
621 v->Visit("iter_type", &iter_type);
622 v->Visit("bind_thread", &bind_thread);
623 v->Visit("prefetch_data", &prefetch_data);
624 v->Visit("prefetch_offset", &prefetch_offset);
625 v->Visit("tensor_intrin", &tensor_intrin);
626 v->Visit("dim_align_factor", &dim_align_factor);
627 v->Visit("dim_align_offset", &dim_align_offset);
628 v->Visit("pragma_keys", &pragma_keys);
629 v->Visit("pragma_values", &pragma_values);
630 }
631
632 static constexpr const char* _type_key = "IterVarAttr";
633 TVM_DECLARE_NODE_TYPE_INFO(IterVarAttrNode, Node);
634 };
635
636 /*! \brief base node of iteration var */
637 class IterVarRelationNode : public Node {
638 public:
639 static constexpr const char* _type_key = "IterVarRelation";
640 TVM_DECLARE_BASE_NODE_INFO(IterVarRelationNode, Node);
641 };
642
643 /*!
644 * \brief Split the parent domain into product of
645 * outer and iter.
646 */
647 class SplitNode : public IterVarRelationNode {
648 public:
649 /*! \brief The parent domain */
650 IterVar parent;
651 /*! \brief The outer domain */
652 IterVar outer;
653 /*! \brief The inner domain */
654 IterVar inner;
655 /*! \brief The split factor */
656 Expr factor;
657 /*! \brief Number of parts, only factor or nparts can be given */
658 Expr nparts;
659
VisitAttrs(AttrVisitor * v)660 void VisitAttrs(AttrVisitor* v) {
661 v->Visit("parent", &parent);
662 v->Visit("outer", &outer);
663 v->Visit("inner", &inner);
664 v->Visit("factor", &factor);
665 v->Visit("nparts", &nparts);
666 }
667
668 static IterVarRelation make(IterVar parent,
669 IterVar outer,
670 IterVar inner,
671 Expr factor,
672 Expr nparts);
673
674 static constexpr const char* _type_key = "Split";
675 TVM_DECLARE_NODE_TYPE_INFO(SplitNode, IterVarRelationNode);
676 };
677
678 /*!
679 * \brief Fuse two domains into one domain.
680 */
681 class FuseNode : public IterVarRelationNode {
682 public:
683 /*! \brief The outer domain */
684 IterVar outer;
685 /*! \brief The inner domain */
686 IterVar inner;
687 /*! \brief The target domain */
688 IterVar fused;
689
VisitAttrs(AttrVisitor * v)690 void VisitAttrs(AttrVisitor* v) {
691 v->Visit("outer", &outer);
692 v->Visit("inner", &inner);
693 v->Visit("fused", &fused);
694 }
695
696 static IterVarRelation make(
697 IterVar outer, IterVar inner, IterVar fused);
698
699 static constexpr const char* _type_key = "Fuse";
700 TVM_DECLARE_NODE_TYPE_INFO(FuseNode, IterVarRelationNode);
701 };
702
703 /*!
704 * \brief Rebase the iteration to make min to be 0.
705 * This is useful to normalize the Schedule
706 * to make every leaf variable's min to be 0.
707 */
708 class RebaseNode : public IterVarRelationNode {
709 public:
710 /*! \brief The parent domain */
711 IterVar parent;
712 /*! \brief The inner domain */
713 IterVar rebased;
714
VisitAttrs(AttrVisitor * v)715 void VisitAttrs(AttrVisitor* v) {
716 v->Visit("parent", &parent);
717 v->Visit("rebased", &rebased);
718 }
719
720 static IterVarRelation make(IterVar parent, IterVar rebased);
721
722 static constexpr const char* _type_key = "Rebase";
723 TVM_DECLARE_NODE_TYPE_INFO(RebaseNode, IterVarRelationNode);
724 };
725
726
727 /*!
728 * \brief Singleton iterator [0, 1)
729 */
730 class SingletonNode : public IterVarRelationNode {
731 public:
732 /*! \brief The singleton iterator */
733 IterVar iter;
734
VisitAttrs(AttrVisitor * v)735 void VisitAttrs(AttrVisitor* v) {
736 v->Visit("iter", &iter);
737 }
738
739 static IterVarRelation make(IterVar iter);
740
741 static constexpr const char* _type_key = "Singleton";
742 TVM_DECLARE_NODE_TYPE_INFO(SingletonNode, IterVarRelationNode);
743 };
744
745
746 // implementations
747 inline const StageNode* Stage::operator->() const {
748 return static_cast<const StageNode*>(get());
749 }
750 inline StageNode* Stage::operator->() {
751 return static_cast<StageNode*>(get_mutable());
752 }
753
754 inline const ScheduleNode* Schedule::operator->() const {
755 return static_cast<const ScheduleNode*>(get());
756 }
757 inline ScheduleNode* Schedule::operator->() {
758 return static_cast<ScheduleNode*>(get_mutable());
759 }
760
761 inline const IterVarRelationNode* IterVarRelation::operator->() const {
762 return static_cast<const IterVarRelationNode*>(get());
763 }
764
765 inline const IterVarAttrNode* IterVarAttr::operator->() const {
766 return static_cast<const IterVarAttrNode*>(get());
767 }
768 } // namespace tvm
769 #endif // TVM_SCHEDULE_H_
770