1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file auto_scheduler/loop_state.cc
22 * \brief An lightweight IR (intermediate representation) for loop structures.
23 * see auto_scheduler/loop_state.h for more explanation.
24 */
25
26 #include <tvm/auto_scheduler/compute_dag.h>
27 #include <tvm/auto_scheduler/loop_state.h>
28 #include <tvm/auto_scheduler/transform_step.h>
29 #include <tvm/runtime/registry.h>
30 #include <tvm/te/operation.h>
31
32 #include <utility>
33
34 #include "utils.h"
35
36 namespace tvm {
37 namespace auto_scheduler {
38
39 TVM_REGISTER_OBJECT_TYPE(StepNode);
40 TVM_REGISTER_NODE_TYPE(StageNode);
41 TVM_REGISTER_NODE_TYPE(StateNode);
42 TVM_REGISTER_NODE_TYPE(IteratorNode);
43
44 /********** Iterator **********/
Iterator(String name,Range range,IteratorKind iter_kind,IteratorAnnotation annotation,const std::vector<Iterator> * orig_iters)45 Iterator::Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation,
46 const std::vector<Iterator>* orig_iters) {
47 auto node = make_object<IteratorNode>();
48 node->name = std::move(name);
49 node->range = std::move(range);
50 node->iter_kind = iter_kind;
51 node->annotation = annotation;
52 if (orig_iters != nullptr) {
53 node->orig_iters = *orig_iters;
54 }
55 data_ = std::move(node);
56 }
57
58 /********** Stage **********/
Stage(te::Operation op)59 Stage::Stage(te::Operation op) {
60 auto node = make_object<StageNode>();
61 if (op->IsInstance<te::ComputeOpNode>()) {
62 node->op_type = StageKind::kCompute;
63 auto* pop = op.as<te::ComputeOpNode>();
64 for (const auto& axis : pop->axis) {
65 node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom,
66 IteratorKind::kSpatial, IteratorAnnotation::kNone));
67 }
68 for (const auto& axis : pop->reduce_axis) {
69 node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom,
70 IteratorKind::kReduction, IteratorAnnotation::kNone));
71 }
72 } else if (op->IsInstance<te::PlaceholderOpNode>()) {
73 node->op_type = StageKind::kPlaceholder;
74 } else {
75 LOG(FATAL) << "Unsupported operator type" << op->_type_key;
76 }
77
78 node->compute_at = ComputeAtKind::kRoot;
79 node->op = std::move(op);
80 node->attrs.auto_unroll_max_step = 0;
81 node->attrs.storage_offset = 0;
82 data_ = std::move(node);
83 }
84
Stage(te::Operation op,StageKind op_type,const Array<Iterator> & iters,ComputeAtKind compute_at,StageAttributes attrs)85 Stage::Stage(te::Operation op, StageKind op_type, const Array<Iterator>& iters,
86 ComputeAtKind compute_at, StageAttributes attrs) {
87 auto node = make_object<StageNode>();
88 node->op = std::move(op);
89 node->op_type = op_type;
90 node->iters = iters;
91 node->compute_at = compute_at;
92 node->attrs = attrs;
93 data_ = std::move(node);
94 }
95
96 /********** AttachMap **********/
SetComputeAtIter(int stage_id,int target_stage_id,int target_iter_id)97 void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id) {
98 AttachMapNode* pnode = CopyOnWrite();
99
100 // Delete the current entry of this stage
101 DeleteStageEntry(pnode, stage_id);
102
103 // Store the new stage/iterator relations to map
104 IterKey iter_key(target_stage_id, target_iter_id);
105 pnode->stage_to_attach_iter[stage_id] = iter_key;
106 pnode->iter_to_attached_stages[iter_key].push_back(stage_id);
107 }
108
DeleteStage(int stage_id)109 void AttachMap::DeleteStage(int stage_id) {
110 AttachMapNode* pnode = CopyOnWrite();
111 // Delete the original stage entry
112 DeleteStageEntry(pnode, stage_id);
113 }
114
UpdateIters(const std::vector<IterKey> & original_iters,const std::vector<IterKey> & new_iters)115 void AttachMap::UpdateIters(const std::vector<IterKey>& original_iters,
116 const std::vector<IterKey>& new_iters) {
117 CHECK_EQ(original_iters.size(), new_iters.size());
118 AttachMapNode* pnode = CopyOnWrite();
119 std::unordered_map<IterKey, std::vector<StageKey>> new_iter_to_attached_stages;
120 for (size_t i = 0; i < original_iters.size(); ++i) {
121 auto entry = pnode->iter_to_attached_stages.find(original_iters[i]);
122 // We get <IterKey, std::vector<StageKey>> from this map
123 if (entry == pnode->iter_to_attached_stages.end()) {
124 // Skip if this iterator does not have any attach relations
125 continue;
126 }
127
128 // Update the attaching target of an stage to the new iter in `stage_to_attach_iter`
129 for (const auto& s : entry->second) {
130 pnode->stage_to_attach_iter[s] = new_iters[i];
131 }
132
133 // Remove the original iterator relation from `iter_to_attached_stages` and add the new
134 // iterator to it
135 std::vector<int> attached_stages = std::move(entry->second);
136 pnode->iter_to_attached_stages.erase(entry);
137 new_iter_to_attached_stages[new_iters[i]] = std::move(attached_stages);
138 }
139
140 // Update new entries
141 for (auto& it : new_iter_to_attached_stages) {
142 pnode->iter_to_attached_stages[it.first] = std::move(it.second);
143 }
144 }
145
DeleteStageEntry(AttachMapNode * pnode,int stage_id)146 void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) {
147 auto old_entry = pnode->stage_to_attach_iter.find(stage_id);
148 // We get <StageKey, IterKey> from this map
149 if (old_entry != pnode->stage_to_attach_iter.end()) {
150 // Delete the stage in `iter_to_attached_stages`, if the corresponding iterator does not have
151 // any attached stage, delete this iterm too
152 auto entry2 = pnode->iter_to_attached_stages.find(old_entry->second);
153 // We get <IterKey, std::vector<StageKey>> from this map
154 FindAndDeleteItem(&entry2->second, stage_id);
155 if (entry2->second.size() == 0) {
156 pnode->iter_to_attached_stages.erase(entry2);
157 }
158 // Delete the stage in `stage_to_attach_iter`
159 pnode->stage_to_attach_iter.erase(old_entry);
160 }
161 }
162
ApplyStageIdOffset(int start_id,int offset) const163 AttachMap AttachMap::ApplyStageIdOffset(int start_id, int offset) const {
164 AttachMap map = AttachMap(make_object<AttachMapNode>());
165 auto pmap = map.CopyOnWrite();
166 for (const auto& x : operator->()->stage_to_attach_iter) {
167 auto key = x.first;
168 if (key >= start_id) {
169 key += offset;
170 }
171 auto value = x.second;
172 if (value.first >= start_id) {
173 value.first += offset;
174 }
175 pmap->stage_to_attach_iter.insert(std::make_pair(key, value));
176 }
177 for (const auto& x : operator->()->iter_to_attached_stages) {
178 auto key = x.first;
179 if (key.first >= start_id) {
180 key.first += offset;
181 }
182 auto value = x.second;
183 for (auto& i : value) {
184 if (i >= start_id) {
185 i += offset;
186 }
187 }
188 pmap->iter_to_attached_stages.insert(std::make_pair(key, value));
189 }
190 return map;
191 }
192
193 /********** State **********/
State(const Array<te::Operation> & ops)194 State::State(const Array<te::Operation>& ops) {
195 auto node = make_object<StateNode>();
196 for (const auto& op : ops) {
197 node->stages.push_back(Stage(op));
198 }
199 node->attach_map = AttachMap(make_object<AttachMapNode>());
200 node->concrete = true;
201 data_ = std::move(node);
202 }
203
204 /********** Schedule primitives apis for state **********/
bind(int stage_id,const Iterator & it,IteratorAnnotation thread_type)205 Iterator State::bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type) {
206 const Stage& stage = operator->()->stages[stage_id];
207 if (thread_type < IteratorAnnotation::kVThread || thread_type > IteratorAnnotation::kThreadZ) {
208 LOG(FATAL) << "thread_type error, valid: kVThread, kBlockX, kBlockY, "
209 << "kThreadX, kThreadY, kBlockZ, kThreadZ";
210 }
211 AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it), thread_type);
212 CopyOnWrite()->transform_steps.push_back(step);
213 return step->ApplyToState(this);
214 }
215
parallel(int stage_id,const Iterator & it)216 Iterator State::parallel(int stage_id, const Iterator& it) {
217 const Stage& stage = operator->()->stages[stage_id];
218 AnnotationStep step =
219 AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kParallel);
220 CopyOnWrite()->transform_steps.push_back(step);
221 return step->ApplyToState(this);
222 }
223
unroll(int stage_id,const Iterator & it,int max_unroll)224 Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) {
225 const Stage& stage = operator->()->stages[stage_id];
226
227 // Don't unroll if the extent is larger than max_unroll
228 if (max_unroll != -1 && it->range.defined()) {
229 if (auto imm = it->range->extent.as<IntImmNode>()) {
230 if (imm->value > max_unroll) {
231 return it;
232 }
233 }
234 }
235
236 AnnotationStep step =
237 AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kUnroll);
238 CopyOnWrite()->transform_steps.push_back(step);
239 return step->ApplyToState(this);
240 }
241
vectorize(int stage_id,const Iterator & it)242 Iterator State::vectorize(int stage_id, const Iterator& it) {
243 const Stage& stage = operator->()->stages[stage_id];
244 AnnotationStep step =
245 AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kVectorize);
246 CopyOnWrite()->transform_steps.push_back(step);
247 return step->ApplyToState(this);
248 }
249
fuse(int stage_id,const Array<Iterator> & iters)250 Iterator State::fuse(int stage_id, const Array<Iterator>& iters) {
251 const Stage& stage = operator->()->stages[stage_id];
252 Array<Integer> indices;
253 GetIndices(stage->iters, iters, &indices);
254 FuseStep step = FuseStep(stage_id, indices);
255 CopyOnWrite()->transform_steps.push_back(step);
256 return step->ApplyToState(this);
257 }
258
pragma(int stage_id,const Iterator & it,const String & pragma_type)259 void State::pragma(int stage_id, const Iterator& it, const String& pragma_type) {
260 const Stage& stage = operator->()->stages[stage_id];
261 PragmaStep step = PragmaStep(stage_id, GetIndex(stage->iters, it), pragma_type);
262 CopyOnWrite()->transform_steps.push_back(step);
263 return step->ApplyToState(this);
264 }
265
reorder(int stage_id,const Array<Iterator> & order)266 void State::reorder(int stage_id, const Array<Iterator>& order) {
267 const Stage& stage = operator->()->stages[stage_id];
268 CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators "
269 << "should be specified";
270 Array<Integer> after_ids;
271 GetIndices(stage->iters, order, &after_ids);
272 ReorderStep step = ReorderStep(stage_id, after_ids);
273 CopyOnWrite()->transform_steps.push_back(step);
274 step->ApplyToState(this);
275 }
276
split(int stage_id,const Iterator & it,const Array<Optional<Integer>> & lengths,bool inner_to_outer)277 Array<Iterator> State::split(int stage_id, const Iterator& it,
278 const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
279 const Stage& stage = operator->()->stages[stage_id];
280 SplitStep step =
281 SplitStep(stage_id, GetIndex(stage->iters, it),
282 it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer);
283 CopyOnWrite()->transform_steps.push_back(step);
284 return step->ApplyToState(this);
285 }
286
follow_split(int stage_id,const Iterator & it,int src_step_id,int n_split)287 Array<Iterator> State::follow_split(int stage_id, const Iterator& it, int src_step_id,
288 int n_split) {
289 const Stage& stage = operator->()->stages[stage_id];
290 FollowSplitStep step =
291 FollowSplitStep(stage_id, GetIndex(stage->iters, it), src_step_id, n_split);
292 CopyOnWrite()->transform_steps.push_back(step);
293 return step->ApplyToState(this);
294 }
295
follow_fused_split(int stage_id,const Iterator & it,const Array<Integer> & src_step_ids,int level,bool factor_or_nparts)296 Array<Iterator> State::follow_fused_split(int stage_id, const Iterator& it,
297 const Array<Integer>& src_step_ids, int level,
298 bool factor_or_nparts) {
299 const Stage& stage = operator->()->stages[stage_id];
300 FollowFusedSplitStep step = FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it),
301 src_step_ids, level, factor_or_nparts);
302 CopyOnWrite()->transform_steps.push_back(step);
303 return step->ApplyToState(this);
304 }
305
storage_align(int stage_id,const Iterator & it,int factor,int offset)306 void State::storage_align(int stage_id, const Iterator& it, int factor, int offset) {
307 const Stage& stage = operator->()->stages[stage_id];
308 StorageAlignStep step = StorageAlignStep(stage_id, GetIndex(stage->iters, it), factor, offset);
309 CopyOnWrite()->transform_steps.push_back(step);
310 return step->ApplyToState(this);
311 }
312
compute_at(int stage_id,int target_stage_id,const Iterator & target_iter)313 void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) {
314 const Stage& target_stage = operator->()->stages[target_stage_id];
315 ComputeAtStep step =
316 ComputeAtStep(stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter));
317 CopyOnWrite()->transform_steps.push_back(step);
318 step->ApplyToState(this);
319 }
320
compute_inline(int stage_id)321 void State::compute_inline(int stage_id) {
322 ComputeInlineStep step = ComputeInlineStep(stage_id);
323 CopyOnWrite()->transform_steps.push_back(step);
324 step->ApplyToState(this);
325 }
326
compute_root(int stage_id)327 void State::compute_root(int stage_id) {
328 ComputeRootStep step = ComputeRootStep(stage_id);
329 CopyOnWrite()->transform_steps.push_back(step);
330 step->ApplyToState(this);
331 }
332
cache_read(int stage_id,const String & scope_name,const Array<Integer> & reader_stage_ids,const ComputeDAG & dag)333 int State::cache_read(int stage_id, const String& scope_name,
334 const Array<Integer>& reader_stage_ids, const ComputeDAG& dag) {
335 CacheReadStep step = CacheReadStep(stage_id, scope_name, reader_stage_ids);
336 CopyOnWrite()->transform_steps.push_back(step);
337 return step->ApplyToState(this, dag);
338 }
339
cache_write(int stage_id,const String & scope_name,const ComputeDAG & dag)340 int State::cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag) {
341 CacheWriteStep step = CacheWriteStep(stage_id, scope_name);
342 CopyOnWrite()->transform_steps.push_back(step);
343 return step->ApplyToState(this, dag);
344 }
345
rfactor(int stage_id,const Iterator & it,int factor_iter_id,const ComputeDAG & dag)346 int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag) {
347 const Stage& stage = operator->()->stages[stage_id];
348 RfactorStep step = RfactorStep(stage_id, GetIndex(stage->iters, it), factor_iter_id);
349 CopyOnWrite()->transform_steps.push_back(step);
350 return step->ApplyToState(this, dag);
351 }
352
353 // Print stage to ostream
PrintStage(std::ostream * os,int stage_id,const State & state,size_t base_indent,bool delete_trivial_loop)354 void PrintStage(std::ostream* os, int stage_id, const State& state, size_t base_indent,
355 bool delete_trivial_loop) {
356 const Stage& stage = state->stages[stage_id];
357
358 if (stage->attrs.auto_unroll_max_step != 0) {
359 for (size_t j = 0; j < base_indent; ++j) {
360 *os << " ";
361 }
362 *os << stage->op->name << " auto_unroll: " << stage->attrs.auto_unroll_max_step << "\n";
363 }
364 if (stage->attrs.storage_offset != 0) {
365 for (size_t j = 0; j < base_indent; ++j) {
366 *os << " ";
367 }
368 *os << stage->op->name << " storage_offset: " << stage->attrs.storage_offset << "\n";
369 }
370
371 size_t indent = 0;
372 for (size_t i = 0; i < stage->iters.size(); ++i) {
373 const Iterator& iter = stage->iters[i];
374
375 if (!(delete_trivial_loop && iter->range.defined() && is_one(iter->range->extent))) {
376 for (size_t j = 0; j < base_indent + indent; ++j) {
377 *os << " ";
378 }
379 *os << IteratorAnnotationString[static_cast<int>(iter->annotation)] << " ";
380 if (iter->range.defined()) {
381 *os << iter->name << " (" << iter->range->min << "," << iter->range->extent << ")";
382 } else {
383 *os << iter->name << " (None)";
384 }
385 *os << "\n";
386
387 indent += 2;
388 }
389
390 if (state.defined()) {
391 IterKey iter_key(stage_id, i);
392 auto pair = state->attach_map->iter_to_attached_stages.find(iter_key);
393 if (pair != state->attach_map->iter_to_attached_stages.end()) {
394 // Print the attached stage
395 for (const auto& attach_stage_id : pair->second) {
396 PrintStage(os, attach_stage_id, state, base_indent + indent, delete_trivial_loop);
397 }
398 }
399 }
400 }
401
402 for (size_t j = 0; j < base_indent + indent; ++j) {
403 *os << " ";
404 }
405 *os << stage->op->name << " = ...\n";
406 }
407
408 // Print state to ostream
PrintState(std::ostream * os,const State & state,bool delete_trivial_loop)409 void PrintState(std::ostream* os, const State& state, bool delete_trivial_loop) {
410 // Gather placeholders
411 Array<String> placeholders;
412 for (const auto& stage : state->stages) {
413 if (stage->op_type == StageKind::kPlaceholder) {
414 placeholders.push_back(stage->op->name);
415 }
416 }
417
418 *os << "Placeholder: ";
419 for (size_t i = 0; i < placeholders.size(); ++i) {
420 *os << placeholders[i];
421 if (i != placeholders.size() - 1) {
422 *os << ", ";
423 }
424 }
425 *os << "\n";
426
427 // Print all stages
428 for (size_t i = 0; i < state->stages.size(); ++i) {
429 const Stage& stage = state->stages[i];
430 if (stage->op_type == StageKind::kPlaceholder) {
431 continue;
432 } else if (stage->op_type == StageKind::kCompute) {
433 if (stage->compute_at == ComputeAtKind::kRoot) {
434 PrintStage(os, i, state, 0, delete_trivial_loop);
435 }
436 } else {
437 LOG(FATAL) << "Invalid op type";
438 }
439 }
440 }
441
ToStr(bool delete_trivial_loop) const442 String State::ToStr(bool delete_trivial_loop) const {
443 std::ostringstream os;
444 PrintState(&os, (*this), delete_trivial_loop);
445 return os.str();
446 }
447
448 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
__anon42f1274b0102(const ObjectRef& ref, ReprPrinter* p) 449 .set_dispatch<StateNode>([](const ObjectRef& ref, ReprPrinter* p) {
450 PrintState(&p->stream, tvm::Downcast<State>(ref), true);
451 });
452
453 /********** State interface API for ffi **********/
454 TVM_REGISTER_GLOBAL("auto_scheduler.StateBind")
__anon42f1274b0202(State state, int stage_id, const Iterator& it, int thread_type) 455 .set_body_typed([](State state, int stage_id, const Iterator& it, int thread_type) {
456 const auto& res = state.bind(stage_id, it, IteratorAnnotation(thread_type));
457 return Array<ObjectRef>{state, res};
458 });
459
460 TVM_REGISTER_GLOBAL("auto_scheduler.StateParallel")
__anon42f1274b0302(State state, int stage_id, const Iterator& it) 461 .set_body_typed([](State state, int stage_id, const Iterator& it) {
462 const auto& res = state.parallel(stage_id, it);
463 return Array<ObjectRef>{state, res};
464 });
465
466 TVM_REGISTER_GLOBAL("auto_scheduler.StateUnroll")
__anon42f1274b0402(State state, int stage_id, const Iterator& it, int max_unroll) 467 .set_body_typed([](State state, int stage_id, const Iterator& it, int max_unroll) {
468 const auto& res = state.unroll(stage_id, it, max_unroll);
469 return Array<ObjectRef>{state, res};
470 });
471
472 TVM_REGISTER_GLOBAL("auto_scheduler.StateVectorize")
__anon42f1274b0502(State state, int stage_id, const Iterator& it) 473 .set_body_typed([](State state, int stage_id, const Iterator& it) {
474 const auto& res = state.vectorize(stage_id, it);
475 return Array<ObjectRef>{state, res};
476 });
477
478 TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse")
__anon42f1274b0602(State state, int stage_id, const Array<Iterator>& iters) 479 .set_body_typed([](State state, int stage_id, const Array<Iterator>& iters) {
480 const auto& res = state.fuse(stage_id, iters);
481 return Array<ObjectRef>{state, res};
482 });
483
484 TVM_REGISTER_GLOBAL("auto_scheduler.StatePragma")
__anon42f1274b0702(State state, int stage_id, const Iterator& it, const String& pragma_type) 485 .set_body_typed([](State state, int stage_id, const Iterator& it, const String& pragma_type) {
486 state.pragma(stage_id, it, pragma_type);
487 return state;
488 });
489
490 TVM_REGISTER_GLOBAL("auto_scheduler.StateReorder")
__anon42f1274b0802(State state, int stage_id, const Array<Iterator>& order) 491 .set_body_typed([](State state, int stage_id, const Array<Iterator>& order) {
492 state.reorder(stage_id, order);
493 return state;
494 });
495
496 TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit")
497 .set_body_typed([](State state, int stage_id, const Iterator& it,
__anon42f1274b0902(State state, int stage_id, const Iterator& it, const Array<Optional<Integer>>& lengths, bool inner_to_outer) 498 const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
499 const auto& res = state.split(stage_id, it, lengths, inner_to_outer);
500 return Array<ObjectRef>{state, res};
501 });
502
503 TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit")
504 .set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id,
__anon42f1274b0a02(State state, int stage_id, const Iterator& it, int src_step_id, int n_split) 505 int n_split) {
506 const auto& res = state.follow_split(stage_id, it, src_step_id, n_split);
507 return Array<ObjectRef>{state, Array<Iterator>(res)};
508 });
509
510 TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit")
511 .set_body_typed([](State state, int stage_id, const Iterator& it,
__anon42f1274b0b02(State state, int stage_id, const Iterator& it, const Array<Integer>& src_step_ids, int level, bool factor_or_nparts) 512 const Array<Integer>& src_step_ids, int level, bool factor_or_nparts) {
513 const auto& res =
514 state.follow_fused_split(stage_id, it, src_step_ids, level, factor_or_nparts);
515 return Array<ObjectRef>{state, Array<Iterator>(res)};
516 });
517
518 TVM_REGISTER_GLOBAL("auto_scheduler.StateStorageAlign")
__anon42f1274b0c02(State state, int stage_id, const Iterator& it, int factor, int offset) 519 .set_body_typed([](State state, int stage_id, const Iterator& it, int factor, int offset) {
520 state.storage_align(stage_id, it, factor, offset);
521 return state;
522 });
523
524 TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt")
525 .set_body_typed([](State state, int stage_id, int target_stage_id,
__anon42f1274b0d02(State state, int stage_id, int target_stage_id, const Iterator& target_iter) 526 const Iterator& target_iter) {
527 state.compute_at(stage_id, target_stage_id, target_iter);
528 return state;
529 });
530
531 TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeInline")
__anon42f1274b0e02(State state, int stage_id) 532 .set_body_typed([](State state, int stage_id) {
533 state.compute_inline(stage_id);
534 return state;
535 });
536
537 TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeRoot")
__anon42f1274b0f02(State state, int stage_id) 538 .set_body_typed([](State state, int stage_id) {
539 state.compute_root(stage_id);
540 return state;
541 });
542
543 TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheRead")
544 .set_body_typed([](State state, int stage_id, const String& scope_name,
__anon42f1274b1002(State state, int stage_id, const String& scope_name, const Array<Integer>& reader_stage_ids, const ComputeDAG& dag) 545 const Array<Integer>& reader_stage_ids, const ComputeDAG& dag) {
546 int res = state.cache_read(stage_id, scope_name, reader_stage_ids, dag);
547 return Array<ObjectRef>{state, Integer(res)};
548 });
549
550 TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheWrite")
551 .set_body_typed([](State state, int stage_id, const String& scope_name,
__anon42f1274b1102(State state, int stage_id, const String& scope_name, const ComputeDAG& task_dag) 552 const ComputeDAG& task_dag) {
553 int res = state.cache_write(stage_id, scope_name, task_dag);
554 return Array<ObjectRef>{state, Integer(res)};
555 });
556
557 TVM_REGISTER_GLOBAL("auto_scheduler.StateRfactor")
558 .set_body_typed([](State state, int stage_id, const Iterator& it, int factor_iter_id,
__anon42f1274b1202(State state, int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag) 559 const ComputeDAG& dag) {
560 int res = state.rfactor(stage_id, it, factor_iter_id, dag);
561 return Array<ObjectRef>{state, Integer(res)};
562 });
563
__anon42f1274b1302(State state1, State state2) 564 TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual").set_body_typed([](State state1, State state2) {
565 return std::equal_to<State>()(state1, state2);
566 });
567
568 } // namespace auto_scheduler
569 } // namespace tvm
570