1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file tvm/te/schedule.h
22 * \brief Define a schedule.
23 */
24 // Acknowledgement: Many schedule primitives originate from Halide and Loopy.
25 #ifndef TVM_TE_SCHEDULE_H_
26 #define TVM_TE_SCHEDULE_H_
27
28 #include <tvm/support/with.h>
29 #include <tvm/te/tensor.h>
30 #include <tvm/te/tensor_intrin.h>
31 #include <tvm/tir/expr.h>
32
33 #include <string>
34 #include <unordered_map>
35
36 namespace tvm {
37 namespace te {
38 // Node container for Stage
39 class StageNode;
40 // Node container for Schedule
41 class ScheduleNode;
42 // Node container for IterVarRelation
43 class IterVarRelationNode;
44 // Attribute of itervar.
45 class IterVarAttrNode;
46
47 /*! \brief the attachment type */
48 enum AttachType : int {
49 kGroupRoot = 1,
50 kInline = 2,
51 kInlinedAlready = 3,
52 kScope = 4,
53 kScanUpdate = 5
54 };
55
56 /*! \brief Stage, contains scheduling for a stage of computation. */
57 class Stage : public ObjectRef {
58 public:
Stage()59 Stage() {}
Stage(ObjectPtr<Object> n)60 explicit Stage(ObjectPtr<Object> n) : ObjectRef(n) {}
61 /*!
62 * \brief create a new schedule for op.
63 * \param op The operator in the schedule
64 */
65 explicit Stage(Operation op);
66 /*!
67 * \brief access the internal node container
68 * \return the pointer to the internal node container
69 */
70 inline const StageNode* operator->() const;
71 /*!
72 * \brief access the internal node container
73 * \return the pointer to the internal node container
74 */
75 inline StageNode* operator->();
76 /*!
77 * \brief set the memory scope of the stage
78 * \param scope The memory scope.
79 */
80 TVM_DLL Stage& set_scope(std::string scope); // NOLINT(*)
81 /*!
82 * \brief specify the schedule to be computed at the parent schedule's scope.
83 * \param parent The parent schedule.
84 * \param scope The iteration point to carry the schedule.
85 * \return reference to self.
86 */
87 TVM_DLL Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*)
88 /*!
89 * \brief Compute the function inline.
90 * \return reference to self.
91 */
92 TVM_DLL Stage& compute_inline(); // NOLINT(*)
93 /*!
94 * \brief Compute the function at group root.
95 * \return reference to self.
96 */
97 TVM_DLL Stage& compute_root(); // NOLINT(*)
98 /*!
99 * \brief Bind the IterVar to thread index.
100 *
101 * \param ivar The IterVar to be bound.
102 * \param thread_ivar The thread axis to be bound.
103 * \return reference to self.
104 */
105 TVM_DLL Stage& bind(IterVar ivar, IterVar thread_ivar);
106 /*!
107 * \brief Set the predicate to determine whether a store to the array should be performed.
108 * Use this when there are multiple threads performing the same store and we only
109 * need one of them to do the store.
110 *
111 * \note This is a dangerous scheduling primitive that can change behavior of program.
112 * Only do when we are certain that thare are duplicated stores.
113 * \param predicate The condition to be checked.
114 * \return reference to self.
115 */
116 TVM_DLL Stage& set_store_predicate(PrimExpr predicate);
117 /*!
118 * \brief Specify environment threads that launched around the group's scope.
119 * This can only be used in group stage.
120 * \param threads The threads to be launched around the scope.
121 * \note Each thread can only appear in one env_threads.
122 * This is a beta feature.
123 * \return reference to self.
124 */
125 TVM_DLL Stage& env_threads(Array<IterVar> threads);
126 /*!
127 * \brief Split the parent by factor, generate
128 * \param parent The parent iteration domain.
129 * \param factor The split factor of the loop.
130 * \param p_outer The result outer domain
131 * \param p_inner The result inner domain.
132 * \return reference to self.
133 */
134 TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer,
135 IterVar* p_inner); // NOLINT(*)
136 /*!
137 * \brief Split the iteration with given number of parts.
138 *
139 * \param parent The parent domain.
140 * \param nparts The number of parts in the outer domain.
141 * \param p_outer The result outer domain.
142 * \param p_inner The result inner domain.
143 * \return reference to self.
144 */
145 TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer,
146 IterVar* p_inner); // NOLINT(*)
147 /*!
148 * \brief Fuse the inner outer domain to the target
149 * \param outer The outer domain to be fused.
150 * \param inner The inner domain to be fused
151 * \param p_target The result target domain.
152 * \return reference to self.
153 */
154 TVM_DLL Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*)
155 /*!
156 * \brief Fuse all the axes together into a single axis.
157 *
158 * \param axes All the axes to be fused.
159 * \param p_target The result target domain.
160 *
161 * \note axes can be an empty array,
162 * in that case, a singleton IterVar is created and
163 * inserted to the outermost loop.
164 * The fuse of empty array is used to support zero-dimension tensors.
165 *
166 * \return reference to self.
167 */
168 TVM_DLL Stage& fuse(const Array<IterVar>& axes, IterVar* p_target); // NOLINT(*)
169 /*!
170 * \brief Reorder the iteration
171 * \param order The order of iteration variable.
172 * \return reference to self.
173 */
174 TVM_DLL Stage& reorder(const Array<IterVar>& order); // NOLINT(*)
175 /*!
176 * \brief Perform tiling on two dimensions
177 * The final loop order from outmost to inner most are
178 * [x_outer, y_outer, x_inner, y_inner]
179 *
180 * \param x_parent The original x dimension
181 * \param y_parent The original y dimension
182 * \param x_factor The stride factor on x axis
183 * \param y_factor The stride factor on y axis
184 * \param p_x_outer Outer axis of x dimension
185 * \param p_y_outer Outer axis of y dimension
186 * \param p_x_inner Inner axis of x dimension
187 * \param p_y_inner Inner axis of y dimension
188 * \return reference to self.
189 */
190 TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
191 PrimExpr x_factor, PrimExpr y_factor, IterVar* p_x_outer, IterVar* p_y_outer,
192 IterVar* p_x_inner, IterVar* p_y_inner);
193 /*!
194 * \brief Vectorize iteration.
195 * \param var The axis to be vectorized.
196 * \return reference to self.
197 */
198 TVM_DLL Stage& vectorize(IterVar var); // NOLINT(*)
199 /*!
200 * \brief Replace computation of the current stage by tensor intrinsic f.
201 * \param var The axis marks beginning of tensorization.
202 * Every operations inside the axis(include axis itself is tensorized).
203 * \param f The Tensor compute intrinsics.
204 * \return reference to self.
205 */
206 TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*)
207 /*!
208 * \brief Unroll iteration.
209 * \param var The axis to be unrolled.
210 * \return reference to self.
211 */
212 TVM_DLL Stage& unroll(IterVar var); // NOLINT(*)
213 /*!
214 * \brief Parallelize iteration.
215 * \param var The axis to be parallelized.
216 * \return reference to self.
217 */
218 TVM_DLL Stage& parallel(IterVar var); // NOLINT(*)
219 /*!
220 * \brief Annotate the iteration with pragma
221 *
222 * \param var The axis to be parallelized.
223 * \param pragma_type The pragma type.
224 * \param pragma_value The pragma value
225 *
226 * \return reference to self.
227 */
228 TVM_DLL Stage& pragma(IterVar var, const std::string& pragma_type,
229 const PrimExpr& pragma_value = PrimExpr()); // NOLINT(*)
230 /*!
231 * \brief Fetch data in advance.
232 * \param domain the tensor to be prefetched
233 * \param var the iteration point at which to apply prefetching
234 * \param offset the number of iterations be to fetched in advance
235 * \return reference to self
236 */
237 TVM_DLL Stage& prefetch(const Tensor& domain, IterVar var, PrimExpr offset); // NOLINT(*)
238 /*!
239 * \brief Set alignment requirement for specific dimension.
240 *
241 * Such that stride[axis] == k * factor + offset for some k.
242 *
243 * \param axis The dimension to be specified for alignment.
244 * \param factor The factor multiple of alignment
245 * \param offset The required offset factor.
246 * \return reference to self
247 */
248 TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); // NOLINT(*)
249 /*!
250 * \brief Compute current stage with double buffering.
251 * \return reference to self.
252 */
253 TVM_DLL Stage& double_buffer(); // NOLINT(*)
254 /*!
255 * \brief whether the stage has been scheduled.
256 * \return whether the stage has been scheduled.
257 */
258 bool is_scheduled() const;
259 /*!
260 * \brief Get attachment spec of current stage.
261 * If the stage compute at Group root, this function
262 * will traverse the group function to get the
263 * final spec from the group.
264 * \return A stage representing the attach spec of the group.
265 */
266 Stage GetAttachSpec() const;
267 // declare container type
268 using ContainerType = StageNode;
269 };
270
271 /*!
272 * \brief Global schedule container
273 * For operations and all the operations they depend on.
274 * The schedule per Operation is named as stage.
275 */
276 class Schedule : public ObjectRef {
277 public:
Schedule()278 Schedule() {}
Schedule(ObjectPtr<Object> n)279 explicit Schedule(ObjectPtr<Object> n) : ObjectRef(n) {}
280 /*!
281 * \brief Create a schedule for array of ops(and their dependencies).
282 * \param ops The ops to be scheduled.
283 * \return sch The created Schedule.
284 */
285 TVM_DLL explicit Schedule(Array<Operation> ops);
286 /*!
287 * \brief Get a copy of current schedule.
288 * \return The copied schedule.
289 */
290 Schedule copy() const;
291 /*!
292 * \brief Get the stage corresponds to the op
293 * \param op The operation.
294 */
295 TVM_DLL Stage operator[](const Operation& op);
296 /*!
297 * \brief Short hand for getting the stage of tensor's operation.
298 * \param tensor The tensor
299 * \return The stage corresponding to the tensor's op
300 */
301 TVM_DLL Stage operator[](const Tensor& tensor) { return this->operator[](tensor->op); }
302 /*!
303 * \brief Create a new stage group for all intermediate
304 * operations between inputs and outputs.
305 *
306 * \param outputs The output boundary of the group.
307 * \param inputs The input boundary of the group.
308 * \param include_inputs Whether include inputs if they are reachable from outputs.
309 * \return The new grouped stage.
310 */
311 TVM_DLL Stage create_group(const Array<Tensor>& outputs, const Array<Tensor>& inputs,
312 bool include_inputs = false);
313 /*!
314 * \brief create a cache read of original tensor for readers.
315 * This will mutate the body of the readers.
316 * A new stage will be created for the tensor.
317 * \param tensor The tensor cached.
318 * \param scope The scope of the cache.
319 * \param readers The readers to redirect to the tensor.
320 * \return The created tensor.
321 */
322 TVM_DLL Tensor cache_read(const Tensor& tensor, const std::string& scope,
323 const Array<Operation>& readers);
324 /*!
325 * \brief Create a cache write tensor for producing tensor.
326 * The the tensor will take over body of original tensor op.
327 *
328 * This function can be used to do data layout transformation.
329 * If there is a split/fuse/reorder on the data parallel axis of tensor
330 * before cache_write is called. The intermediate cache stores
331 * the data in the layout as the iteration order of leave axis.
332 * The data will be transformed back to the original layout in the original tensor.
333 * User can further call compute_inline to inline the original layout and keep
334 * the data stored in the transformed layout.
335 *
336 * \param tensor The tensors to be produced.
337 * \param scope The scope of the storage.
338 * \return The created tensor.
339 */
340 TVM_DLL Array<Tensor> cache_write(const Array<Tensor>& tensor, const std::string& scope);
341 /*!
342 * \brief Create a cache write tensor for producing tensor.
343 * The the tensor will take over body of original tensor op.
344 *
345 * This function can be used to do data layout transformation.
346 * If there is a split/fuse/reorder on the data parallel axis of tensor
347 * before cache_write is called. The intermediate cache stores
348 * the data in the layout as the iteration order of leave axis.
349 * The data will be transformed back to the original layout in the original tensor.
350 * User can further call compute_inline to inline the original layout and keep
351 * the data stored in the transformed layout.
352 *
353 * \param tensor The tensor to be produced.
354 * \param scope The scope of the storage.
355 * \return The created tensor.
356 */
357 TVM_DLL Tensor cache_write(const Tensor& tensor, const std::string& scope);
358 /*!
359 * \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
360 * This will create a new stage that generated the new tensor with axis
361 * as the first dimension. The tensor's body will be rewritten as a reduction
362 * over the factored tensor.
363 *
364 * P. Suriana, A. Adams and S. Kamil. Parallel associative reductions in halide. CGO'17
365 *
366 * \param tensor The tensor to be factored.
367 * \param axis The reduction axis in tensor's schedule to be factored.
368 * \param factor_axis The position where the new axis is placed.
369 * \return The created factored tensors.
370 */
371 TVM_DLL Array<Tensor> rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis = 0);
372 /*!
373 * \brief Normalize the schedule.
374 * This is needed before bound inference.
375 * Insert necessary RebaseNode to make sure all leaf_iter_vars
376 * are in form [0, extent)
377 *
378 * \return A normalized schedule, can be same as current one.
379 */
380 Schedule normalize();
381 /*!
382 * \brief access the internal node container
383 * \return the pointer to the internal node container
384 */
385 inline const ScheduleNode* operator->() const;
386 /*!
387 * \brief access the internal node container
388 * \return the pointer to the internal node container
389 */
390 inline ScheduleNode* operator->();
391 // declare container type
392 using ContainerType = ScheduleNode;
393 };
394
395 /*!
396 * \brief The schedule relation between IterVars
397 * can be Split, Fuse.
398 */
399 class IterVarRelation : public ObjectRef {
400 public:
IterVarRelation()401 IterVarRelation() {}
IterVarRelation(ObjectPtr<Object> n)402 explicit IterVarRelation(ObjectPtr<Object> n) : ObjectRef(n) {}
403 /*!
404 * \brief access the internal node container
405 * \return the pointer to the internal node container
406 */
407 inline const IterVarRelationNode* operator->() const;
408 };
409
410 /*!
411 * \brief Additional scheduable attributes about IterVar.
412 */
413 class IterVarAttr : public ObjectRef {
414 public:
IterVarAttr()415 IterVarAttr() {}
IterVarAttr(ObjectPtr<Object> n)416 explicit IterVarAttr(ObjectPtr<Object> n) : ObjectRef(n) {}
417 /*!
418 * \brief access the internal node container
419 * \return the pointer to the internal node container
420 */
421 inline const IterVarAttrNode* operator->() const;
422 };
423
424 /*!
425 * \brief represents a stage.
426 *
427 * relations form a Directed acylic hypergraph in bipartite manner.
428 * With each node is represented by a IterVar,
429 * and each hyper-edge is represented by a IterVarRelation.
430 * The relations connects the IterVars in the graph.
431 *
432 * Besides typical stage that corresponds to operations.
433 * There is also group stage, which groups stages together.
434 * Each stage's group(given by group) represent an constraint,
435 * the stage can only be attached to stages within the group.
436 *
437 * The group stage node can be attached to IterVars as in normal stage.
438 */
439 class StageNode : public Object {
440 public:
441 /*!
442 * \brief The operation of stage, can be different from original op.
443 * If it is null, then this stage is a group stage.
444 */
445 Operation op;
446 /*!
447 * \brief The original operator.
448 * The op field can change during schedule to alternate the dataflow,
449 * while origin_op remains fixed.
450 */
451 Operation origin_op;
452 /*! \brief All the nodes in the iter var */
453 Array<IterVar> all_iter_vars;
454 /*! \brief The current active leaf iter vars in the stage. */
455 Array<IterVar> leaf_iter_vars;
456 /*!
457 * \brief Specify threads to be launched at the stage.
458 * This is only valid for composite ops such as Scan.
459 * \note Experimental primitive: used for thread persistence.
460 */
461 Array<IterVar> env_threads;
462 /*!
463 * \brief The predicate under which store can happen
464 * Use this when there can be duplicated threads doing the same store.
465 * \note Experimental primitive: used by cross thread-reduction.
466 */
467 PrimExpr store_predicate;
468 /*! \brief The relation bwteen of IterVars */
469 Array<IterVarRelation> relations;
470 /*! \brief additional attributes about iter var. */
471 Map<IterVar, IterVarAttr> iter_var_attrs;
472 /*! \brief The attachment type of the schedule */
473 AttachType attach_type{kGroupRoot};
474 /*! \brief The attach point of this schedule. */
475 IterVar attach_ivar;
476 /*! \brief The stage this node attaches to */
477 Stage attach_stage;
478 /*! \brief The thread storage scope level of the stage */
479 std::string scope;
480 /*! \brief Whether this is an output stage */
481 bool is_output{false};
482 /*! \brief Whether apply double buffer optimization to this stage */
483 bool double_buffer{false};
484 /*!
485 * \brief The parent group of the current stage.
486 * The stage cannot be assigned to stages outside the group.
487 */
488 Stage group;
489 /*! \brief Number of direct child stages, only used for group stage.*/
490 int num_child_stages{0};
491
VisitAttrs(AttrVisitor * v)492 void VisitAttrs(AttrVisitor* v) {
493 v->Visit("op", &op);
494 v->Visit("origin_op", &origin_op);
495 v->Visit("all_iter_vars", &all_iter_vars);
496 v->Visit("leaf_iter_vars", &leaf_iter_vars);
497 v->Visit("env_threads", &env_threads);
498 v->Visit("relations", &relations);
499 v->Visit("iter_var_attrs", &iter_var_attrs);
500 v->Visit("attach_type", &attach_type);
501 v->Visit("attach_ivar", &attach_ivar);
502 v->Visit("attach_stage", &attach_stage);
503 v->Visit("scope", &scope);
504 v->Visit("is_output", &is_output);
505 v->Visit("double_buffer", &double_buffer);
506 v->Visit("group", &group);
507 v->Visit("num_child_stages", &num_child_stages);
508 }
509
510 static constexpr const char* _type_key = "Stage";
511 TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object);
512 };
513
514 /*! \brief node container for schedule */
515 class ScheduleNode : public Object {
516 public:
517 /*! \brief The output operations in original data flow graph */
518 Array<Operation> outputs;
519 /*!
520 * \brief list of all stages for ops.
521 * The stages are sorted in dependency order.
522 */
523 Array<Stage> stages;
524 /*!
525 * \brief List of all stage groups.
526 */
527 Array<Stage> groups;
528 /*! \brief map of original operation to the stages */
529 Map<Operation, Stage> stage_map;
530 /*!
531 * \brief Internal stage map to map internal ops to stages.
532 * This is created on demand and can be invalidated.
533 */
534 std::unordered_map<const Object*, Stage> op2stage_cache_;
535
VisitAttrs(AttrVisitor * v)536 void VisitAttrs(AttrVisitor* v) {
537 v->Visit("outputs", &outputs);
538 v->Visit("stages", &stages);
539 v->Visit("groups", &groups);
540 v->Visit("stage_map", &stage_map);
541 }
542
543 /*! \brief Initialize temp cache. */
544 void InitCache();
545 /*! \brief Invalidate temp cache. */
546 void InvalidateCache();
547
548 /*!
549 * \brief Check if the schedule contains an Operation.
550 * \param op The candidate Operation.
551 * \return true if the schedule has the Operation. Otherwise, false.
552 */
553 TVM_DLL bool Contain(const Operation& op) const;
554
555 /*!
556 * \brief Check if the schedule contains a Tensor.
557 * \param tensor The candidate tensor.
558 * \return true if the schedule has the tensor. Otherwise, false.
559 */
Contain(const Tensor & tensor)560 TVM_DLL bool Contain(const Tensor& tensor) const { return Contain(tensor->op); }
561
562 static constexpr const char* _type_key = "Schedule";
563 TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, Object);
564 };
565
566 /*!
567 * \brief Create a schedule for array of ops(and their dependencies).
568 * \param ops The ops to be scheduled.
569 * \return sch The created Schedule.
570 */
create_schedule(Array<Operation> ops)571 inline Schedule create_schedule(Array<Operation> ops) { return Schedule(ops); }
572
573 /*! \brief node container for IterVar attr */
574 class IterVarAttrNode : public Object {
575 public:
576 /*! \brief The iteration type. */
577 IterVarType iter_type{kDataPar};
578 /*! \brief The thread this iter Var binds, can be null */
579 IterVar bind_thread;
580 /*! \brief List of tensor to be prefetched in this loop */
581 Array<Tensor> prefetch_data;
582 /*! \brief The offset used in each prefetch */
583 Array<PrimExpr> prefetch_offset;
584 /*!
585 * \brief Tensor intrinsic used in tensorization,
586 * when the axis is marked as Tensorized
587 */
588 TensorIntrin tensor_intrin;
589 /*! \brief Alignment factor of buffer dimension */
590 int dim_align_factor{0};
591 /*! \brief Alignment offset of buffer dimension */
592 int dim_align_offset{0};
593 /*!
594 * \brief Additional pragma keys, array of StringImm
595 */
596 Array<PrimExpr> pragma_keys;
597 /*!
598 * \brief Additional values of pragma, if any
599 */
600 Array<PrimExpr> pragma_values;
601
VisitAttrs(AttrVisitor * v)602 void VisitAttrs(AttrVisitor* v) {
603 v->Visit("iter_type", &iter_type);
604 v->Visit("bind_thread", &bind_thread);
605 v->Visit("prefetch_data", &prefetch_data);
606 v->Visit("prefetch_offset", &prefetch_offset);
607 v->Visit("tensor_intrin", &tensor_intrin);
608 v->Visit("dim_align_factor", &dim_align_factor);
609 v->Visit("dim_align_offset", &dim_align_offset);
610 v->Visit("pragma_keys", &pragma_keys);
611 v->Visit("pragma_values", &pragma_values);
612 }
613
614 static constexpr const char* _type_key = "IterVarAttr";
615 TVM_DECLARE_FINAL_OBJECT_INFO(IterVarAttrNode, Object);
616 };
617
618 /*! \brief base node of iteration var */
619 class IterVarRelationNode : public Object {
620 public:
621 static constexpr const char* _type_key = "IterVarRelation";
622 TVM_DECLARE_BASE_OBJECT_INFO(IterVarRelationNode, Object);
623 };
624
625 /*!
626 * \brief Split the parent domain into product of
627 * outer and iter.
628 */
629 class SplitNode : public IterVarRelationNode {
630 public:
631 /*! \brief The parent domain */
632 IterVar parent;
633 /*! \brief The outer domain */
634 IterVar outer;
635 /*! \brief The inner domain */
636 IterVar inner;
637 /*! \brief The split factor */
638 PrimExpr factor;
639 /*! \brief Number of parts, only factor or nparts can be given */
640 PrimExpr nparts;
641
VisitAttrs(AttrVisitor * v)642 void VisitAttrs(AttrVisitor* v) {
643 v->Visit("parent", &parent);
644 v->Visit("outer", &outer);
645 v->Visit("inner", &inner);
646 v->Visit("factor", &factor);
647 v->Visit("nparts", &nparts);
648 }
649
650 static constexpr const char* _type_key = "Split";
651 TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode);
652 };
653
654 /*!
655 * \brief Managed reference to SplitNode
656 * \sa SplitNode
657 */
658 class Split : public IterVarRelation {
659 public:
660 TVM_DLL Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts);
661
662 TVM_DEFINE_OBJECT_REF_METHODS(Split, IterVarRelation, SplitNode);
663 };
664
665 /*!
666 * \brief Fuse two domains into one domain.
667 */
668 class FuseNode : public IterVarRelationNode {
669 public:
670 /*! \brief The outer domain */
671 IterVar outer;
672 /*! \brief The inner domain */
673 IterVar inner;
674 /*! \brief The target domain */
675 IterVar fused;
676
VisitAttrs(AttrVisitor * v)677 void VisitAttrs(AttrVisitor* v) {
678 v->Visit("outer", &outer);
679 v->Visit("inner", &inner);
680 v->Visit("fused", &fused);
681 }
682
683 static constexpr const char* _type_key = "Fuse";
684 TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode);
685 };
686
687 /*!
688 * \brief Managed reference to FuseNode
689 * \sa FuseNode
690 */
691 class Fuse : public IterVarRelation {
692 public:
693 TVM_DLL Fuse(IterVar outer, IterVar inner, IterVar fused);
694
695 TVM_DEFINE_OBJECT_REF_METHODS(Fuse, IterVarRelation, FuseNode);
696 };
697
698 /*!
699 * \brief Rebase the iteration to make min to be 0.
700 * This is useful to normalize the Schedule
701 * to make every leaf variable's min to be 0.
702 */
703 class RebaseNode : public IterVarRelationNode {
704 public:
705 /*! \brief The parent domain */
706 IterVar parent;
707 /*! \brief The inner domain */
708 IterVar rebased;
709
VisitAttrs(AttrVisitor * v)710 void VisitAttrs(AttrVisitor* v) {
711 v->Visit("parent", &parent);
712 v->Visit("rebased", &rebased);
713 }
714
715 static constexpr const char* _type_key = "Rebase";
716 TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode);
717 };
718
719 /*!
720 * \brief Managed reference to RebaseNode
721 * \sa RebaseNode
722 */
723 class Rebase : public IterVarRelation {
724 public:
725 TVM_DLL Rebase(IterVar parent, IterVar rebased);
726
727 TVM_DEFINE_OBJECT_REF_METHODS(Rebase, IterVarRelation, RebaseNode);
728 };
729
730 /*!
731 * \brief Singleton iterator [0, 1)
732 */
733 class SingletonNode : public IterVarRelationNode {
734 public:
735 /*! \brief The singleton iterator */
736 IterVar iter;
737
VisitAttrs(AttrVisitor * v)738 void VisitAttrs(AttrVisitor* v) { v->Visit("iter", &iter); }
739
740 static constexpr const char* _type_key = "Singleton";
741 TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode);
742 };
743
744 /*!
745 * \brief Managed reference to SingletonNode
746 * \sa SingletonNode
747 */
748 class Singleton : public IterVarRelation {
749 public:
750 TVM_DLL explicit Singleton(IterVar iter);
751
752 TVM_DEFINE_OBJECT_REF_METHODS(Singleton, IterVarRelation, SingletonNode);
753 };
754
755 /*! \brief Container for specialization conditions. */
756 class SpecializedConditionNode : public Object {
757 public:
758 /*!
759 * \brief List of conditions in conjunctive joint form (CNF).
760 * Each condition should be a simple expression, e.g., n > 16, m % 8 == 0, etc.,
761 * where n, m are tvm::Var that represents a dimension in the tensor shape.
762 */
763 Array<PrimExpr> clauses;
764
VisitAttrs(AttrVisitor * v)765 void VisitAttrs(AttrVisitor* v) { v->Visit("clauses", &clauses); }
766
767 static constexpr const char* _type_key = "SpecializedCondition";
768 TVM_DECLARE_FINAL_OBJECT_INFO(SpecializedConditionNode, Object);
769 };
770
771 /*!
772 * \brief Specialized condition to enable op specialization
773 */
774 class SpecializedCondition : public ObjectRef {
775 public:
776 /*!
777 * \brief construct from conditions
778 * \param conditions The clauses in the specialized condition.
779 */
780 TVM_DLL SpecializedCondition(Array<PrimExpr> conditions); // NOLINT(*)
781
782 /*!
783 * \brief Get the current specialized condition.
784 * \return the current specialized condition.
785 */
786 TVM_DLL static SpecializedCondition Current();
787
788 TVM_DEFINE_OBJECT_REF_METHODS(SpecializedCondition, ObjectRef, SpecializedConditionNode);
789 class Internal;
790
791 private:
792 // enable with syntax.
793 friend class Internal;
794 friend class With<SpecializedCondition>;
795 /*! \brief Push a new specialized condition onto the thread local stack. */
796 TVM_DLL void EnterWithScope();
797 /*! \brief Pop a specialized condition off the thread local context stack. */
798 TVM_DLL void ExitWithScope();
799 };
800
801 // implementations
802 inline const StageNode* Stage::operator->() const { return static_cast<const StageNode*>(get()); }
803 inline StageNode* Stage::operator->() { return static_cast<StageNode*>(get_mutable()); }
804
805 inline const ScheduleNode* Schedule::operator->() const {
806 return static_cast<const ScheduleNode*>(get());
807 }
808 inline ScheduleNode* Schedule::operator->() { return static_cast<ScheduleNode*>(get_mutable()); }
809
810 inline const IterVarRelationNode* IterVarRelation::operator->() const {
811 return static_cast<const IterVarRelationNode*>(get());
812 }
813
814 inline const IterVarAttrNode* IterVarAttr::operator->() const {
815 return static_cast<const IterVarAttrNode*>(get());
816 }
817
818 } // namespace te
819 } // namespace tvm
820 #endif // TVM_TE_SCHEDULE_H_
821