1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file auto_scheduler/loop_state.cc
22  * \brief An lightweight IR (intermediate representation) for loop structures.
23  * see auto_scheduler/loop_state.h for more explanation.
24  */
25 
26 #include <tvm/auto_scheduler/compute_dag.h>
27 #include <tvm/auto_scheduler/loop_state.h>
28 #include <tvm/auto_scheduler/transform_step.h>
29 #include <tvm/runtime/registry.h>
30 #include <tvm/te/operation.h>
31 
32 #include <utility>
33 
34 #include "utils.h"
35 
36 namespace tvm {
37 namespace auto_scheduler {
38 
39 TVM_REGISTER_OBJECT_TYPE(StepNode);
40 TVM_REGISTER_NODE_TYPE(StageNode);
41 TVM_REGISTER_NODE_TYPE(StateNode);
42 TVM_REGISTER_NODE_TYPE(IteratorNode);
43 
44 /********** Iterator **********/
Iterator(String name,Range range,IteratorKind iter_kind,IteratorAnnotation annotation,const std::vector<Iterator> * orig_iters)45 Iterator::Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation,
46                    const std::vector<Iterator>* orig_iters) {
47   auto node = make_object<IteratorNode>();
48   node->name = std::move(name);
49   node->range = std::move(range);
50   node->iter_kind = iter_kind;
51   node->annotation = annotation;
52   if (orig_iters != nullptr) {
53     node->orig_iters = *orig_iters;
54   }
55   data_ = std::move(node);
56 }
57 
58 /********** Stage **********/
Stage(te::Operation op)59 Stage::Stage(te::Operation op) {
60   auto node = make_object<StageNode>();
61   if (op->IsInstance<te::ComputeOpNode>()) {
62     node->op_type = StageKind::kCompute;
63     auto* pop = op.as<te::ComputeOpNode>();
64     for (const auto& axis : pop->axis) {
65       node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom,
66                                      IteratorKind::kSpatial, IteratorAnnotation::kNone));
67     }
68     for (const auto& axis : pop->reduce_axis) {
69       node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom,
70                                      IteratorKind::kReduction, IteratorAnnotation::kNone));
71     }
72   } else if (op->IsInstance<te::PlaceholderOpNode>()) {
73     node->op_type = StageKind::kPlaceholder;
74   } else {
75     LOG(FATAL) << "Unsupported operator type" << op->_type_key;
76   }
77 
78   node->compute_at = ComputeAtKind::kRoot;
79   node->op = std::move(op);
80   node->attrs.auto_unroll_max_step = 0;
81   node->attrs.storage_offset = 0;
82   data_ = std::move(node);
83 }
84 
Stage(te::Operation op,StageKind op_type,const Array<Iterator> & iters,ComputeAtKind compute_at,StageAttributes attrs)85 Stage::Stage(te::Operation op, StageKind op_type, const Array<Iterator>& iters,
86              ComputeAtKind compute_at, StageAttributes attrs) {
87   auto node = make_object<StageNode>();
88   node->op = std::move(op);
89   node->op_type = op_type;
90   node->iters = iters;
91   node->compute_at = compute_at;
92   node->attrs = attrs;
93   data_ = std::move(node);
94 }
95 
96 /********** AttachMap **********/
SetComputeAtIter(int stage_id,int target_stage_id,int target_iter_id)97 void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id) {
98   AttachMapNode* pnode = CopyOnWrite();
99 
100   // Delete the current entry of this stage
101   DeleteStageEntry(pnode, stage_id);
102 
103   // Store the new stage/iterator relations to map
104   IterKey iter_key(target_stage_id, target_iter_id);
105   pnode->stage_to_attach_iter[stage_id] = iter_key;
106   pnode->iter_to_attached_stages[iter_key].push_back(stage_id);
107 }
108 
DeleteStage(int stage_id)109 void AttachMap::DeleteStage(int stage_id) {
110   AttachMapNode* pnode = CopyOnWrite();
111   // Delete the original stage entry
112   DeleteStageEntry(pnode, stage_id);
113 }
114 
UpdateIters(const std::vector<IterKey> & original_iters,const std::vector<IterKey> & new_iters)115 void AttachMap::UpdateIters(const std::vector<IterKey>& original_iters,
116                             const std::vector<IterKey>& new_iters) {
117   CHECK_EQ(original_iters.size(), new_iters.size());
118   AttachMapNode* pnode = CopyOnWrite();
119   std::unordered_map<IterKey, std::vector<StageKey>> new_iter_to_attached_stages;
120   for (size_t i = 0; i < original_iters.size(); ++i) {
121     auto entry = pnode->iter_to_attached_stages.find(original_iters[i]);
122     // We get <IterKey, std::vector<StageKey>> from this map
123     if (entry == pnode->iter_to_attached_stages.end()) {
124       // Skip if this iterator does not have any attach relations
125       continue;
126     }
127 
128     // Update the attaching target of an stage to the new iter in `stage_to_attach_iter`
129     for (const auto& s : entry->second) {
130       pnode->stage_to_attach_iter[s] = new_iters[i];
131     }
132 
133     // Remove the original iterator relation from `iter_to_attached_stages` and add the new
134     // iterator to it
135     std::vector<int> attached_stages = std::move(entry->second);
136     pnode->iter_to_attached_stages.erase(entry);
137     new_iter_to_attached_stages[new_iters[i]] = std::move(attached_stages);
138   }
139 
140   // Update new entries
141   for (auto& it : new_iter_to_attached_stages) {
142     pnode->iter_to_attached_stages[it.first] = std::move(it.second);
143   }
144 }
145 
DeleteStageEntry(AttachMapNode * pnode,int stage_id)146 void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) {
147   auto old_entry = pnode->stage_to_attach_iter.find(stage_id);
148   // We get <StageKey, IterKey> from this map
149   if (old_entry != pnode->stage_to_attach_iter.end()) {
150     // Delete the stage in `iter_to_attached_stages`, if the corresponding iterator does not have
151     // any attached stage, delete this iterm too
152     auto entry2 = pnode->iter_to_attached_stages.find(old_entry->second);
153     // We get <IterKey, std::vector<StageKey>> from this map
154     FindAndDeleteItem(&entry2->second, stage_id);
155     if (entry2->second.size() == 0) {
156       pnode->iter_to_attached_stages.erase(entry2);
157     }
158     // Delete the stage in `stage_to_attach_iter`
159     pnode->stage_to_attach_iter.erase(old_entry);
160   }
161 }
162 
ApplyStageIdOffset(int start_id,int offset) const163 AttachMap AttachMap::ApplyStageIdOffset(int start_id, int offset) const {
164   AttachMap map = AttachMap(make_object<AttachMapNode>());
165   auto pmap = map.CopyOnWrite();
166   for (const auto& x : operator->()->stage_to_attach_iter) {
167     auto key = x.first;
168     if (key >= start_id) {
169       key += offset;
170     }
171     auto value = x.second;
172     if (value.first >= start_id) {
173       value.first += offset;
174     }
175     pmap->stage_to_attach_iter.insert(std::make_pair(key, value));
176   }
177   for (const auto& x : operator->()->iter_to_attached_stages) {
178     auto key = x.first;
179     if (key.first >= start_id) {
180       key.first += offset;
181     }
182     auto value = x.second;
183     for (auto& i : value) {
184       if (i >= start_id) {
185         i += offset;
186       }
187     }
188     pmap->iter_to_attached_stages.insert(std::make_pair(key, value));
189   }
190   return map;
191 }
192 
193 /********** State **********/
State(const Array<te::Operation> & ops)194 State::State(const Array<te::Operation>& ops) {
195   auto node = make_object<StateNode>();
196   for (const auto& op : ops) {
197     node->stages.push_back(Stage(op));
198   }
199   node->attach_map = AttachMap(make_object<AttachMapNode>());
200   node->concrete = true;
201   data_ = std::move(node);
202 }
203 
204 /********** Schedule primitives apis for state **********/
bind(int stage_id,const Iterator & it,IteratorAnnotation thread_type)205 Iterator State::bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type) {
206   const Stage& stage = operator->()->stages[stage_id];
207   if (thread_type < IteratorAnnotation::kVThread || thread_type > IteratorAnnotation::kThreadZ) {
208     LOG(FATAL) << "thread_type error, valid: kVThread, kBlockX, kBlockY, "
209                << "kThreadX, kThreadY, kBlockZ, kThreadZ";
210   }
211   AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it), thread_type);
212   CopyOnWrite()->transform_steps.push_back(step);
213   return step->ApplyToState(this);
214 }
215 
parallel(int stage_id,const Iterator & it)216 Iterator State::parallel(int stage_id, const Iterator& it) {
217   const Stage& stage = operator->()->stages[stage_id];
218   AnnotationStep step =
219       AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kParallel);
220   CopyOnWrite()->transform_steps.push_back(step);
221   return step->ApplyToState(this);
222 }
223 
unroll(int stage_id,const Iterator & it,int max_unroll)224 Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) {
225   const Stage& stage = operator->()->stages[stage_id];
226 
227   // Don't unroll if the extent is larger than max_unroll
228   if (max_unroll != -1 && it->range.defined()) {
229     if (auto imm = it->range->extent.as<IntImmNode>()) {
230       if (imm->value > max_unroll) {
231         return it;
232       }
233     }
234   }
235 
236   AnnotationStep step =
237       AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kUnroll);
238   CopyOnWrite()->transform_steps.push_back(step);
239   return step->ApplyToState(this);
240 }
241 
vectorize(int stage_id,const Iterator & it)242 Iterator State::vectorize(int stage_id, const Iterator& it) {
243   const Stage& stage = operator->()->stages[stage_id];
244   AnnotationStep step =
245       AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kVectorize);
246   CopyOnWrite()->transform_steps.push_back(step);
247   return step->ApplyToState(this);
248 }
249 
fuse(int stage_id,const Array<Iterator> & iters)250 Iterator State::fuse(int stage_id, const Array<Iterator>& iters) {
251   const Stage& stage = operator->()->stages[stage_id];
252   Array<Integer> indices;
253   GetIndices(stage->iters, iters, &indices);
254   FuseStep step = FuseStep(stage_id, indices);
255   CopyOnWrite()->transform_steps.push_back(step);
256   return step->ApplyToState(this);
257 }
258 
pragma(int stage_id,const Iterator & it,const String & pragma_type)259 void State::pragma(int stage_id, const Iterator& it, const String& pragma_type) {
260   const Stage& stage = operator->()->stages[stage_id];
261   PragmaStep step = PragmaStep(stage_id, GetIndex(stage->iters, it), pragma_type);
262   CopyOnWrite()->transform_steps.push_back(step);
263   return step->ApplyToState(this);
264 }
265 
reorder(int stage_id,const Array<Iterator> & order)266 void State::reorder(int stage_id, const Array<Iterator>& order) {
267   const Stage& stage = operator->()->stages[stage_id];
268   CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators "
269                                               << "should be specified";
270   Array<Integer> after_ids;
271   GetIndices(stage->iters, order, &after_ids);
272   ReorderStep step = ReorderStep(stage_id, after_ids);
273   CopyOnWrite()->transform_steps.push_back(step);
274   step->ApplyToState(this);
275 }
276 
split(int stage_id,const Iterator & it,const Array<Optional<Integer>> & lengths,bool inner_to_outer)277 Array<Iterator> State::split(int stage_id, const Iterator& it,
278                              const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
279   const Stage& stage = operator->()->stages[stage_id];
280   SplitStep step =
281       SplitStep(stage_id, GetIndex(stage->iters, it),
282                 it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer);
283   CopyOnWrite()->transform_steps.push_back(step);
284   return step->ApplyToState(this);
285 }
286 
follow_split(int stage_id,const Iterator & it,int src_step_id,int n_split)287 Array<Iterator> State::follow_split(int stage_id, const Iterator& it, int src_step_id,
288                                     int n_split) {
289   const Stage& stage = operator->()->stages[stage_id];
290   FollowSplitStep step =
291       FollowSplitStep(stage_id, GetIndex(stage->iters, it), src_step_id, n_split);
292   CopyOnWrite()->transform_steps.push_back(step);
293   return step->ApplyToState(this);
294 }
295 
follow_fused_split(int stage_id,const Iterator & it,const Array<Integer> & src_step_ids,int level,bool factor_or_nparts)296 Array<Iterator> State::follow_fused_split(int stage_id, const Iterator& it,
297                                           const Array<Integer>& src_step_ids, int level,
298                                           bool factor_or_nparts) {
299   const Stage& stage = operator->()->stages[stage_id];
300   FollowFusedSplitStep step = FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it),
301                                                    src_step_ids, level, factor_or_nparts);
302   CopyOnWrite()->transform_steps.push_back(step);
303   return step->ApplyToState(this);
304 }
305 
storage_align(int stage_id,const Iterator & it,int factor,int offset)306 void State::storage_align(int stage_id, const Iterator& it, int factor, int offset) {
307   const Stage& stage = operator->()->stages[stage_id];
308   StorageAlignStep step = StorageAlignStep(stage_id, GetIndex(stage->iters, it), factor, offset);
309   CopyOnWrite()->transform_steps.push_back(step);
310   return step->ApplyToState(this);
311 }
312 
compute_at(int stage_id,int target_stage_id,const Iterator & target_iter)313 void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) {
314   const Stage& target_stage = operator->()->stages[target_stage_id];
315   ComputeAtStep step =
316       ComputeAtStep(stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter));
317   CopyOnWrite()->transform_steps.push_back(step);
318   step->ApplyToState(this);
319 }
320 
compute_inline(int stage_id)321 void State::compute_inline(int stage_id) {
322   ComputeInlineStep step = ComputeInlineStep(stage_id);
323   CopyOnWrite()->transform_steps.push_back(step);
324   step->ApplyToState(this);
325 }
326 
compute_root(int stage_id)327 void State::compute_root(int stage_id) {
328   ComputeRootStep step = ComputeRootStep(stage_id);
329   CopyOnWrite()->transform_steps.push_back(step);
330   step->ApplyToState(this);
331 }
332 
cache_read(int stage_id,const String & scope_name,const Array<Integer> & reader_stage_ids,const ComputeDAG & dag)333 int State::cache_read(int stage_id, const String& scope_name,
334                       const Array<Integer>& reader_stage_ids, const ComputeDAG& dag) {
335   CacheReadStep step = CacheReadStep(stage_id, scope_name, reader_stage_ids);
336   CopyOnWrite()->transform_steps.push_back(step);
337   return step->ApplyToState(this, dag);
338 }
339 
cache_write(int stage_id,const String & scope_name,const ComputeDAG & dag)340 int State::cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag) {
341   CacheWriteStep step = CacheWriteStep(stage_id, scope_name);
342   CopyOnWrite()->transform_steps.push_back(step);
343   return step->ApplyToState(this, dag);
344 }
345 
rfactor(int stage_id,const Iterator & it,int factor_iter_id,const ComputeDAG & dag)346 int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag) {
347   const Stage& stage = operator->()->stages[stage_id];
348   RfactorStep step = RfactorStep(stage_id, GetIndex(stage->iters, it), factor_iter_id);
349   CopyOnWrite()->transform_steps.push_back(step);
350   return step->ApplyToState(this, dag);
351 }
352 
353 // Print stage to ostream
PrintStage(std::ostream * os,int stage_id,const State & state,size_t base_indent,bool delete_trivial_loop)354 void PrintStage(std::ostream* os, int stage_id, const State& state, size_t base_indent,
355                 bool delete_trivial_loop) {
356   const Stage& stage = state->stages[stage_id];
357 
358   if (stage->attrs.auto_unroll_max_step != 0) {
359     for (size_t j = 0; j < base_indent; ++j) {
360       *os << " ";
361     }
362     *os << stage->op->name << " auto_unroll: " << stage->attrs.auto_unroll_max_step << "\n";
363   }
364   if (stage->attrs.storage_offset != 0) {
365     for (size_t j = 0; j < base_indent; ++j) {
366       *os << " ";
367     }
368     *os << stage->op->name << " storage_offset: " << stage->attrs.storage_offset << "\n";
369   }
370 
371   size_t indent = 0;
372   for (size_t i = 0; i < stage->iters.size(); ++i) {
373     const Iterator& iter = stage->iters[i];
374 
375     if (!(delete_trivial_loop && iter->range.defined() && is_one(iter->range->extent))) {
376       for (size_t j = 0; j < base_indent + indent; ++j) {
377         *os << " ";
378       }
379       *os << IteratorAnnotationString[static_cast<int>(iter->annotation)] << " ";
380       if (iter->range.defined()) {
381         *os << iter->name << " (" << iter->range->min << "," << iter->range->extent << ")";
382       } else {
383         *os << iter->name << " (None)";
384       }
385       *os << "\n";
386 
387       indent += 2;
388     }
389 
390     if (state.defined()) {
391       IterKey iter_key(stage_id, i);
392       auto pair = state->attach_map->iter_to_attached_stages.find(iter_key);
393       if (pair != state->attach_map->iter_to_attached_stages.end()) {
394         // Print the attached stage
395         for (const auto& attach_stage_id : pair->second) {
396           PrintStage(os, attach_stage_id, state, base_indent + indent, delete_trivial_loop);
397         }
398       }
399     }
400   }
401 
402   for (size_t j = 0; j < base_indent + indent; ++j) {
403     *os << " ";
404   }
405   *os << stage->op->name << " = ...\n";
406 }
407 
408 // Print state to ostream
PrintState(std::ostream * os,const State & state,bool delete_trivial_loop)409 void PrintState(std::ostream* os, const State& state, bool delete_trivial_loop) {
410   // Gather placeholders
411   Array<String> placeholders;
412   for (const auto& stage : state->stages) {
413     if (stage->op_type == StageKind::kPlaceholder) {
414       placeholders.push_back(stage->op->name);
415     }
416   }
417 
418   *os << "Placeholder: ";
419   for (size_t i = 0; i < placeholders.size(); ++i) {
420     *os << placeholders[i];
421     if (i != placeholders.size() - 1) {
422       *os << ", ";
423     }
424   }
425   *os << "\n";
426 
427   // Print all stages
428   for (size_t i = 0; i < state->stages.size(); ++i) {
429     const Stage& stage = state->stages[i];
430     if (stage->op_type == StageKind::kPlaceholder) {
431       continue;
432     } else if (stage->op_type == StageKind::kCompute) {
433       if (stage->compute_at == ComputeAtKind::kRoot) {
434         PrintStage(os, i, state, 0, delete_trivial_loop);
435       }
436     } else {
437       LOG(FATAL) << "Invalid op type";
438     }
439   }
440 }
441 
ToStr(bool delete_trivial_loop) const442 String State::ToStr(bool delete_trivial_loop) const {
443   std::ostringstream os;
444   PrintState(&os, (*this), delete_trivial_loop);
445   return os.str();
446 }
447 
448 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
__anon42f1274b0102(const ObjectRef& ref, ReprPrinter* p) 449     .set_dispatch<StateNode>([](const ObjectRef& ref, ReprPrinter* p) {
450       PrintState(&p->stream, tvm::Downcast<State>(ref), true);
451     });
452 
453 /********** State interface API for ffi **********/
454 TVM_REGISTER_GLOBAL("auto_scheduler.StateBind")
__anon42f1274b0202(State state, int stage_id, const Iterator& it, int thread_type) 455     .set_body_typed([](State state, int stage_id, const Iterator& it, int thread_type) {
456       const auto& res = state.bind(stage_id, it, IteratorAnnotation(thread_type));
457       return Array<ObjectRef>{state, res};
458     });
459 
460 TVM_REGISTER_GLOBAL("auto_scheduler.StateParallel")
__anon42f1274b0302(State state, int stage_id, const Iterator& it) 461     .set_body_typed([](State state, int stage_id, const Iterator& it) {
462       const auto& res = state.parallel(stage_id, it);
463       return Array<ObjectRef>{state, res};
464     });
465 
466 TVM_REGISTER_GLOBAL("auto_scheduler.StateUnroll")
__anon42f1274b0402(State state, int stage_id, const Iterator& it, int max_unroll) 467     .set_body_typed([](State state, int stage_id, const Iterator& it, int max_unroll) {
468       const auto& res = state.unroll(stage_id, it, max_unroll);
469       return Array<ObjectRef>{state, res};
470     });
471 
472 TVM_REGISTER_GLOBAL("auto_scheduler.StateVectorize")
__anon42f1274b0502(State state, int stage_id, const Iterator& it) 473     .set_body_typed([](State state, int stage_id, const Iterator& it) {
474       const auto& res = state.vectorize(stage_id, it);
475       return Array<ObjectRef>{state, res};
476     });
477 
478 TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse")
__anon42f1274b0602(State state, int stage_id, const Array<Iterator>& iters) 479     .set_body_typed([](State state, int stage_id, const Array<Iterator>& iters) {
480       const auto& res = state.fuse(stage_id, iters);
481       return Array<ObjectRef>{state, res};
482     });
483 
484 TVM_REGISTER_GLOBAL("auto_scheduler.StatePragma")
__anon42f1274b0702(State state, int stage_id, const Iterator& it, const String& pragma_type) 485     .set_body_typed([](State state, int stage_id, const Iterator& it, const String& pragma_type) {
486       state.pragma(stage_id, it, pragma_type);
487       return state;
488     });
489 
490 TVM_REGISTER_GLOBAL("auto_scheduler.StateReorder")
__anon42f1274b0802(State state, int stage_id, const Array<Iterator>& order) 491     .set_body_typed([](State state, int stage_id, const Array<Iterator>& order) {
492       state.reorder(stage_id, order);
493       return state;
494     });
495 
496 TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit")
497     .set_body_typed([](State state, int stage_id, const Iterator& it,
__anon42f1274b0902(State state, int stage_id, const Iterator& it, const Array<Optional<Integer>>& lengths, bool inner_to_outer) 498                        const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
499       const auto& res = state.split(stage_id, it, lengths, inner_to_outer);
500       return Array<ObjectRef>{state, res};
501     });
502 
503 TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit")
504     .set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id,
__anon42f1274b0a02(State state, int stage_id, const Iterator& it, int src_step_id, int n_split) 505                        int n_split) {
506       const auto& res = state.follow_split(stage_id, it, src_step_id, n_split);
507       return Array<ObjectRef>{state, Array<Iterator>(res)};
508     });
509 
510 TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit")
511     .set_body_typed([](State state, int stage_id, const Iterator& it,
__anon42f1274b0b02(State state, int stage_id, const Iterator& it, const Array<Integer>& src_step_ids, int level, bool factor_or_nparts) 512                        const Array<Integer>& src_step_ids, int level, bool factor_or_nparts) {
513       const auto& res =
514           state.follow_fused_split(stage_id, it, src_step_ids, level, factor_or_nparts);
515       return Array<ObjectRef>{state, Array<Iterator>(res)};
516     });
517 
518 TVM_REGISTER_GLOBAL("auto_scheduler.StateStorageAlign")
__anon42f1274b0c02(State state, int stage_id, const Iterator& it, int factor, int offset) 519     .set_body_typed([](State state, int stage_id, const Iterator& it, int factor, int offset) {
520       state.storage_align(stage_id, it, factor, offset);
521       return state;
522     });
523 
524 TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt")
525     .set_body_typed([](State state, int stage_id, int target_stage_id,
__anon42f1274b0d02(State state, int stage_id, int target_stage_id, const Iterator& target_iter) 526                        const Iterator& target_iter) {
527       state.compute_at(stage_id, target_stage_id, target_iter);
528       return state;
529     });
530 
531 TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeInline")
__anon42f1274b0e02(State state, int stage_id) 532     .set_body_typed([](State state, int stage_id) {
533       state.compute_inline(stage_id);
534       return state;
535     });
536 
537 TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeRoot")
__anon42f1274b0f02(State state, int stage_id) 538     .set_body_typed([](State state, int stage_id) {
539       state.compute_root(stage_id);
540       return state;
541     });
542 
543 TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheRead")
544     .set_body_typed([](State state, int stage_id, const String& scope_name,
__anon42f1274b1002(State state, int stage_id, const String& scope_name, const Array<Integer>& reader_stage_ids, const ComputeDAG& dag) 545                        const Array<Integer>& reader_stage_ids, const ComputeDAG& dag) {
546       int res = state.cache_read(stage_id, scope_name, reader_stage_ids, dag);
547       return Array<ObjectRef>{state, Integer(res)};
548     });
549 
550 TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheWrite")
551     .set_body_typed([](State state, int stage_id, const String& scope_name,
__anon42f1274b1102(State state, int stage_id, const String& scope_name, const ComputeDAG& task_dag) 552                        const ComputeDAG& task_dag) {
553       int res = state.cache_write(stage_id, scope_name, task_dag);
554       return Array<ObjectRef>{state, Integer(res)};
555     });
556 
557 TVM_REGISTER_GLOBAL("auto_scheduler.StateRfactor")
558     .set_body_typed([](State state, int stage_id, const Iterator& it, int factor_iter_id,
__anon42f1274b1202(State state, int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag) 559                        const ComputeDAG& dag) {
560       int res = state.rfactor(stage_id, it, factor_iter_id, dag);
561       return Array<ObjectRef>{state, Integer(res)};
562     });
563 
__anon42f1274b1302(State state1, State state2) 564 TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual").set_body_typed([](State state1, State state2) {
565   return std::equal_to<State>()(state1, state2);
566 });
567 
568 }  // namespace auto_scheduler
569 }  // namespace tvm
570