1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file auto_scheduler/search_policy/utils.cc
22 * \brief Common utilities
23 */
24
25 #include "utils.h"
26
27 #include <algorithm>
28
29 namespace tvm {
30 namespace auto_scheduler {
31
GetSpatialSplitStepIds(const State & s,int stage_id)32 Array<Integer> GetSpatialSplitStepIds(const State& s, int stage_id) {
33 const auto& stage = s->stages[stage_id];
34 const auto& pop = s->stages[stage_id]->op.as<te::ComputeOpNode>();
35 CHECK(pop != nullptr);
36 const std::set<std::string>& no_split_at_inner_name_set =
37 stage->op->attrs.count(SearchPolicyKey::no_split_at_inner)
38 ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
39 : std::set<std::string>();
40 size_t reduce_count = 0;
41 for (const auto axis : pop->reduce_axis) {
42 if (!no_split_at_inner_name_set.count(axis->var->name_hint)) {
43 reduce_count++;
44 }
45 }
46
47 Array<Integer> spatial_split_step_ids;
48 for (int i = s->transform_steps.size() - 1; i >= 0; --i) {
49 if (IsStageNumberChangingStep(s->transform_steps[i])) {
50 if (stage_id > s->transform_steps[i]->stage_id) {
51 stage_id--;
52 }
53 } else if (auto ps = s->transform_steps[i].as<SplitStepNode>()) {
54 if (stage_id == ps->stage_id) {
55 // Assume SplitStep on reduction axes are always after SplitStep on spatial axes.
56 if (reduce_count) {
57 reduce_count--;
58 } else {
59 spatial_split_step_ids.push_back(i);
60 }
61 }
62 }
63 }
64
65 return spatial_split_step_ids;
66 }
67
GetComputeLocationCandidates(const SearchTask & task,const State & state,int stage_id)68 std::vector<std::pair<int, int>> GetComputeLocationCandidates(const SearchTask& task,
69 const State& state, int stage_id) {
70 int target_stage_id = GetSingleConsumerId(task, state, stage_id);
71 if (target_stage_id < 0) {
72 return {};
73 }
74 const Stage& target_stage = state->stages[target_stage_id];
75
76 std::vector<std::pair<int, int>> candidates;
77 bool target_compute_at_other = target_stage->compute_at == ComputeAtKind::kIter;
78 bool target_is_tiled = IsTiled(target_stage);
79
80 bool visited_reduce = false;
81 // Enumerate compute_at location at target_stage
82 // TODO(merrymercy): More analysis here to make smarter choices
83 for (size_t i = 0; i < target_stage->iters.size(); ++i) {
84 const Iterator& target_iter = target_stage->iters[i];
85 if (target_iter->iter_kind == IteratorKind::kReduction) {
86 visited_reduce = true;
87 if (!target_is_tiled) { // Do not go into reduce iter
88 break;
89 }
90 } else if (target_iter->iter_kind == IteratorKind::kSpatial) {
91 if (visited_reduce) { // Do not go into inner tile
92 break;
93 }
94 }
95
96 if (target_iter->annotation == IteratorAnnotation::kUnroll) {
97 // Do not go into the unroll region of const tensor indices
98 break;
99 }
100
101 if (GetExtent(target_iter) == 1) {
102 // Skip iterators with length of 1
103 continue;
104 }
105 if (target_compute_at_other && target_iter->iter_kind == IteratorKind::kSpatial &&
106 StrEndsWith(target_iter->name, ".0")) {
107 // Skip the first level iterators if target stage compute_at another stage
108 // In this case, the lengths of first level iterators are always one
109 continue;
110 }
111 candidates.emplace_back(target_stage_id, i);
112
113 if (state->attach_map->iter_to_attached_stages.count(std::make_pair(target_stage_id, i))) {
114 break;
115 }
116 }
117
118 // if the target_stage is already compute_at another stage X, try also compute_at X
119 // We call stage X as `target_target_stage`
120 if (target_compute_at_other) {
121 int target_target_stage_id;
122 target_target_stage_id = state->attach_map->stage_to_attach_iter.at(target_stage_id).first;
123 const Stage& target_target_stage = state->stages[target_target_stage_id];
124
125 for (size_t i = 0; i < target_target_stage->iters.size(); ++i) {
126 const Iterator& target_target_iter = target_target_stage->iters[i];
127 if (target_target_iter->iter_kind == IteratorKind::kReduction ||
128 state->attach_map->iter_to_attached_stages.count(
129 std::make_pair(target_target_stage_id, i))) {
130 break;
131 }
132
133 if (target_target_iter->annotation == IteratorAnnotation::kUnroll) {
134 // Do not go into the unroll region of const tensor indices
135 break;
136 }
137
138 if (GetExtent(target_target_iter) == 1) { // skip iterators with length of 1
139 continue;
140 }
141
142 candidates.emplace_back(target_target_stage_id, i);
143 }
144 }
145
146 return candidates;
147 }
148
DoMultiLevelTiling(const State & state,int stage_id,const std::string & format,std::vector<int> * spatial_split_step_ids)149 State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format,
150 std::vector<int>* spatial_split_step_ids) {
151 // Temporal object to be used if the input pointer is nullptr
152 std::vector<int> temp_split_step_ids;
153 if (spatial_split_step_ids == nullptr) {
154 spatial_split_step_ids = &temp_split_step_ids;
155 }
156 std::vector<std::vector<Iterator>> space_levels;
157 std::vector<std::vector<Iterator>> reduce_levels;
158 std::vector<Iterator> space_outer, space_inner, reduce_outer, reduce_inner;
159 Array<Iterator> split_res;
160
161 for (const auto c : format) {
162 if (tolower(c) == 's') {
163 space_levels.emplace_back();
164 } else if (tolower(c) == 'r') {
165 reduce_levels.emplace_back();
166 } else {
167 LOG(FATAL) << "Invalid multi-level tiling format: " << format;
168 }
169 }
170 size_t n_space = space_levels.size();
171 size_t n_reduce = reduce_levels.size();
172
173 spatial_split_step_ids->clear();
174
175 State tmp_s = state;
176 const Stage& stage = state->stages[stage_id];
177 const std::set<std::string>& no_split_at_inner_name_set =
178 stage->op->attrs.count(SearchPolicyKey::no_split_at_inner)
179 ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
180 : std::set<std::string>();
181
182 for (const auto& iter : state->stages[stage_id]->iters) {
183 if (!no_split_at_inner_name_set.count(iter->name)) {
184 if (iter->iter_kind == IteratorKind::kSpatial) {
185 CHECK_GE(n_space, 1);
186
187 if (n_space == 1) {
188 space_levels[0].push_back(iter);
189 } else {
190 split_res = tmp_s.split(stage_id, iter, Array<Optional<Integer>>(n_space - 1, NullOpt));
191 for (size_t i = 0; i < n_space; i++) {
192 space_levels[i].push_back(split_res[i]);
193 }
194 spatial_split_step_ids->push_back(tmp_s->transform_steps.size() - 1);
195 }
196 } else if (iter->iter_kind == IteratorKind::kReduction) {
197 CHECK_GE(n_reduce, 1);
198
199 if (n_reduce == 1) {
200 reduce_levels[0].push_back(iter);
201 } else {
202 split_res = tmp_s.split(stage_id, iter, Array<Optional<Integer>>(n_reduce - 1, NullOpt));
203 for (size_t i = 0; i < n_reduce; i++) {
204 reduce_levels[i].push_back(split_res[i]);
205 }
206 }
207 } else {
208 LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind);
209 }
210 } else {
211 if (iter->iter_kind == IteratorKind::kSpatial) {
212 space_inner.push_back(iter);
213 } else if (iter->iter_kind == IteratorKind::kReduction) {
214 reduce_inner.push_back(iter);
215 } else {
216 LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind);
217 }
218 }
219 }
220
221 if (!space_outer.empty()) {
222 CHECK(!space_levels.empty());
223 space_levels.front().insert(space_levels.front().begin(),
224 std::make_move_iterator(space_outer.begin()),
225 std::make_move_iterator(space_outer.end()));
226 }
227 if (!space_inner.empty()) {
228 CHECK(!space_levels.empty());
229 space_levels.back().insert(space_levels.back().begin(),
230 std::make_move_iterator(space_inner.begin()),
231 std::make_move_iterator(space_inner.end()));
232 }
233
234 if (!reduce_outer.empty()) {
235 CHECK(!reduce_levels.empty());
236 reduce_levels.front().insert(reduce_levels.front().begin(),
237 std::make_move_iterator(reduce_outer.begin()),
238 std::make_move_iterator(reduce_outer.end()));
239 }
240 if (!reduce_inner.empty()) {
241 CHECK(!reduce_levels.empty());
242 reduce_levels.back().insert(reduce_levels.back().begin(),
243 std::make_move_iterator(reduce_inner.begin()),
244 std::make_move_iterator(reduce_inner.end()));
245 }
246
247 Array<Iterator> order;
248 int space_ct = 0, reduce_ct = 0;
249 for (const auto c : format) {
250 if (tolower(c) == 's') {
251 order.insert(order.end(), std::make_move_iterator(space_levels[space_ct].begin()),
252 std::make_move_iterator(space_levels[space_ct].end()));
253 space_ct++;
254 } else if (tolower(c) == 'r') {
255 order.insert(order.end(), std::make_move_iterator(reduce_levels[reduce_ct].begin()),
256 std::make_move_iterator(reduce_levels[reduce_ct].end()));
257 reduce_ct++;
258 } else {
259 LOG(FATAL) << "Invalid multi level tiling format: " << format;
260 }
261 }
262
263 tmp_s.reorder(stage_id, order);
264 return tmp_s;
265 }
266
FollowTiling(const State & state,int stage_id,const std::vector<int> & split_step_ids,int n_split)267 State FollowTiling(const State& state, int stage_id, const std::vector<int>& split_step_ids,
268 int n_split) {
269 if (n_split < 1 || n_split > 3) {
270 LOG(FATAL) << "Invalid split parts, currently only support 1, 2 and 3";
271 }
272 // Apply up to three-level tiling structure: space_L0, space_L1, space_L2
273 std::vector<Iterator> space_0, space_1, space_2, space_3, tmp_order;
274 Array<Iterator> split_res;
275
276 auto pop = state->stages[stage_id]->op.as<te::ComputeOpNode>();
277 CHECK(pop != nullptr);
278 const Stage& stage = state->stages[stage_id];
279 const std::set<std::string>& no_split_at_inner_name_set =
280 stage->op->attrs.count(SearchPolicyKey::no_split_at_inner)
281 ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
282 : std::set<std::string>();
283 int no_split_at_inner_name_in_stage_cnt = 0;
284 for (const auto& iter : state->stages[stage_id]->iters) {
285 no_split_at_inner_name_in_stage_cnt += no_split_at_inner_name_set.count(iter->name);
286 }
287
288 CHECK_EQ(state->stages[stage_id]->iters.size() - no_split_at_inner_name_in_stage_cnt,
289 split_step_ids.size());
290
291 State tmp_s = state;
292 int ct = 0;
293 for (const auto& iter : state->stages[stage_id]->iters) {
294 if (iter->iter_kind == IteratorKind::kSpatial) {
295 // For spatial iterator, split it into multi iterators
296 if (!no_split_at_inner_name_set.count(iter->name)) {
297 IteratorAnnotation ann_type = iter->annotation;
298 split_res = tmp_s.follow_split(stage_id, iter, split_step_ids[ct], n_split);
299 // Restore annotation. Move unroll and vectorize to inner, move parallel
300 // to outer
301 switch (ann_type) {
302 case IteratorAnnotation::kUnroll:
303 split_res.Set(n_split, tmp_s.unroll(stage_id, split_res[n_split]));
304 break;
305 case IteratorAnnotation::kVectorize:
306 split_res.Set(n_split, tmp_s.vectorize(stage_id, split_res[n_split]));
307 break;
308 case IteratorAnnotation::kParallel:
309 split_res.Set(0, tmp_s.parallel(stage_id, split_res[0]));
310 break;
311 default:
312 break;
313 }
314
315 space_0.push_back(split_res[0]);
316 space_1.push_back(split_res[1]);
317 if (n_split >= 2) {
318 space_2.push_back(split_res[2]);
319 if (n_split == 3) {
320 space_3.push_back(split_res[3]);
321 }
322 }
323 ct++;
324 } else {
325 if (no_split_at_inner_name_set.count(iter->name)) {
326 if (n_split == 1) {
327 space_1.push_back(iter);
328 } else if (n_split == 2) {
329 space_2.push_back(iter);
330 } else {
331 CHECK_EQ(n_split, 3);
332 space_3.push_back(iter);
333 }
334 }
335 }
336 } else {
337 LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind);
338 }
339 }
340
341 if (n_split == 3) {
342 ConcatenateMove(&tmp_order, &space_0, &space_1, &space_2, &space_3);
343 } else if (n_split == 2) {
344 ConcatenateMove(&tmp_order, &space_0, &space_1, &space_2);
345 } else {
346 ConcatenateMove(&tmp_order, &space_0, &space_1);
347 }
348 tmp_s.reorder(stage_id, tmp_order);
349 return tmp_s;
350 }
351
352 // Return whether a state has nested parallel, which is invalid on CPUs
HasNestedParallel(const State & state)353 bool HasNestedParallel(const State& state) {
354 std::function<void(int stage_id, size_t*)> count_parallel_ct;
355
356 count_parallel_ct = [&state, &count_parallel_ct](int stage_id, size_t* parallel_ct) {
357 const Stage& stage = state->stages[stage_id];
358
359 if (stage->compute_at == ComputeAtKind::kInlined) {
360 return;
361 }
362
363 for (size_t i = 0; i < stage->iters.size(); ++i) {
364 if (stage->iters[i]->annotation == IteratorAnnotation::kParallel) {
365 (*parallel_ct)++;
366 }
367
368 IterKey iter_key(stage_id, i);
369 auto pair = state->attach_map->iter_to_attached_stages.find(iter_key);
370 if (pair != state->attach_map->iter_to_attached_stages.end()) {
371 for (const auto& attach_stage_id : pair->second) {
372 count_parallel_ct(attach_stage_id, parallel_ct);
373 }
374 }
375 }
376 };
377
378 for (size_t stage_id = 0; stage_id < state->stages.size(); ++stage_id) {
379 size_t parallel_ct = 0;
380
381 if (state->stages[stage_id]->compute_at == ComputeAtKind::kRoot) {
382 count_parallel_ct(stage_id, ¶llel_ct);
383 if (parallel_ct >= 2) {
384 return true;
385 }
386 }
387 }
388
389 return false;
390 }
391
PruneInvalidState(const SearchTask & task,Array<State> * states)392 void PruneInvalidState(const SearchTask& task, Array<State>* states) {
393 size_t pt = 0;
394 for (size_t i = 0; i < states->size(); ++i) {
395 if (!(*states)[i].defined()) {
396 continue;
397 }
398 if (!IsGPUTask(task) && HasNestedParallel((*states)[i])) {
399 continue;
400 }
401
402 if (i != pt) {
403 states->Set(pt, (*states)[i]);
404 }
405 pt++;
406 }
407
408 if (pt == 0) {
409 LOG(FATAL) << "Internal error: All states are invalid.";
410 } else {
411 states->resize(pt);
412 }
413 }
414
415 /********** SplitFactorizationMemo **********/
416
GetRead()417 void SplitFactorizationMemo::ReadWriteLock::GetRead() {
418 std::unique_lock<std::mutex> lock(cv_mutex_);
419 // Wake up and get the mutex lock if there's no writing thread
420 cv_.wait(lock, [this]() { return !this->is_writing_; });
421 read_count_++;
422 }
423
GetWrite()424 void SplitFactorizationMemo::ReadWriteLock::GetWrite() {
425 std::unique_lock<std::mutex> lock(cv_mutex_);
426 // Wake up and get the mutex lock if there's no reading or writing threads
427 cv_.wait(lock, [this]() { return this->read_count_ == 0 && !this->is_writing_; });
428 is_writing_ = true;
429 }
430
UnlockRead()431 void SplitFactorizationMemo::ReadWriteLock::UnlockRead() {
432 std::lock_guard<std::mutex> lock(cv_mutex_);
433 read_count_--;
434 // Notify the other blocked threads if this is the last reading thread
435 if (read_count_ == 0) {
436 cv_.notify_one();
437 }
438 }
439
UnlockWrite()440 void SplitFactorizationMemo::ReadWriteLock::UnlockWrite() {
441 std::lock_guard<std::mutex> lock(cv_mutex_);
442 is_writing_ = false;
443 // Notify the other blocked threads
444 cv_.notify_one();
445 }
446
GetFactorizationSchemes(int extent,int n_lengths,int max_innermost_factor)447 const Array<Array<Integer>>& SplitFactorizationMemo::GetFactorizationSchemes(
448 int extent, int n_lengths, int max_innermost_factor) {
449 QueryKey key = std::make_tuple(extent, n_lengths, max_innermost_factor);
450 const auto& const_memory = memory_;
451 lock_.GetRead();
452 const auto& it = const_memory.find(key);
453 const auto& memory_end = const_memory.end();
454 lock_.UnlockRead();
455 if (it != memory_end) {
456 return it->second;
457 }
458
459 lock_.GetWrite();
460 tmp_stack_ = Array<Integer>(n_lengths, Integer());
461 results_ = &memory_[key];
462 n_lengths_ = n_lengths;
463 DfsEnumerate(0, extent, max_innermost_factor);
464 lock_.UnlockWrite();
465
466 return *results_;
467 }
468
DfsEnumerate(int now,int remaining_length,int max_innermost_factor)469 void SplitFactorizationMemo::DfsEnumerate(int now, int remaining_length, int max_innermost_factor) {
470 if (now == n_lengths_) {
471 if (tmp_stack_.back().as<IntImmNode>()->value <= max_innermost_factor) {
472 results_->push_back(tmp_stack_);
473 }
474 } else {
475 for (const auto& f : GetFactors(remaining_length)) {
476 tmp_stack_.Set(now, Integer(f));
477 DfsEnumerate(now + 1, remaining_length / f, max_innermost_factor);
478 }
479 }
480 }
481
GetFactors(int n)482 const std::vector<int>& SplitFactorizationMemo::GetFactors(int n) {
483 auto it = factor_memory_.find(n);
484 if (it != factor_memory_.end()) {
485 return it->second;
486 }
487
488 std::vector<int>& res = factor_memory_[n];
489 int step = n % 2 == 0 ? 1 : 2;
490 for (size_t i = 1; i < static_cast<size_t>(std::sqrt(n)) + 1; i += step) {
491 if (n % i == 0) {
492 res.push_back(i);
493 if (n / i != i) {
494 res.push_back(n / i);
495 }
496 }
497 }
498 std::sort(res.begin(), res.end());
499 return res;
500 }
501
502 /********** Utils interface API for ffi **********/
503
504 TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsIsTiled")
__anonf652c8760402(const Stage& stage) 505 .set_body_typed([](const Stage& stage) { return IsTiled(stage); });
506
507 TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasCacheReadStage")
__anonf652c8760502(const State& s, int stage_id) 508 .set_body_typed([](const State& s, int stage_id) { return HasCacheReadStage(s, stage_id); });
509
510 TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasCacheWriteStage")
__anonf652c8760602(const State& s, int stage_id) 511 .set_body_typed([](const State& s, int stage_id) { return HasCacheWriteStage(s, stage_id); });
512
513 TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasRfactorStage")
__anonf652c8760702(const State& s, int stage_id) 514 .set_body_typed([](const State& s, int stage_id) { return HasRfactorStage(s, stage_id); });
515
516 TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasCrossThreadReduction")
__anonf652c8760802(const State& s, int stage_id) 517 .set_body_typed([](const State& s, int stage_id) {
518 return HasCrossThreadReduction(s, stage_id);
519 });
520
521 } // namespace auto_scheduler
522 } // namespace tvm
523