1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file auto_scheduler/search_policy/utils.h
22 * \brief Common utilities for search policies.
23 */
24
25 #ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_UTILS_H_
26 #define TVM_AUTO_SCHEDULER_SEARCH_POLICY_UTILS_H_
27
28 #include <dmlc/common.h>
29 #include <tvm/auto_scheduler/loop_state.h>
30 #include <tvm/auto_scheduler/search_policy.h>
31 #include <tvm/ir/expr.h>
32 #include <tvm/te/operation.h>
33
34 #include <algorithm>
35 #include <condition_variable>
36 #include <set>
37 #include <string>
38 #include <tuple>
39 #include <unordered_map>
40 #include <unordered_set>
41 #include <utility>
42 #include <vector>
43
44 #include "../utils.h"
45
46 namespace tvm {
47 namespace auto_scheduler {
48
49 /*! \brief Return whether the search task is targeting a CPU. */
IsCPUTask(const SearchTask & task)50 inline bool IsCPUTask(const SearchTask& task) {
51 return (task)->target->kind->device_type == kDLCPU;
52 }
53
54 /*! \brief Return whether the search task is targeting a GPU. */
IsGPUTask(const SearchTask & task)55 inline bool IsGPUTask(const SearchTask& task) {
56 return (task)->target->kind->device_type == kDLGPU ||
57 (task)->target->kind->device_type == kDLOpenCL ||
58 (task)->target->kind->device_type == kDLVulkan ||
59 (task)->target->kind->device_type == kDLMetal ||
60 (task)->target->kind->device_type == kDLROCM ||
61 (task)->target->kind->device_type == kOpenGL;
62 }
63
64 /*! \brief Return whether the search task is targeting a CUDA GPU. */
IsCUDATask(const SearchTask & task)65 inline bool IsCUDATask(const SearchTask& task) {
66 return (task)->target->kind->device_type == kDLGPU;
67 }
68
69 /*! \brief Return whether the search task is targeting a OpenCL GPU. */
IsOpenCLTask(const SearchTask & task)70 inline bool IsOpenCLTask(const SearchTask& task) {
71 return (task)->target->kind->device_type == kDLOpenCL;
72 }
73
74 /*! \brief Argsort. Order: largest to smallest */
75 template <typename T>
Argsort(const std::vector<T> & scores)76 inline std::vector<int> Argsort(const std::vector<T>& scores) {
77 std::vector<int> index;
78 index.reserve(scores.size());
79 for (size_t i = 0; i < scores.size(); ++i) {
80 index.push_back(i);
81 }
82 auto cmp = [&scores](int l, int r) { return scores[l] > scores[r]; };
83 std::sort(index.begin(), index.end(), cmp);
84 return index;
85 }
86
87 /*! \brief Convert operation to stage id. */
OperationToStage(const te::Operation & op,const State & state)88 inline int OperationToStage(const te::Operation& op, const State& state) {
89 for (size_t i = 0; i < state->stages.size(); ++i) {
90 if (op == state->stages[i]->op) {
91 return i;
92 }
93 }
94 LOG(FATAL) << "Cannot find op: " << op;
95 return -1;
96 }
97
98 /********** Get Parameters **********/
99
100 /*! \brief Get an integer from a tvm str Map. */
GetIntParam(const Map<String,ObjectRef> & attr_dict,const std::string & key)101 inline int GetIntParam(const Map<String, ObjectRef>& attr_dict, const std::string& key) {
102 CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict;
103 auto pint = attr_dict[key].as<IntImmNode>();
104 CHECK(pint != nullptr);
105 return pint->value;
106 }
107
108 /*! \brief Get a double from a tvm str Map. */
GetDoubleParam(const Map<String,ObjectRef> & attr_dict,const std::string & key)109 inline double GetDoubleParam(const Map<String, ObjectRef>& attr_dict, const std::string& key) {
110 CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict;
111 auto pdouble = attr_dict[key].as<FloatImmNode>();
112 CHECK(pdouble != nullptr);
113 return pdouble->value;
114 }
115
116 /*! \brief Get a string from a tvm str Map. */
GetStringParam(const Map<String,ObjectRef> & attr_dict,const std::string & key)117 inline std::string GetStringParam(const Map<String, ObjectRef>& attr_dict, const std::string& key) {
118 CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict;
119 const auto& target = attr_dict[key];
120 if (auto pstr = target.as<StringImmNode>()) {
121 return pstr->value;
122 }
123 auto pstr = target.as<StringObj>();
124 CHECK(pstr != nullptr);
125 return pstr->data;
126 }
127
128 /*! \brief Get a iterator name set from a tvm str Map. */
GetIterNameSetParam(const Map<String,ObjectRef> & attr_dict,const std::string & key)129 inline std::set<std::string> GetIterNameSetParam(const Map<String, ObjectRef>& attr_dict,
130 const std::string& key) {
131 std::set<std::string> ret;
132 CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict;
133 auto names = attr_dict[key].as<ArrayNode>();
134 CHECK(names != nullptr);
135 for (const auto& name : *names) {
136 ret.insert(name.as<StringObj>()->data);
137 }
138 return ret;
139 }
140
141 /********** Checks with ComputeDAG **********/
142
143 /*! \brief Return whether an op is strictly-inlineable. */
IsStrictlyInlineable(const SearchTask & task,const State & state,int stage_id)144 inline bool IsStrictlyInlineable(const SearchTask& task, const State& state, int stage_id) {
145 if (state->current_compute_dag) {
146 return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.IsStrictlyInlineable(
147 state->stages[stage_id]->op);
148 } else {
149 return task->compute_dag->access_analyzer.IsStrictlyInlineable(state->stages[stage_id]->op);
150 }
151 }
152
153 /*! \brief Return whether an op is an output op. */
IsOutputOp(const SearchTask & task,const State & state,int stage_id)154 inline bool IsOutputOp(const SearchTask& task, const State& state, int stage_id) {
155 if (state->current_compute_dag) {
156 return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.IsOutput(
157 state->stages[stage_id]->op);
158 } else {
159 return task->compute_dag->access_analyzer.IsOutput(state->stages[stage_id]->op);
160 }
161 }
162
163 /*! \brief Return whether an op needs multi level tiling. */
NeedsMultilevelTiling(const SearchTask & task,const State & state,int stage_id)164 inline bool NeedsMultilevelTiling(const SearchTask& task, const State& state, int stage_id) {
165 if (state->current_compute_dag) {
166 return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.NeedsMultiLevelTiling(
167 state->stages[stage_id]->op);
168 } else {
169 return task->compute_dag->access_analyzer.NeedsMultiLevelTiling(state->stages[stage_id]->op);
170 }
171 }
172
173 /*! \brief Get all consumers for a stage. This function propagates the relation for inlined ops. */
GetConsumers(const SearchTask & task,const State & state,int stage_id)174 inline std::set<int> GetConsumers(const SearchTask& task, const State& state, int stage_id) {
175 std::unordered_set<te::Operation, ObjectHash, ObjectEqual> consumers;
176 std::set<int> ret;
177
178 if (state->current_compute_dag) {
179 consumers = state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.GetConsumers(
180 state, state->stages[stage_id]->op);
181 } else {
182 consumers = task->compute_dag->access_analyzer.GetConsumers(state, state->stages[stage_id]->op);
183 }
184
185 for (const auto& op : consumers) {
186 ret.insert(OperationToStage(op, state));
187 }
188 return ret;
189 }
190
191 /*! \brief Check if a stage has single consumer or all of its consumers share a common root, return
192 * the target consumer root or -1. */
GetSingleConsumerId(const SearchTask & task,const State & state,int stage_id)193 inline int GetSingleConsumerId(const SearchTask& task, const State& state, int stage_id) {
194 const std::set<int>& consumers = GetConsumers(task, state, stage_id);
195 if (consumers.empty()) {
196 return -1;
197 }
198
199 if (consumers.size() == 1) {
200 return *consumers.begin();
201 } else {
202 // Check all consumers share a common root
203 int common_root_id = -1;
204 bool mismatch = false;
205 for (const auto& consumer_stage_id : consumers) {
206 int root_id = -1;
207 if (state->stages[consumer_stage_id]->compute_at == ComputeAtKind::kRoot) {
208 root_id = consumer_stage_id;
209 } else if (state->stages[consumer_stage_id]->compute_at == ComputeAtKind::kIter) {
210 root_id = state->attach_map->stage_to_attach_iter.at(consumer_stage_id).first;
211 } else {
212 LOG(FATAL) << "Invalid case";
213 }
214
215 if (common_root_id == -1) {
216 common_root_id = root_id;
217 } else {
218 if (common_root_id != root_id) {
219 mismatch = true;
220 break;
221 }
222 }
223 }
224
225 return mismatch ? -1 : common_root_id;
226 }
227 }
228
229 /*! \brief Get all producers for a stage. This function propagates the relation for inlined ops. */
GetProducers(const SearchTask & task,const State & state,int stage_id)230 inline std::set<int> GetProducers(const SearchTask& task, const State& state, int stage_id) {
231 std::unordered_set<te::Operation, ObjectHash, ObjectEqual> producers;
232 std::set<int> ret;
233
234 if (state->current_compute_dag) {
235 producers = state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.GetProducers(
236 state, state->stages[stage_id]->op);
237 } else {
238 producers = task->compute_dag->access_analyzer.GetProducers(state, state->stages[stage_id]->op);
239 }
240
241 for (const auto& op : producers) {
242 ret.insert(OperationToStage(op, state));
243 }
244 return ret;
245 }
246
247 /*! \brief Get all producers for a stage. This function DOES NOT propagates the relation for
248 * inlined ops. */
GetDirectProducers(const SearchTask & task,const State & state,int stage_id)249 inline std::set<int> GetDirectProducers(const SearchTask& task, const State& state, int stage_id) {
250 std::unordered_set<te::Operation, ObjectHash, ObjectEqual> producers;
251 std::set<int> ret;
252
253 if (state->current_compute_dag) {
254 producers = state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.GetDirectProducers(
255 state->stages[stage_id]->op);
256 } else {
257 producers = task->compute_dag->access_analyzer.GetDirectProducers(state->stages[stage_id]->op);
258 }
259
260 for (const auto& op : producers) {
261 ret.insert(OperationToStage(op, state));
262 }
263 return ret;
264 }
265
266 /*! \brief Get the number of common outer iterators. This function propagates the relation for
267 * chains with multiple ops. */
GetNumCommonOuterIterator(const SearchTask & task,const State & state,int stage_id,int target_stage_id)268 inline int GetNumCommonOuterIterator(const SearchTask& task, const State& state, int stage_id,
269 int target_stage_id) {
270 if (state->current_compute_dag) {
271 return state->current_compute_dag.as<ComputeDAGNode>()
272 ->access_analyzer.GetNumCommonOuterIterator(state->stages[stage_id]->op,
273 state->stages[target_stage_id]->op);
274 } else {
275 return task->compute_dag->access_analyzer.GetNumCommonOuterIterator(
276 state->stages[stage_id]->op, state->stages[target_stage_id]->op);
277 }
278 }
279
280 /*! \brief Return whether two ops are elementwise-matched. */
ElementwiseMatch(const SearchTask & task,const State & state,int stage_id,int target_stage_id)281 inline bool ElementwiseMatch(const SearchTask& task, const State& state, int stage_id,
282 int target_stage_id) {
283 const auto& op = state->stages[stage_id]->op;
284 const auto& target_op = state->stages[target_stage_id]->op;
285 if (state->current_compute_dag) {
286 return state->current_compute_dag.as<ComputeDAGNode>()->access_analyzer.ElementWiseMatch(
287 op, target_op);
288 } else {
289 return task->compute_dag->access_analyzer.ElementWiseMatch(op, target_op);
290 }
291 }
292
293 /********** Get informations from Stage/Iterator **********/
294
295 /*! \brief Return the extent of an iterator. */
GetExtent(const Iterator & it)296 inline int64_t GetExtent(const Iterator& it) {
297 if (it->range.defined()) {
298 if (auto pint = it->range->extent.as<IntImmNode>()) {
299 return pint->value;
300 }
301 }
302 return -1;
303 }
304
305 /*! \brief Compute the product of lengths of all space iters and all reduce iters, respectively. */
GetCumulativeSpaceAndReductionLength(const Stage & stage)306 inline std::pair<int64_t, int64_t> GetCumulativeSpaceAndReductionLength(const Stage& stage) {
307 int64_t cum_space_len = 1, cum_reduce_len = 1;
308 for (const auto& iter : stage->iters) {
309 if (iter->iter_kind == IteratorKind::kSpatial) {
310 cum_space_len *= GetExtent(iter);
311 } else if (iter->iter_kind == IteratorKind::kReduction) {
312 cum_reduce_len *= GetExtent(iter);
313 }
314 }
315 return std::make_pair(cum_space_len, cum_reduce_len);
316 }
317
318 /*! \brief Return whether this stage needs rfactor. */
NeedsRfactor(const SearchTask & task,const State & state,int stage_id)319 inline bool NeedsRfactor(const SearchTask& task, const State& state, int stage_id) {
320 const auto& op = state->stages[stage_id]->op;
321 if (op->IsInstance<te::ComputeOpNode>()) {
322 // Compute the product of lengths of all space iters and all reduce iters
323 int cum_space_len, cum_reduce_len;
324 std::tie(cum_space_len, cum_reduce_len) =
325 GetCumulativeSpaceAndReductionLength(state->stages[stage_id]);
326
327 if (NeedsMultilevelTiling(task, state, stage_id)) {
328 // Do not use rfactor if we have enough parallelism on space iters
329 if (cum_space_len > cum_reduce_len || cum_space_len > task->hardware_params->num_cores * 16) {
330 return false;
331 } else {
332 return true;
333 }
334 } else if (cum_reduce_len > 1) {
335 // Always try rfactor for reduction ops
336 return cum_reduce_len > task->hardware_params->num_cores;
337 }
338 }
339
340 return false;
341 }
342
343 /*! \brief Return whether the stage has reduce iterators. */
HasReduceIter(const Stage & stage)344 inline bool HasReduceIter(const Stage& stage) {
345 for (const auto& iter : stage->iters) {
346 if (iter->iter_kind != IteratorKind::kSpatial) {
347 return true;
348 }
349 }
350 return false;
351 }
352
353 /*! \brief Return whether the stage has specific annotated iterators. */
HasAnnotatedIter(const Stage & stage,IteratorAnnotation type)354 inline bool HasAnnotatedIter(const Stage& stage, IteratorAnnotation type) {
355 for (const auto& iter : stage->iters) {
356 if (iter->annotation == type) {
357 return true;
358 }
359 }
360 return false;
361 }
362
363 /*! \brief Return whether the stage has only one consumer and they are elementwise-matched. */
364 inline bool HasSingleElementwiseMatchedConsumer(const SearchTask& task, const State& state,
365 int stage_id, int* target_stage_id = nullptr) {
366 // Temporal object to be used if the input pointer is nullptr
367 int temp_target_stage_id;
368 if (target_stage_id == nullptr) {
369 target_stage_id = &temp_target_stage_id;
370 }
371 const std::set<int>& consumers = GetConsumers(task, state, stage_id);
372 if (consumers.size() == 1) {
373 *target_stage_id = *consumers.begin();
374 if (ElementwiseMatch(task, state, stage_id, *target_stage_id) &&
375 (!(HasReduceIter(state->stages[stage_id]) &&
376 HasReduceIter(state->stages[*target_stage_id]))) &&
377 (!StrEndsWith(state->stages[*target_stage_id]->op->name, ".shared"))) {
378 return true;
379 }
380 }
381 return false;
382 }
383
384 /*! \brief Return whether the step changes the number of stages */
IsStageNumberChangingStep(const Step & step)385 inline bool IsStageNumberChangingStep(const Step& step) {
386 return step->IsInstance<CacheWriteStepNode>() || step->IsInstance<CacheReadStepNode>() ||
387 step->IsInstance<RfactorStepNode>();
388 }
389
390 /*! \brief Return whether the state does cache_read for stage_id. */
HasCacheReadStage(const State & s,int stage_id)391 inline bool HasCacheReadStage(const State& s, int stage_id) {
392 for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
393 if (auto ps = s->transform_steps[i].as<CacheReadStepNode>()) {
394 if (stage_id == ps->stage_id) {
395 return true;
396 }
397 }
398
399 if (IsStageNumberChangingStep(s->transform_steps[i])) {
400 if (stage_id > s->transform_steps[i]->stage_id) {
401 stage_id--;
402 }
403 }
404 }
405 return false;
406 }
407
408 /*! \brief Return whether the state does cache_write for stage_id. */
HasCacheWriteStage(const State & s,int stage_id)409 inline bool HasCacheWriteStage(const State& s, int stage_id) {
410 for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
411 if (auto ps = s->transform_steps[i].as<CacheWriteStepNode>()) {
412 if (stage_id == ps->stage_id) {
413 return true;
414 }
415 }
416
417 if (IsStageNumberChangingStep(s->transform_steps[i])) {
418 if (stage_id > s->transform_steps[i]->stage_id) {
419 stage_id--;
420 }
421 }
422 }
423 return false;
424 }
425
426 /*! \brief Return whether the state does rfactor for stage_id. */
HasRfactorStage(const State & s,int stage_id)427 inline bool HasRfactorStage(const State& s, int stage_id) {
428 for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
429 if (auto ps = s->transform_steps[i].as<RfactorStepNode>()) {
430 if (stage_id == ps->stage_id) {
431 return true;
432 }
433 }
434
435 if (IsStageNumberChangingStep(s->transform_steps[i])) {
436 if (stage_id > s->transform_steps[i]->stage_id) {
437 stage_id--;
438 }
439 }
440 }
441 return false;
442 }
443
444 /*! \brief Return whether the stage does cross thread reduction. */
HasCrossThreadReduction(const State & state,int stage_id)445 inline bool HasCrossThreadReduction(const State& state, int stage_id) {
446 std::function<bool(const Stage&)> check_stage = [](const Stage& in_stage) {
447 for (const auto& iter : in_stage->iters) {
448 if (iter->annotation == IteratorAnnotation::kThreadX &&
449 iter->iter_kind == IteratorKind::kReduction) {
450 return true;
451 }
452 }
453 return false;
454 };
455
456 // Check the stage itself
457 if (check_stage(state->stages[stage_id])) {
458 return true;
459 }
460
461 // Check the attached stages
462 for (size_t iter_id = 0; iter_id < state->stages[stage_id]->iters.size(); iter_id++) {
463 const auto& res =
464 state->attach_map->iter_to_attached_stages.find(std::make_pair(stage_id, iter_id));
465 if (res != state->attach_map->iter_to_attached_stages.end()) {
466 for (int attached_stage_id : res->second) {
467 if (check_stage(state->stages[attached_stage_id])) {
468 return true;
469 }
470 }
471 }
472 }
473
474 return false;
475 }
476
477 /*! \brief Return whether the stage has been tiled already. */
IsTiled(const Stage & stage)478 inline bool IsTiled(const Stage& stage) {
479 auto op = stage->op.as<te::ComputeOpNode>();
480 CHECK(op != nullptr);
481 return stage->iters.size() != op->axis.size() + op->reduce_axis.size();
482 }
483
484 /*! \brief Extract primitive iterators from a nested fused or splitted iterator's name. */
ExtractOriginalIterators(const std::string & name,std::set<std::string> * rets)485 inline void ExtractOriginalIterators(const std::string& name, std::set<std::string>* rets) {
486 size_t last_pos = 0;
487 for (size_t i = 0; i < name.size(); ++i) {
488 if (name[i] == '@' || name[i] == '.') { // '@' for fuse and '.' for split
489 if (!isdigit(name[last_pos]) && name[last_pos] != '@' && name[last_pos] != '.') {
490 rets->insert(name.substr(last_pos, i - last_pos));
491 }
492 last_pos = i + 1;
493 }
494 }
495
496 if (last_pos < name.size() && !isdigit(name[last_pos]) && name[last_pos] != '@' &&
497 name[last_pos] != '.') {
498 rets->insert(name.substr(last_pos, name.size() - last_pos));
499 }
500 }
501
502 /*! \brief Get the last reduce iterator in the outermost reduce tile. */
GetLastReduceIteratorInOutermostReduceTile(const Stage & stage)503 inline Iterator GetLastReduceIteratorInOutermostReduceTile(const Stage& stage) {
504 auto pop = stage->op.as<te::ComputeOpNode>();
505 CHECK(pop != nullptr);
506 std::set<std::string> original_names;
507
508 const std::set<std::string>& no_split_at_inner_name_set =
509 stage->op->attrs.count(SearchPolicyKey::no_split_at_inner)
510 ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
511 : std::set<std::string>();
512 size_t reduce_axis_size = 0;
513 for (const auto axis : pop->reduce_axis) {
514 if (!no_split_at_inner_name_set.count(axis->var->name_hint)) {
515 reduce_axis_size++;
516 }
517 }
518 if (reduce_axis_size) {
519 for (const auto& iter : stage->iters) {
520 if (iter->iter_kind == IteratorKind::kReduction) {
521 ExtractOriginalIterators(iter->name, &original_names);
522 if (original_names.size() == reduce_axis_size) {
523 return iter;
524 }
525 }
526 }
527 } else {
528 // Return the first reduce iterator
529 for (const auto& iter : stage->iters) {
530 if (iter->iter_kind == IteratorKind::kReduction) {
531 return iter;
532 }
533 }
534 }
535
536 LOG(FATAL) << "Cannot find the iterator.";
537 return stage->iters[0];
538 }
539
540 /*! \brief Get the target stage id of a history step in the new state.
541 * We need this because the stage_id in the history may be stale due to later steps */
GetTargetStageIDInState(const State & s,int step_id)542 inline int GetTargetStageIDInState(const State& s, int step_id) {
543 int stage_inc = 0;
544
545 for (size_t i = step_id + 1; i < s->transform_steps.size(); ++i) {
546 if (IsStageNumberChangingStep(s->transform_steps[i])) {
547 if (s->transform_steps[i]->stage_id <= s->transform_steps[step_id]->stage_id + stage_inc)
548 stage_inc++;
549 }
550 }
551 return s->transform_steps[step_id]->stage_id + stage_inc;
552 }
553
554 /*! \brief Get all split steps for one stage. */
GetSplitStepIds(const State & s,int stage_id,std::vector<int> * split_step_ids)555 inline void GetSplitStepIds(const State& s, int stage_id, std::vector<int>* split_step_ids) {
556 for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
557 if (auto ps = s->transform_steps[i].as<SplitStepNode>()) {
558 if (stage_id == ps->stage_id) {
559 split_step_ids->push_back(i);
560 }
561 }
562
563 if (IsStageNumberChangingStep(s->transform_steps[i])) {
564 if (stage_id > s->transform_steps[i]->stage_id) {
565 stage_id--;
566 }
567 }
568 }
569 }
570
571 /*! \brief Fuse all reduction iterators. */
FuseAllReductionIterators(const State & state,int stage_id,Iterator * fused_iter,Array<Iterator> * space_iters,Array<Iterator> * reduce_iters)572 inline State FuseAllReductionIterators(const State& state, int stage_id, Iterator* fused_iter,
573 Array<Iterator>* space_iters,
574 Array<Iterator>* reduce_iters) {
575 space_iters->clear();
576 reduce_iters->clear();
577
578 for (const auto& iter : state->stages[stage_id]->iters) {
579 if (iter->iter_kind == IteratorKind::kSpatial) {
580 space_iters->push_back(iter);
581 } else if (iter->iter_kind == IteratorKind::kReduction) {
582 reduce_iters->push_back(iter);
583 }
584 }
585
586 CHECK(!reduce_iters->empty());
587 State tmp_s = state;
588 if (reduce_iters->size() > 1) {
589 *fused_iter = tmp_s.fuse(stage_id, *reduce_iters);
590 } else {
591 *fused_iter = (*reduce_iters)[0];
592 }
593 return tmp_s;
594 }
595
596 /*! \brief Fuse all outer level space iterators. */
FuseAllOuterSpaceIterators(const State & state,int stage_id,Iterator * fused_iter)597 inline State FuseAllOuterSpaceIterators(const State& state, int stage_id, Iterator* fused_iter) {
598 std::vector<Iterator> to_fuse;
599 for (size_t iter_id = 0; iter_id < state->stages[stage_id]->iters.size(); ++iter_id) {
600 const auto& it = state->stages[stage_id]->iters[iter_id];
601 // Stop at reduce iterator or annotated iterator
602 if (it->iter_kind == IteratorKind::kReduction || it->annotation != IteratorAnnotation::kNone) {
603 break;
604 }
605 // Stop at compute_at attach point
606 if (state->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, iter_id - 1))) {
607 break;
608 }
609 to_fuse.push_back(it);
610 }
611
612 CHECK(!to_fuse.empty());
613 State tmp_s = state;
614 if (to_fuse.size() > 1) {
615 *fused_iter = tmp_s.fuse(stage_id, to_fuse);
616 } else {
617 *fused_iter = to_fuse[0];
618 }
619 return tmp_s;
620 }
621
622 /*! \brief Random sample states. */
RandomSampleStates(const Array<State> & in_states,std::mt19937 * random_gen,size_t out_size)623 inline Array<State> RandomSampleStates(const Array<State>& in_states, std::mt19937* random_gen,
624 size_t out_size) {
625 Array<State> out_states;
626 for (size_t i = 0; i < out_size; i++) {
627 out_states.push_back(in_states[(*random_gen)() % in_states.size()]);
628 }
629 return out_states;
630 }
631
632 /*! \brief Compute prefix-sum probabiilty based on the given weights */
ComputePrefixSumProb(const std::vector<float> & weights,std::vector<double> * prefix_sum_probs)633 inline void ComputePrefixSumProb(const std::vector<float>& weights,
634 std::vector<double>* prefix_sum_probs) {
635 // Compute selection probabilities.
636 float sum = 0.0;
637 prefix_sum_probs->resize(weights.size());
638 for (size_t i = 0; i < weights.size(); ++i) {
639 sum += std::max(weights[i], 0.0f);
640 (*prefix_sum_probs)[i] = sum;
641 }
642 for (size_t i = 0; i < weights.size(); ++i) {
643 (*prefix_sum_probs)[i] /= sum;
644 }
645 }
646
647 /*! \brief Random choose an index according to a prefix sum probability. */
RandomChoose(const std::vector<double> & prefix_sum_probs,std::mt19937 * random_gen)648 inline int RandomChoose(const std::vector<double>& prefix_sum_probs, std::mt19937* random_gen) {
649 std::uniform_real_distribution<> dis(0.0, 1.0);
650 double x = dis(*random_gen);
651
652 CHECK(!prefix_sum_probs.empty());
653
654 return std::lower_bound(prefix_sum_probs.begin(), prefix_sum_probs.end(), x) -
655 prefix_sum_probs.begin();
656 }
657
658 /*! \brief Print a title */
PrintTitle(const std::string & title,int verbose)659 inline void PrintTitle(const std::string& title, int verbose) {
660 StdCout(verbose) << Chars('-', 60) << "\n"
661 << Chars('-', 25) << " [ " << title << " ]\n"
662 << Chars('-', 60) << std::endl;
663 }
664
665 /*!
666 * \brief Enumerate all possible factorization schemes for splitting an axes.
667 * \note This class will memorize the results for reuse.
668 */
669 class SplitFactorizationMemo {
670 public:
671 using QueryKey = std::tuple<int, int, int>;
672
673 const Array<Array<Integer>>& GetFactorizationSchemes(int extent, int n_lengths,
674 int max_innermost_factor);
675 const std::vector<int>& GetFactors(int n);
676
677 private:
678 void DfsEnumerate(int now, int remaining_length, int max_innermost_factor);
679
680 /*!
681 * \brief A simple implementation of read-write lock.
682 * The guarded block can be read by multiple threads at the same time, while other operations will
683 * be blocked if one thread is writing.
684 * \note Writing threads will wait until all reading threads have finshed. If there're multiple
685 * writing threads, the process order of them is not guaranteed.
686 */
687 class ReadWriteLock {
688 public:
689 /*! \brief The method to get the read lock. One thread can process read if there's on other
690 * writing threads. */
691 void GetRead();
692 /*! \brief The method to get the write lock. One thread can process write if there's on other
693 * reading or writing threads. */
694 void GetWrite();
695 /*! \brief The method to release the read lock. */
696 void UnlockRead();
697 /*! \brief The method to release the write lock. */
698 void UnlockWrite();
699
700 private:
701 uint32_t read_count_ = 0;
702 bool is_writing_ = false;
703 std::mutex cv_mutex_;
704 std::condition_variable cv_;
705 } lock_;
706
707 std::unordered_map<QueryKey, Array<Array<Integer>>> memory_;
708
709 int n_lengths_;
710 Array<Integer> tmp_stack_;
711 Array<Array<Integer>>* results_;
712 std::unordered_map<int, std::vector<int>> factor_memory_;
713 };
714
715 /*! \brief Get the indexes of SplitStep that processes on spatial iterator. */
716 Array<Integer> GetSpatialSplitStepIds(const State& s, int stage_id);
717
718 /*! \brief Get the possible compute locations for a stage. */
719 std::vector<std::pair<int, int>> GetComputeLocationCandidates(const SearchTask& task,
720 const State& state, int stage_id);
721
722 // Apply multi-level tiling structure according to a string format,
723 // where "S" stands a space level, "R" stands for a reduction level.
724 // For example, if the format is "SSRSRS", then we will
725 // use tiling structure: space_L0, space_L1, reduce_L0, space_L2, reduce_L1, space_L3
726 // For example, if apply "SSRSRS" to matrix multiplication,
727 // we have space iterators i and j, reduce iterator k.
728 // Then the tiling structure is : i0, j0, i1, j1, k0, i2, j2, k1, i3, j3
729 State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format,
730 std::vector<int>* spatial_split_step_ids = nullptr);
731
732 // Apply tiling structure: space, space, space, ..., with tile sizes from other SplitStep
733 State FollowTiling(const State& state, int stage_id, const std::vector<int>& split_step_ids,
734 int n_split);
735
736 // Prune invalid states and return the results in-place.
737 void PruneInvalidState(const SearchTask& task, Array<State>* states);
738
739 } // namespace auto_scheduler
740 } // namespace tvm
741
742 #endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_UTILS_H_
743