1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file auto_scheduler/search_policy/utils.cc
22  * \brief Common utilities
23  */
24 
25 #include "utils.h"
26 
27 #include <algorithm>
28 
29 namespace tvm {
30 namespace auto_scheduler {
31 
GetSpatialSplitStepIds(const State & s,int stage_id)32 Array<Integer> GetSpatialSplitStepIds(const State& s, int stage_id) {
33   const auto& stage = s->stages[stage_id];
34   const auto& pop = s->stages[stage_id]->op.as<te::ComputeOpNode>();
35   CHECK(pop != nullptr);
36   const std::set<std::string>& no_split_at_inner_name_set =
37       stage->op->attrs.count(SearchPolicyKey::no_split_at_inner)
38           ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
39           : std::set<std::string>();
40   size_t reduce_count = 0;
41   for (const auto axis : pop->reduce_axis) {
42     if (!no_split_at_inner_name_set.count(axis->var->name_hint)) {
43       reduce_count++;
44     }
45   }
46 
47   Array<Integer> spatial_split_step_ids;
48   for (int i = s->transform_steps.size() - 1; i >= 0; --i) {
49     if (IsStageNumberChangingStep(s->transform_steps[i])) {
50       if (stage_id > s->transform_steps[i]->stage_id) {
51         stage_id--;
52       }
53     } else if (auto ps = s->transform_steps[i].as<SplitStepNode>()) {
54       if (stage_id == ps->stage_id) {
55         // Assume SplitStep on reduction axes are always after SplitStep on spatial axes.
56         if (reduce_count) {
57           reduce_count--;
58         } else {
59           spatial_split_step_ids.push_back(i);
60         }
61       }
62     }
63   }
64 
65   return spatial_split_step_ids;
66 }
67 
GetComputeLocationCandidates(const SearchTask & task,const State & state,int stage_id)68 std::vector<std::pair<int, int>> GetComputeLocationCandidates(const SearchTask& task,
69                                                               const State& state, int stage_id) {
70   int target_stage_id = GetSingleConsumerId(task, state, stage_id);
71   if (target_stage_id < 0) {
72     return {};
73   }
74   const Stage& target_stage = state->stages[target_stage_id];
75 
76   std::vector<std::pair<int, int>> candidates;
77   bool target_compute_at_other = target_stage->compute_at == ComputeAtKind::kIter;
78   bool target_is_tiled = IsTiled(target_stage);
79 
80   bool visited_reduce = false;
81   // Enumerate compute_at location at target_stage
82   // TODO(merrymercy): More analysis here to make smarter choices
83   for (size_t i = 0; i < target_stage->iters.size(); ++i) {
84     const Iterator& target_iter = target_stage->iters[i];
85     if (target_iter->iter_kind == IteratorKind::kReduction) {
86       visited_reduce = true;
87       if (!target_is_tiled) {  // Do not go into reduce iter
88         break;
89       }
90     } else if (target_iter->iter_kind == IteratorKind::kSpatial) {
91       if (visited_reduce) {  // Do not go into inner tile
92         break;
93       }
94     }
95 
96     if (target_iter->annotation == IteratorAnnotation::kUnroll) {
97       // Do not go into the unroll region of const tensor indices
98       break;
99     }
100 
101     if (GetExtent(target_iter) == 1) {
102       // Skip iterators with length of 1
103       continue;
104     }
105     if (target_compute_at_other && target_iter->iter_kind == IteratorKind::kSpatial &&
106         StrEndsWith(target_iter->name, ".0")) {
107       // Skip the first level iterators if target stage compute_at another stage
108       // In this case, the lengths of first level iterators are always one
109       continue;
110     }
111     candidates.emplace_back(target_stage_id, i);
112 
113     if (state->attach_map->iter_to_attached_stages.count(std::make_pair(target_stage_id, i))) {
114       break;
115     }
116   }
117 
118   // if the target_stage is already compute_at another stage X, try also compute_at X
119   // We call stage X as `target_target_stage`
120   if (target_compute_at_other) {
121     int target_target_stage_id;
122     target_target_stage_id = state->attach_map->stage_to_attach_iter.at(target_stage_id).first;
123     const Stage& target_target_stage = state->stages[target_target_stage_id];
124 
125     for (size_t i = 0; i < target_target_stage->iters.size(); ++i) {
126       const Iterator& target_target_iter = target_target_stage->iters[i];
127       if (target_target_iter->iter_kind == IteratorKind::kReduction ||
128           state->attach_map->iter_to_attached_stages.count(
129               std::make_pair(target_target_stage_id, i))) {
130         break;
131       }
132 
133       if (target_target_iter->annotation == IteratorAnnotation::kUnroll) {
134         // Do not go into the unroll region of const tensor indices
135         break;
136       }
137 
138       if (GetExtent(target_target_iter) == 1) {  // skip iterators with length of 1
139         continue;
140       }
141 
142       candidates.emplace_back(target_target_stage_id, i);
143     }
144   }
145 
146   return candidates;
147 }
148 
DoMultiLevelTiling(const State & state,int stage_id,const std::string & format,std::vector<int> * spatial_split_step_ids)149 State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format,
150                          std::vector<int>* spatial_split_step_ids) {
151   // Temporal object to be used if the input pointer is nullptr
152   std::vector<int> temp_split_step_ids;
153   if (spatial_split_step_ids == nullptr) {
154     spatial_split_step_ids = &temp_split_step_ids;
155   }
156   std::vector<std::vector<Iterator>> space_levels;
157   std::vector<std::vector<Iterator>> reduce_levels;
158   std::vector<Iterator> space_outer, space_inner, reduce_outer, reduce_inner;
159   Array<Iterator> split_res;
160 
161   for (const auto c : format) {
162     if (tolower(c) == 's') {
163       space_levels.emplace_back();
164     } else if (tolower(c) == 'r') {
165       reduce_levels.emplace_back();
166     } else {
167       LOG(FATAL) << "Invalid multi-level tiling format: " << format;
168     }
169   }
170   size_t n_space = space_levels.size();
171   size_t n_reduce = reduce_levels.size();
172 
173   spatial_split_step_ids->clear();
174 
175   State tmp_s = state;
176   const Stage& stage = state->stages[stage_id];
177   const std::set<std::string>& no_split_at_inner_name_set =
178       stage->op->attrs.count(SearchPolicyKey::no_split_at_inner)
179           ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
180           : std::set<std::string>();
181 
182   for (const auto& iter : state->stages[stage_id]->iters) {
183     if (!no_split_at_inner_name_set.count(iter->name)) {
184       if (iter->iter_kind == IteratorKind::kSpatial) {
185         CHECK_GE(n_space, 1);
186 
187         if (n_space == 1) {
188           space_levels[0].push_back(iter);
189         } else {
190           split_res = tmp_s.split(stage_id, iter, Array<Optional<Integer>>(n_space - 1, NullOpt));
191           for (size_t i = 0; i < n_space; i++) {
192             space_levels[i].push_back(split_res[i]);
193           }
194           spatial_split_step_ids->push_back(tmp_s->transform_steps.size() - 1);
195         }
196       } else if (iter->iter_kind == IteratorKind::kReduction) {
197         CHECK_GE(n_reduce, 1);
198 
199         if (n_reduce == 1) {
200           reduce_levels[0].push_back(iter);
201         } else {
202           split_res = tmp_s.split(stage_id, iter, Array<Optional<Integer>>(n_reduce - 1, NullOpt));
203           for (size_t i = 0; i < n_reduce; i++) {
204             reduce_levels[i].push_back(split_res[i]);
205           }
206         }
207       } else {
208         LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind);
209       }
210     } else {
211       if (iter->iter_kind == IteratorKind::kSpatial) {
212         space_inner.push_back(iter);
213       } else if (iter->iter_kind == IteratorKind::kReduction) {
214         reduce_inner.push_back(iter);
215       } else {
216         LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind);
217       }
218     }
219   }
220 
221   if (!space_outer.empty()) {
222     CHECK(!space_levels.empty());
223     space_levels.front().insert(space_levels.front().begin(),
224                                 std::make_move_iterator(space_outer.begin()),
225                                 std::make_move_iterator(space_outer.end()));
226   }
227   if (!space_inner.empty()) {
228     CHECK(!space_levels.empty());
229     space_levels.back().insert(space_levels.back().begin(),
230                                std::make_move_iterator(space_inner.begin()),
231                                std::make_move_iterator(space_inner.end()));
232   }
233 
234   if (!reduce_outer.empty()) {
235     CHECK(!reduce_levels.empty());
236     reduce_levels.front().insert(reduce_levels.front().begin(),
237                                  std::make_move_iterator(reduce_outer.begin()),
238                                  std::make_move_iterator(reduce_outer.end()));
239   }
240   if (!reduce_inner.empty()) {
241     CHECK(!reduce_levels.empty());
242     reduce_levels.back().insert(reduce_levels.back().begin(),
243                                 std::make_move_iterator(reduce_inner.begin()),
244                                 std::make_move_iterator(reduce_inner.end()));
245   }
246 
247   Array<Iterator> order;
248   int space_ct = 0, reduce_ct = 0;
249   for (const auto c : format) {
250     if (tolower(c) == 's') {
251       order.insert(order.end(), std::make_move_iterator(space_levels[space_ct].begin()),
252                    std::make_move_iterator(space_levels[space_ct].end()));
253       space_ct++;
254     } else if (tolower(c) == 'r') {
255       order.insert(order.end(), std::make_move_iterator(reduce_levels[reduce_ct].begin()),
256                    std::make_move_iterator(reduce_levels[reduce_ct].end()));
257       reduce_ct++;
258     } else {
259       LOG(FATAL) << "Invalid multi level tiling format: " << format;
260     }
261   }
262 
263   tmp_s.reorder(stage_id, order);
264   return tmp_s;
265 }
266 
FollowTiling(const State & state,int stage_id,const std::vector<int> & split_step_ids,int n_split)267 State FollowTiling(const State& state, int stage_id, const std::vector<int>& split_step_ids,
268                    int n_split) {
269   if (n_split < 1 || n_split > 3) {
270     LOG(FATAL) << "Invalid split parts, currently only support 1, 2 and 3";
271   }
272   // Apply up to three-level tiling structure:  space_L0, space_L1, space_L2
273   std::vector<Iterator> space_0, space_1, space_2, space_3, tmp_order;
274   Array<Iterator> split_res;
275 
276   auto pop = state->stages[stage_id]->op.as<te::ComputeOpNode>();
277   CHECK(pop != nullptr);
278   const Stage& stage = state->stages[stage_id];
279   const std::set<std::string>& no_split_at_inner_name_set =
280       stage->op->attrs.count(SearchPolicyKey::no_split_at_inner)
281           ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
282           : std::set<std::string>();
283   int no_split_at_inner_name_in_stage_cnt = 0;
284   for (const auto& iter : state->stages[stage_id]->iters) {
285     no_split_at_inner_name_in_stage_cnt += no_split_at_inner_name_set.count(iter->name);
286   }
287 
288   CHECK_EQ(state->stages[stage_id]->iters.size() - no_split_at_inner_name_in_stage_cnt,
289            split_step_ids.size());
290 
291   State tmp_s = state;
292   int ct = 0;
293   for (const auto& iter : state->stages[stage_id]->iters) {
294     if (iter->iter_kind == IteratorKind::kSpatial) {
295       // For spatial iterator, split it into multi iterators
296       if (!no_split_at_inner_name_set.count(iter->name)) {
297         IteratorAnnotation ann_type = iter->annotation;
298         split_res = tmp_s.follow_split(stage_id, iter, split_step_ids[ct], n_split);
299         // Restore annotation. Move unroll and vectorize to inner, move parallel
300         // to outer
301         switch (ann_type) {
302           case IteratorAnnotation::kUnroll:
303             split_res.Set(n_split, tmp_s.unroll(stage_id, split_res[n_split]));
304             break;
305           case IteratorAnnotation::kVectorize:
306             split_res.Set(n_split, tmp_s.vectorize(stage_id, split_res[n_split]));
307             break;
308           case IteratorAnnotation::kParallel:
309             split_res.Set(0, tmp_s.parallel(stage_id, split_res[0]));
310             break;
311           default:
312             break;
313         }
314 
315         space_0.push_back(split_res[0]);
316         space_1.push_back(split_res[1]);
317         if (n_split >= 2) {
318           space_2.push_back(split_res[2]);
319           if (n_split == 3) {
320             space_3.push_back(split_res[3]);
321           }
322         }
323         ct++;
324       } else {
325         if (no_split_at_inner_name_set.count(iter->name)) {
326           if (n_split == 1) {
327             space_1.push_back(iter);
328           } else if (n_split == 2) {
329             space_2.push_back(iter);
330           } else {
331             CHECK_EQ(n_split, 3);
332             space_3.push_back(iter);
333           }
334         }
335       }
336     } else {
337       LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind);
338     }
339   }
340 
341   if (n_split == 3) {
342     ConcatenateMove(&tmp_order, &space_0, &space_1, &space_2, &space_3);
343   } else if (n_split == 2) {
344     ConcatenateMove(&tmp_order, &space_0, &space_1, &space_2);
345   } else {
346     ConcatenateMove(&tmp_order, &space_0, &space_1);
347   }
348   tmp_s.reorder(stage_id, tmp_order);
349   return tmp_s;
350 }
351 
352 // Return whether a state has nested parallel, which is invalid on CPUs
HasNestedParallel(const State & state)353 bool HasNestedParallel(const State& state) {
354   std::function<void(int stage_id, size_t*)> count_parallel_ct;
355 
356   count_parallel_ct = [&state, &count_parallel_ct](int stage_id, size_t* parallel_ct) {
357     const Stage& stage = state->stages[stage_id];
358 
359     if (stage->compute_at == ComputeAtKind::kInlined) {
360       return;
361     }
362 
363     for (size_t i = 0; i < stage->iters.size(); ++i) {
364       if (stage->iters[i]->annotation == IteratorAnnotation::kParallel) {
365         (*parallel_ct)++;
366       }
367 
368       IterKey iter_key(stage_id, i);
369       auto pair = state->attach_map->iter_to_attached_stages.find(iter_key);
370       if (pair != state->attach_map->iter_to_attached_stages.end()) {
371         for (const auto& attach_stage_id : pair->second) {
372           count_parallel_ct(attach_stage_id, parallel_ct);
373         }
374       }
375     }
376   };
377 
378   for (size_t stage_id = 0; stage_id < state->stages.size(); ++stage_id) {
379     size_t parallel_ct = 0;
380 
381     if (state->stages[stage_id]->compute_at == ComputeAtKind::kRoot) {
382       count_parallel_ct(stage_id, &parallel_ct);
383       if (parallel_ct >= 2) {
384         return true;
385       }
386     }
387   }
388 
389   return false;
390 }
391 
PruneInvalidState(const SearchTask & task,Array<State> * states)392 void PruneInvalidState(const SearchTask& task, Array<State>* states) {
393   size_t pt = 0;
394   for (size_t i = 0; i < states->size(); ++i) {
395     if (!(*states)[i].defined()) {
396       continue;
397     }
398     if (!IsGPUTask(task) && HasNestedParallel((*states)[i])) {
399       continue;
400     }
401 
402     if (i != pt) {
403       states->Set(pt, (*states)[i]);
404     }
405     pt++;
406   }
407 
408   if (pt == 0) {
409     LOG(FATAL) << "Internal error: All states are invalid.";
410   } else {
411     states->resize(pt);
412   }
413 }
414 
415 /********** SplitFactorizationMemo **********/
416 
GetRead()417 void SplitFactorizationMemo::ReadWriteLock::GetRead() {
418   std::unique_lock<std::mutex> lock(cv_mutex_);
419   // Wake up and get the mutex lock if there's no writing thread
420   cv_.wait(lock, [this]() { return !this->is_writing_; });
421   read_count_++;
422 }
423 
GetWrite()424 void SplitFactorizationMemo::ReadWriteLock::GetWrite() {
425   std::unique_lock<std::mutex> lock(cv_mutex_);
426   // Wake up and get the mutex lock if there's no reading or writing threads
427   cv_.wait(lock, [this]() { return this->read_count_ == 0 && !this->is_writing_; });
428   is_writing_ = true;
429 }
430 
UnlockRead()431 void SplitFactorizationMemo::ReadWriteLock::UnlockRead() {
432   std::lock_guard<std::mutex> lock(cv_mutex_);
433   read_count_--;
434   // Notify the other blocked threads if this is the last reading thread
435   if (read_count_ == 0) {
436     cv_.notify_one();
437   }
438 }
439 
UnlockWrite()440 void SplitFactorizationMemo::ReadWriteLock::UnlockWrite() {
441   std::lock_guard<std::mutex> lock(cv_mutex_);
442   is_writing_ = false;
443   // Notify the other blocked threads
444   cv_.notify_one();
445 }
446 
GetFactorizationSchemes(int extent,int n_lengths,int max_innermost_factor)447 const Array<Array<Integer>>& SplitFactorizationMemo::GetFactorizationSchemes(
448     int extent, int n_lengths, int max_innermost_factor) {
449   QueryKey key = std::make_tuple(extent, n_lengths, max_innermost_factor);
450   const auto& const_memory = memory_;
451   lock_.GetRead();
452   const auto& it = const_memory.find(key);
453   const auto& memory_end = const_memory.end();
454   lock_.UnlockRead();
455   if (it != memory_end) {
456     return it->second;
457   }
458 
459   lock_.GetWrite();
460   tmp_stack_ = Array<Integer>(n_lengths, Integer());
461   results_ = &memory_[key];
462   n_lengths_ = n_lengths;
463   DfsEnumerate(0, extent, max_innermost_factor);
464   lock_.UnlockWrite();
465 
466   return *results_;
467 }
468 
DfsEnumerate(int now,int remaining_length,int max_innermost_factor)469 void SplitFactorizationMemo::DfsEnumerate(int now, int remaining_length, int max_innermost_factor) {
470   if (now == n_lengths_) {
471     if (tmp_stack_.back().as<IntImmNode>()->value <= max_innermost_factor) {
472       results_->push_back(tmp_stack_);
473     }
474   } else {
475     for (const auto& f : GetFactors(remaining_length)) {
476       tmp_stack_.Set(now, Integer(f));
477       DfsEnumerate(now + 1, remaining_length / f, max_innermost_factor);
478     }
479   }
480 }
481 
GetFactors(int n)482 const std::vector<int>& SplitFactorizationMemo::GetFactors(int n) {
483   auto it = factor_memory_.find(n);
484   if (it != factor_memory_.end()) {
485     return it->second;
486   }
487 
488   std::vector<int>& res = factor_memory_[n];
489   int step = n % 2 == 0 ? 1 : 2;
490   for (size_t i = 1; i < static_cast<size_t>(std::sqrt(n)) + 1; i += step) {
491     if (n % i == 0) {
492       res.push_back(i);
493       if (n / i != i) {
494         res.push_back(n / i);
495       }
496     }
497   }
498   std::sort(res.begin(), res.end());
499   return res;
500 }
501 
502 /********** Utils interface API for ffi **********/
503 
504 TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsIsTiled")
__anonf652c8760402(const Stage& stage) 505     .set_body_typed([](const Stage& stage) { return IsTiled(stage); });
506 
507 TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasCacheReadStage")
__anonf652c8760502(const State& s, int stage_id) 508     .set_body_typed([](const State& s, int stage_id) { return HasCacheReadStage(s, stage_id); });
509 
510 TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasCacheWriteStage")
__anonf652c8760602(const State& s, int stage_id) 511     .set_body_typed([](const State& s, int stage_id) { return HasCacheWriteStage(s, stage_id); });
512 
513 TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasRfactorStage")
__anonf652c8760702(const State& s, int stage_id) 514     .set_body_typed([](const State& s, int stage_id) { return HasRfactorStage(s, stage_id); });
515 
516 TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasCrossThreadReduction")
__anonf652c8760802(const State& s, int stage_id) 517     .set_body_typed([](const State& s, int stage_id) {
518       return HasCrossThreadReduction(s, stage_id);
519     });
520 
521 }  // namespace auto_scheduler
522 }  // namespace tvm
523