1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file auto_scheduler/search_policy/utils.h
22  * \brief Common utilities for search policies.
23  */
24 
25 #ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_UTILS_H_
26 #define TVM_AUTO_SCHEDULER_SEARCH_POLICY_UTILS_H_
27 
28 #include <dmlc/common.h>
29 #include <tvm/auto_scheduler/loop_state.h>
30 #include <tvm/auto_scheduler/search_policy.h>
31 #include <tvm/ir/expr.h>
32 #include <tvm/te/operation.h>
33 
34 #include <algorithm>
35 #include <condition_variable>
36 #include <set>
37 #include <string>
38 #include <tuple>
39 #include <unordered_map>
40 #include <unordered_set>
41 #include <utility>
42 #include <vector>
43 
44 #include "../utils.h"
45 
46 namespace tvm {
47 namespace auto_scheduler {
48 
49 /*! \brief Return whether the search task is targeting a CPU. */
IsCPUTask(const SearchTask & task)50 inline bool IsCPUTask(const SearchTask& task) {
51   return (task)->target->kind->device_type == kDLCPU;
52 }
53 
54 /*! \brief Return whether the search task is targeting a GPU. */
IsGPUTask(const SearchTask & task)55 inline bool IsGPUTask(const SearchTask& task) {
56   return (task)->target->kind->device_type == kDLGPU ||
57          (task)->target->kind->device_type == kDLOpenCL ||
58          (task)->target->kind->device_type == kDLVulkan ||
59          (task)->target->kind->device_type == kDLMetal ||
60          (task)->target->kind->device_type == kDLROCM ||
61          (task)->target->kind->device_type == kOpenGL;
62 }
63 
64 /*! \brief Return whether the search task is targeting a CUDA GPU. */
IsCUDATask(const SearchTask & task)65 inline bool IsCUDATask(const SearchTask& task) {
66   return (task)->target->kind->device_type == kDLGPU;
67 }
68 
69 /*! \brief Return whether the search task is targeting a OpenCL GPU. */
IsOpenCLTask(const SearchTask & task)70 inline bool IsOpenCLTask(const SearchTask& task) {
71   return (task)->target->kind->device_type == kDLOpenCL;
72 }
73 
74 /*! \brief Argsort. Order: largest to smallest */
75 template <typename T>
Argsort(const std::vector<T> & scores)76 inline std::vector<int> Argsort(const std::vector<T>& scores) {
77   std::vector<int> index;
78   index.reserve(scores.size());
79   for (size_t i = 0; i < scores.size(); ++i) {
80     index.push_back(i);
81   }
82   auto cmp = [&scores](int l, int r) { return scores[l] > scores[r]; };
83   std::sort(index.begin(), index.end(), cmp);
84   return index;
85 }
86 
87 /*! \brief Convert operation to stage id. */
OperationToStage(const te::Operation & op,const State & state)88 inline int OperationToStage(const te::Operation& op, const State& state) {
89   for (size_t i = 0; i < state->stages.size(); ++i) {
90     if (op == state->stages[i]->op) {
91       return i;
92     }
93   }
94   LOG(FATAL) << "Cannot find op: " << op;
95   return -1;
96 }
97 
98 /********** Get Parameters **********/
99 
100 /*! \brief Get an integer from a tvm str Map. */
GetIntParam(const Map<String,ObjectRef> & attr_dict,const std::string & key)101 inline int GetIntParam(const Map<String, ObjectRef>& attr_dict, const std::string& key) {
102   CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict;
103   auto pint = attr_dict[key].as<IntImmNode>();
104   CHECK(pint != nullptr);
105   return pint->value;
106 }
107 
108 /*! \brief Get a double from a tvm str Map. */
GetDoubleParam(const Map<String,ObjectRef> & attr_dict,const std::string & key)109 inline double GetDoubleParam(const Map<String, ObjectRef>& attr_dict, const std::string& key) {
110   CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict;
111   auto pdouble = attr_dict[key].as<FloatImmNode>();
112   CHECK(pdouble != nullptr);
113   return pdouble->value;
114 }
115 
116 /*! \brief Get a string from a tvm str Map. */
GetStringParam(const Map<String,ObjectRef> & attr_dict,const std::string & key)117 inline std::string GetStringParam(const Map<String, ObjectRef>& attr_dict, const std::string& key) {
118   CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict;
119   const auto& target = attr_dict[key];
120   if (auto pstr = target.as<StringImmNode>()) {
121     return pstr->value;
122   }
123   auto pstr = target.as<StringObj>();
124   CHECK(pstr != nullptr);
125   return pstr->data;
126 }
127 
128 /*! \brief Get a iterator name set from a tvm str Map. */
GetIterNameSetParam(const Map<String,ObjectRef> & attr_dict,const std::string & key)129 inline std::set<std::string> GetIterNameSetParam(const Map<String, ObjectRef>& attr_dict,
130                                                  const std::string& key) {
131   std::set<std::string> ret;
132   CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict;
133   auto names = attr_dict[key].as<ArrayNode>();
134   CHECK(names != nullptr);
135   for (const auto& name : *names) {
136     ret.insert(name.as<StringObj>()->data);
137   }
138   return ret;
139 }
140 
141 /********** Checks with ComputeDAG **********/
142 
143 /*! \brief Return whether an op is strictly-inlineable. */
IsStrictlyInlineable(const SearchTask & task,const State & state,int stage_id)144 inline bool IsStrictlyInlineable(const SearchTask& task, const State& state, int stage_id) {
145   if (state->current_compute_dag) {
146     return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.IsStrictlyInlineable(
147         state->stages[stage_id]->op);
148   } else {
149     return task->compute_dag->access_analyzer.IsStrictlyInlineable(state->stages[stage_id]->op);
150   }
151 }
152 
153 /*! \brief Return whether an op is an output op. */
IsOutputOp(const SearchTask & task,const State & state,int stage_id)154 inline bool IsOutputOp(const SearchTask& task, const State& state, int stage_id) {
155   if (state->current_compute_dag) {
156     return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.IsOutput(
157         state->stages[stage_id]->op);
158   } else {
159     return task->compute_dag->access_analyzer.IsOutput(state->stages[stage_id]->op);
160   }
161 }
162 
163 /*! \brief Return whether an op needs multi level tiling. */
NeedsMultilevelTiling(const SearchTask & task,const State & state,int stage_id)164 inline bool NeedsMultilevelTiling(const SearchTask& task, const State& state, int stage_id) {
165   if (state->current_compute_dag) {
166     return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.NeedsMultiLevelTiling(
167         state->stages[stage_id]->op);
168   } else {
169     return task->compute_dag->access_analyzer.NeedsMultiLevelTiling(state->stages[stage_id]->op);
170   }
171 }
172 
173 /*! \brief Get all consumers for a stage. This function propagates the relation for inlined ops. */
GetConsumers(const SearchTask & task,const State & state,int stage_id)174 inline std::set<int> GetConsumers(const SearchTask& task, const State& state, int stage_id) {
175   std::unordered_set<te::Operation, ObjectHash, ObjectEqual> consumers;
176   std::set<int> ret;
177 
178   if (state->current_compute_dag) {
179     consumers = state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.GetConsumers(
180         state, state->stages[stage_id]->op);
181   } else {
182     consumers = task->compute_dag->access_analyzer.GetConsumers(state, state->stages[stage_id]->op);
183   }
184 
185   for (const auto& op : consumers) {
186     ret.insert(OperationToStage(op, state));
187   }
188   return ret;
189 }
190 
191 /*! \brief Check if a stage has single consumer or all of its consumers share a common root, return
192  * the target consumer root or -1. */
GetSingleConsumerId(const SearchTask & task,const State & state,int stage_id)193 inline int GetSingleConsumerId(const SearchTask& task, const State& state, int stage_id) {
194   const std::set<int>& consumers = GetConsumers(task, state, stage_id);
195   if (consumers.empty()) {
196     return -1;
197   }
198 
199   if (consumers.size() == 1) {
200     return *consumers.begin();
201   } else {
202     // Check all consumers share a common root
203     int common_root_id = -1;
204     bool mismatch = false;
205     for (const auto& consumer_stage_id : consumers) {
206       int root_id = -1;
207       if (state->stages[consumer_stage_id]->compute_at == ComputeAtKind::kRoot) {
208         root_id = consumer_stage_id;
209       } else if (state->stages[consumer_stage_id]->compute_at == ComputeAtKind::kIter) {
210         root_id = state->attach_map->stage_to_attach_iter.at(consumer_stage_id).first;
211       } else {
212         LOG(FATAL) << "Invalid case";
213       }
214 
215       if (common_root_id == -1) {
216         common_root_id = root_id;
217       } else {
218         if (common_root_id != root_id) {
219           mismatch = true;
220           break;
221         }
222       }
223     }
224 
225     return mismatch ? -1 : common_root_id;
226   }
227 }
228 
229 /*! \brief Get all producers for a stage. This function propagates the relation for inlined ops. */
GetProducers(const SearchTask & task,const State & state,int stage_id)230 inline std::set<int> GetProducers(const SearchTask& task, const State& state, int stage_id) {
231   std::unordered_set<te::Operation, ObjectHash, ObjectEqual> producers;
232   std::set<int> ret;
233 
234   if (state->current_compute_dag) {
235     producers = state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.GetProducers(
236         state, state->stages[stage_id]->op);
237   } else {
238     producers = task->compute_dag->access_analyzer.GetProducers(state, state->stages[stage_id]->op);
239   }
240 
241   for (const auto& op : producers) {
242     ret.insert(OperationToStage(op, state));
243   }
244   return ret;
245 }
246 
247 /*! \brief Get all producers for a stage. This function DOES NOT propagates the relation for
248  * inlined ops. */
GetDirectProducers(const SearchTask & task,const State & state,int stage_id)249 inline std::set<int> GetDirectProducers(const SearchTask& task, const State& state, int stage_id) {
250   std::unordered_set<te::Operation, ObjectHash, ObjectEqual> producers;
251   std::set<int> ret;
252 
253   if (state->current_compute_dag) {
254     producers = state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.GetDirectProducers(
255         state->stages[stage_id]->op);
256   } else {
257     producers = task->compute_dag->access_analyzer.GetDirectProducers(state->stages[stage_id]->op);
258   }
259 
260   for (const auto& op : producers) {
261     ret.insert(OperationToStage(op, state));
262   }
263   return ret;
264 }
265 
266 /*! \brief Get the number of common outer iterators. This function propagates the relation for
267  * chains with multiple ops. */
GetNumCommonOuterIterator(const SearchTask & task,const State & state,int stage_id,int target_stage_id)268 inline int GetNumCommonOuterIterator(const SearchTask& task, const State& state, int stage_id,
269                                      int target_stage_id) {
270   if (state->current_compute_dag) {
271     return state->current_compute_dag.as<ComputeDAGNode>()
272         ->access_analyzer.GetNumCommonOuterIterator(state->stages[stage_id]->op,
273                                                     state->stages[target_stage_id]->op);
274   } else {
275     return task->compute_dag->access_analyzer.GetNumCommonOuterIterator(
276         state->stages[stage_id]->op, state->stages[target_stage_id]->op);
277   }
278 }
279 
280 /*! \brief Return whether two ops are elementwise-matched. */
ElementwiseMatch(const SearchTask & task,const State & state,int stage_id,int target_stage_id)281 inline bool ElementwiseMatch(const SearchTask& task, const State& state, int stage_id,
282                              int target_stage_id) {
283   const auto& op = state->stages[stage_id]->op;
284   const auto& target_op = state->stages[target_stage_id]->op;
285   if (state->current_compute_dag) {
286     return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.ElementWiseMatch(
287         op, target_op);
288   } else {
289     return task->compute_dag->access_analyzer.ElementWiseMatch(op, target_op);
290   }
291 }
292 
293 /********** Get informations from Stage/Iterator **********/
294 
295 /*! \brief Return the extent of an iterator. */
GetExtent(const Iterator & it)296 inline int64_t GetExtent(const Iterator& it) {
297   if (it->range.defined()) {
298     if (auto pint = it->range->extent.as<IntImmNode>()) {
299       return pint->value;
300     }
301   }
302   return -1;
303 }
304 
305 /*! \brief Compute the product of lengths of all space iters and all reduce iters, respectively. */
GetCumulativeSpaceAndReductionLength(const Stage & stage)306 inline std::pair<int64_t, int64_t> GetCumulativeSpaceAndReductionLength(const Stage& stage) {
307   int64_t cum_space_len = 1, cum_reduce_len = 1;
308   for (const auto& iter : stage->iters) {
309     if (iter->iter_kind == IteratorKind::kSpatial) {
310       cum_space_len *= GetExtent(iter);
311     } else if (iter->iter_kind == IteratorKind::kReduction) {
312       cum_reduce_len *= GetExtent(iter);
313     }
314   }
315   return std::make_pair(cum_space_len, cum_reduce_len);
316 }
317 
318 /*! \brief Return whether this stage needs rfactor. */
NeedsRfactor(const SearchTask & task,const State & state,int stage_id)319 inline bool NeedsRfactor(const SearchTask& task, const State& state, int stage_id) {
320   const auto& op = state->stages[stage_id]->op;
321   if (op->IsInstance<te::ComputeOpNode>()) {
322     // Compute the product of lengths of all space iters and all reduce iters
323     int cum_space_len, cum_reduce_len;
324     std::tie(cum_space_len, cum_reduce_len) =
325         GetCumulativeSpaceAndReductionLength(state->stages[stage_id]);
326 
327     if (NeedsMultilevelTiling(task, state, stage_id)) {
328       // Do not use rfactor if we have enough parallelism on space iters
329       if (cum_space_len > cum_reduce_len || cum_space_len > task->hardware_params->num_cores * 16) {
330         return false;
331       } else {
332         return true;
333       }
334     } else if (cum_reduce_len > 1) {
335       // Always try rfactor for reduction ops
336       return cum_reduce_len > task->hardware_params->num_cores;
337     }
338   }
339 
340   return false;
341 }
342 
343 /*! \brief Return whether the stage has reduce iterators. */
HasReduceIter(const Stage & stage)344 inline bool HasReduceIter(const Stage& stage) {
345   for (const auto& iter : stage->iters) {
346     if (iter->iter_kind != IteratorKind::kSpatial) {
347       return true;
348     }
349   }
350   return false;
351 }
352 
353 /*! \brief Return whether the stage has specific annotated iterators. */
HasAnnotatedIter(const Stage & stage,IteratorAnnotation type)354 inline bool HasAnnotatedIter(const Stage& stage, IteratorAnnotation type) {
355   for (const auto& iter : stage->iters) {
356     if (iter->annotation == type) {
357       return true;
358     }
359   }
360   return false;
361 }
362 
363 /*! \brief Return whether the stage has only one consumer and they are elementwise-matched. */
364 inline bool HasSingleElementwiseMatchedConsumer(const SearchTask& task, const State& state,
365                                                 int stage_id, int* target_stage_id = nullptr) {
366   // Temporal object to be used if the input pointer is nullptr
367   int temp_target_stage_id;
368   if (target_stage_id == nullptr) {
369     target_stage_id = &temp_target_stage_id;
370   }
371   const std::set<int>& consumers = GetConsumers(task, state, stage_id);
372   if (consumers.size() == 1) {
373     *target_stage_id = *consumers.begin();
374     if (ElementwiseMatch(task, state, stage_id, *target_stage_id) &&
375         (!(HasReduceIter(state->stages[stage_id]) &&
376            HasReduceIter(state->stages[*target_stage_id]))) &&
377         (!StrEndsWith(state->stages[*target_stage_id]->op->name, ".shared"))) {
378       return true;
379     }
380   }
381   return false;
382 }
383 
384 /*! \brief Return whether the step changes the number of stages */
IsStageNumberChangingStep(const Step & step)385 inline bool IsStageNumberChangingStep(const Step& step) {
386   return step->IsInstance<CacheWriteStepNode>() || step->IsInstance<CacheReadStepNode>() ||
387          step->IsInstance<RfactorStepNode>();
388 }
389 
390 /*! \brief Return whether the state does cache_read for stage_id. */
HasCacheReadStage(const State & s,int stage_id)391 inline bool HasCacheReadStage(const State& s, int stage_id) {
392   for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
393     if (auto ps = s->transform_steps[i].as<CacheReadStepNode>()) {
394       if (stage_id == ps->stage_id) {
395         return true;
396       }
397     }
398 
399     if (IsStageNumberChangingStep(s->transform_steps[i])) {
400       if (stage_id > s->transform_steps[i]->stage_id) {
401         stage_id--;
402       }
403     }
404   }
405   return false;
406 }
407 
408 /*! \brief Return whether the state does cache_write for stage_id. */
HasCacheWriteStage(const State & s,int stage_id)409 inline bool HasCacheWriteStage(const State& s, int stage_id) {
410   for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
411     if (auto ps = s->transform_steps[i].as<CacheWriteStepNode>()) {
412       if (stage_id == ps->stage_id) {
413         return true;
414       }
415     }
416 
417     if (IsStageNumberChangingStep(s->transform_steps[i])) {
418       if (stage_id > s->transform_steps[i]->stage_id) {
419         stage_id--;
420       }
421     }
422   }
423   return false;
424 }
425 
426 /*! \brief Return whether the state does rfactor for stage_id. */
HasRfactorStage(const State & s,int stage_id)427 inline bool HasRfactorStage(const State& s, int stage_id) {
428   for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
429     if (auto ps = s->transform_steps[i].as<RfactorStepNode>()) {
430       if (stage_id == ps->stage_id) {
431         return true;
432       }
433     }
434 
435     if (IsStageNumberChangingStep(s->transform_steps[i])) {
436       if (stage_id > s->transform_steps[i]->stage_id) {
437         stage_id--;
438       }
439     }
440   }
441   return false;
442 }
443 
444 /*! \brief Return whether the stage does cross thread reduction. */
HasCrossThreadReduction(const State & state,int stage_id)445 inline bool HasCrossThreadReduction(const State& state, int stage_id) {
446   std::function<bool(const Stage&)> check_stage = [](const Stage& in_stage) {
447     for (const auto& iter : in_stage->iters) {
448       if (iter->annotation == IteratorAnnotation::kThreadX &&
449           iter->iter_kind == IteratorKind::kReduction) {
450         return true;
451       }
452     }
453     return false;
454   };
455 
456   // Check the stage itself
457   if (check_stage(state->stages[stage_id])) {
458     return true;
459   }
460 
461   // Check the attached stages
462   for (size_t iter_id = 0; iter_id < state->stages[stage_id]->iters.size(); iter_id++) {
463     const auto& res =
464         state->attach_map->iter_to_attached_stages.find(std::make_pair(stage_id, iter_id));
465     if (res != state->attach_map->iter_to_attached_stages.end()) {
466       for (int attached_stage_id : res->second) {
467         if (check_stage(state->stages[attached_stage_id])) {
468           return true;
469         }
470       }
471     }
472   }
473 
474   return false;
475 }
476 
477 /*! \brief Return whether the stage has been tiled already. */
IsTiled(const Stage & stage)478 inline bool IsTiled(const Stage& stage) {
479   auto op = stage->op.as<te::ComputeOpNode>();
480   CHECK(op != nullptr);
481   return stage->iters.size() != op->axis.size() + op->reduce_axis.size();
482 }
483 
484 /*! \brief Extract primitive iterators from a nested fused or splitted iterator's name. */
ExtractOriginalIterators(const std::string & name,std::set<std::string> * rets)485 inline void ExtractOriginalIterators(const std::string& name, std::set<std::string>* rets) {
486   size_t last_pos = 0;
487   for (size_t i = 0; i < name.size(); ++i) {
488     if (name[i] == '@' || name[i] == '.') {  // '@' for fuse and '.' for split
489       if (!isdigit(name[last_pos]) && name[last_pos] != '@' && name[last_pos] != '.') {
490         rets->insert(name.substr(last_pos, i - last_pos));
491       }
492       last_pos = i + 1;
493     }
494   }
495 
496   if (last_pos < name.size() && !isdigit(name[last_pos]) && name[last_pos] != '@' &&
497       name[last_pos] != '.') {
498     rets->insert(name.substr(last_pos, name.size() - last_pos));
499   }
500 }
501 
502 /*! \brief Get the last reduce iterator in the outermost reduce tile. */
GetLastReduceIteratorInOutermostReduceTile(const Stage & stage)503 inline Iterator GetLastReduceIteratorInOutermostReduceTile(const Stage& stage) {
504   auto pop = stage->op.as<te::ComputeOpNode>();
505   CHECK(pop != nullptr);
506   std::set<std::string> original_names;
507 
508   const std::set<std::string>& no_split_at_inner_name_set =
509       stage->op->attrs.count(SearchPolicyKey::no_split_at_inner)
510           ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
511           : std::set<std::string>();
512   size_t reduce_axis_size = 0;
513   for (const auto axis : pop->reduce_axis) {
514     if (!no_split_at_inner_name_set.count(axis->var->name_hint)) {
515       reduce_axis_size++;
516     }
517   }
518   if (reduce_axis_size) {
519     for (const auto& iter : stage->iters) {
520       if (iter->iter_kind == IteratorKind::kReduction) {
521         ExtractOriginalIterators(iter->name, &original_names);
522         if (original_names.size() == reduce_axis_size) {
523           return iter;
524         }
525       }
526     }
527   } else {
528     // Return the first reduce iterator
529     for (const auto& iter : stage->iters) {
530       if (iter->iter_kind == IteratorKind::kReduction) {
531         return iter;
532       }
533     }
534   }
535 
536   LOG(FATAL) << "Cannot find the iterator.";
537   return stage->iters[0];
538 }
539 
540 /*! \brief Get the target stage id of a history step in the new state.
541  * We need this because the stage_id in the history may be stale due to later steps */
GetTargetStageIDInState(const State & s,int step_id)542 inline int GetTargetStageIDInState(const State& s, int step_id) {
543   int stage_inc = 0;
544 
545   for (size_t i = step_id + 1; i < s->transform_steps.size(); ++i) {
546     if (IsStageNumberChangingStep(s->transform_steps[i])) {
547       if (s->transform_steps[i]->stage_id <= s->transform_steps[step_id]->stage_id + stage_inc)
548         stage_inc++;
549     }
550   }
551   return s->transform_steps[step_id]->stage_id + stage_inc;
552 }
553 
554 /*! \brief Get all split steps for one stage. */
GetSplitStepIds(const State & s,int stage_id,std::vector<int> * split_step_ids)555 inline void GetSplitStepIds(const State& s, int stage_id, std::vector<int>* split_step_ids) {
556   for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
557     if (auto ps = s->transform_steps[i].as<SplitStepNode>()) {
558       if (stage_id == ps->stage_id) {
559         split_step_ids->push_back(i);
560       }
561     }
562 
563     if (IsStageNumberChangingStep(s->transform_steps[i])) {
564       if (stage_id > s->transform_steps[i]->stage_id) {
565         stage_id--;
566       }
567     }
568   }
569 }
570 
571 /*! \brief Fuse all reduction iterators. */
FuseAllReductionIterators(const State & state,int stage_id,Iterator * fused_iter,Array<Iterator> * space_iters,Array<Iterator> * reduce_iters)572 inline State FuseAllReductionIterators(const State& state, int stage_id, Iterator* fused_iter,
573                                        Array<Iterator>* space_iters,
574                                        Array<Iterator>* reduce_iters) {
575   space_iters->clear();
576   reduce_iters->clear();
577 
578   for (const auto& iter : state->stages[stage_id]->iters) {
579     if (iter->iter_kind == IteratorKind::kSpatial) {
580       space_iters->push_back(iter);
581     } else if (iter->iter_kind == IteratorKind::kReduction) {
582       reduce_iters->push_back(iter);
583     }
584   }
585 
586   CHECK(!reduce_iters->empty());
587   State tmp_s = state;
588   if (reduce_iters->size() > 1) {
589     *fused_iter = tmp_s.fuse(stage_id, *reduce_iters);
590   } else {
591     *fused_iter = (*reduce_iters)[0];
592   }
593   return tmp_s;
594 }
595 
596 /*! \brief Fuse all outer level space iterators. */
FuseAllOuterSpaceIterators(const State & state,int stage_id,Iterator * fused_iter)597 inline State FuseAllOuterSpaceIterators(const State& state, int stage_id, Iterator* fused_iter) {
598   std::vector<Iterator> to_fuse;
599   for (size_t iter_id = 0; iter_id < state->stages[stage_id]->iters.size(); ++iter_id) {
600     const auto& it = state->stages[stage_id]->iters[iter_id];
601     // Stop at reduce iterator or annotated iterator
602     if (it->iter_kind == IteratorKind::kReduction || it->annotation != IteratorAnnotation::kNone) {
603       break;
604     }
605     // Stop at compute_at attach point
606     if (state->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, iter_id - 1))) {
607       break;
608     }
609     to_fuse.push_back(it);
610   }
611 
612   CHECK(!to_fuse.empty());
613   State tmp_s = state;
614   if (to_fuse.size() > 1) {
615     *fused_iter = tmp_s.fuse(stage_id, to_fuse);
616   } else {
617     *fused_iter = to_fuse[0];
618   }
619   return tmp_s;
620 }
621 
622 /*! \brief Random sample states. */
RandomSampleStates(const Array<State> & in_states,std::mt19937 * random_gen,size_t out_size)623 inline Array<State> RandomSampleStates(const Array<State>& in_states, std::mt19937* random_gen,
624                                        size_t out_size) {
625   Array<State> out_states;
626   for (size_t i = 0; i < out_size; i++) {
627     out_states.push_back(in_states[(*random_gen)() % in_states.size()]);
628   }
629   return out_states;
630 }
631 
632 /*! \brief Compute prefix-sum probabiilty based on the given weights */
ComputePrefixSumProb(const std::vector<float> & weights,std::vector<double> * prefix_sum_probs)633 inline void ComputePrefixSumProb(const std::vector<float>& weights,
634                                  std::vector<double>* prefix_sum_probs) {
635   // Compute selection probabilities.
636   float sum = 0.0;
637   prefix_sum_probs->resize(weights.size());
638   for (size_t i = 0; i < weights.size(); ++i) {
639     sum += std::max(weights[i], 0.0f);
640     (*prefix_sum_probs)[i] = sum;
641   }
642   for (size_t i = 0; i < weights.size(); ++i) {
643     (*prefix_sum_probs)[i] /= sum;
644   }
645 }
646 
647 /*! \brief Random choose an index according to a prefix sum probability. */
RandomChoose(const std::vector<double> & prefix_sum_probs,std::mt19937 * random_gen)648 inline int RandomChoose(const std::vector<double>& prefix_sum_probs, std::mt19937* random_gen) {
649   std::uniform_real_distribution<> dis(0.0, 1.0);
650   double x = dis(*random_gen);
651 
652   CHECK(!prefix_sum_probs.empty());
653 
654   return std::lower_bound(prefix_sum_probs.begin(), prefix_sum_probs.end(), x) -
655          prefix_sum_probs.begin();
656 }
657 
658 /*! \brief Print a title */
PrintTitle(const std::string & title,int verbose)659 inline void PrintTitle(const std::string& title, int verbose) {
660   StdCout(verbose) << Chars('-', 60) << "\n"
661                    << Chars('-', 25) << "  [ " << title << " ]\n"
662                    << Chars('-', 60) << std::endl;
663 }
664 
665 /*!
666  * \brief Enumerate all possible factorization schemes for splitting an axes.
667  * \note This class will memorize the results for reuse.
668  */
669 class SplitFactorizationMemo {
670  public:
671   using QueryKey = std::tuple<int, int, int>;
672 
673   const Array<Array<Integer>>& GetFactorizationSchemes(int extent, int n_lengths,
674                                                        int max_innermost_factor);
675   const std::vector<int>& GetFactors(int n);
676 
677  private:
678   void DfsEnumerate(int now, int remaining_length, int max_innermost_factor);
679 
680   /*!
681    * \brief A simple implementation of read-write lock.
682    * The guarded block can be read by multiple threads at the same time, while other operations will
683    * be blocked if one thread is writing.
684    * \note Writing threads will wait until all reading threads have finshed. If there're multiple
685    * writing threads, the process order of them is not guaranteed.
686    */
687   class ReadWriteLock {
688    public:
689     /*! \brief The method to get the read lock. One thread can process read if there's on other
690      * writing threads. */
691     void GetRead();
692     /*! \brief The method to get the write lock. One thread can process write if there's on other
693      * reading or writing threads. */
694     void GetWrite();
695     /*! \brief The method to release the read lock. */
696     void UnlockRead();
697     /*! \brief The method to release the write lock. */
698     void UnlockWrite();
699 
700    private:
701     uint32_t read_count_ = 0;
702     bool is_writing_ = false;
703     std::mutex cv_mutex_;
704     std::condition_variable cv_;
705   } lock_;
706 
707   std::unordered_map<QueryKey, Array<Array<Integer>>> memory_;
708 
709   int n_lengths_;
710   Array<Integer> tmp_stack_;
711   Array<Array<Integer>>* results_;
712   std::unordered_map<int, std::vector<int>> factor_memory_;
713 };
714 
715 /*! \brief Get the indexes of SplitStep that processes on spatial iterator. */
716 Array<Integer> GetSpatialSplitStepIds(const State& s, int stage_id);
717 
718 /*! \brief Get the possible compute locations for a stage. */
719 std::vector<std::pair<int, int>> GetComputeLocationCandidates(const SearchTask& task,
720                                                               const State& state, int stage_id);
721 
722 // Apply multi-level tiling structure according to a string format,
723 // where "S" stands a space level, "R" stands for a reduction level.
724 // For example, if the format is "SSRSRS", then we will
725 // use tiling structure:  space_L0, space_L1, reduce_L0, space_L2, reduce_L1, space_L3
726 // For example, if apply "SSRSRS" to matrix multiplication,
727 // we have space iterators i and j, reduce iterator k.
728 // Then the tiling structure is : i0, j0, i1, j1, k0, i2, j2, k1, i3, j3
729 State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format,
730                          std::vector<int>* spatial_split_step_ids = nullptr);
731 
732 // Apply tiling structure: space, space, space, ..., with tile sizes from other SplitStep
733 State FollowTiling(const State& state, int stage_id, const std::vector<int>& split_step_ids,
734                    int n_split);
735 
736 // Prune invalid states and return the results in-place.
737 void PruneInvalidState(const SearchTask& task, Array<State>* states);
738 
739 }  // namespace auto_scheduler
740 }  // namespace tvm
741 
742 #endif  // TVM_AUTO_SCHEDULER_SEARCH_POLICY_UTILS_H_
743