1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file hoist_if_then_else.cc
22 */
23 #include <tvm/ir.h>
24 #include <tvm/ir_visitor.h>
25 #include <tvm/ir_mutator.h>
26 #include <tvm/ir_pass.h>
27 #include <tvm/arithmetic.h>
28 #include <tvm/api_registry.h>
29 #include <unordered_map>
30 #include <unordered_set>
31 #include <queue>
32 #include "../arithmetic/int_set.h"
33 #include "../runtime/thread_storage_scope.h"
34
35 namespace tvm {
36 namespace ir {
37
38 using HoistMap = std::unordered_map<const Node*, std::vector<Stmt>>;
39 using VarMap = std::unordered_map<const Node*, std::unordered_set<const Node*>>;
40
41 /*
42 * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
43 * For example, given the following block:
44 * for (i = 0; i < 3; i++)
45 * for (j = 0; j < 4; j++)
46 * for (k = 0; k < 5; k++)
47 * if (likely(i*2 < 4))
48 * A[3*i+2j+k] = B[7*i+3j+k]
49 *
50 * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
51 * Then we hoist IfThenElse stmt by one For stmt each step:
52 *
53 * Step 1:
54 * for (i = 0; i < 3; i++)
55 * for (j = 0; j < 4; j++)
56 * if (likely(i*2 < 4))
57 * for (k = 0; k < 5; k++)
58 * A[3*i+2j+k] = B[7*i+3j+k]
59 *
60 * Step 2:
61 * for (i = 0; i < 3; i++)
62 * if (likely(i*2 < 4))
63 * for (j = 0; j < 4; j++)
64 * for (k = 0; k < 5; k++)
65 * A[3*i+2j+k] = B[7*i+3j+k]
66 *
67 * In this pass, we only continue detecting possible hoisting chance when visiting For,
68 * IfThenElse or AttrStmt Node. For example, for the following block:
69 * for (i = 0; i < 3; i++)
70 * for (j = 0; j < 4; j++)
71 * A[i + j] = A[i + j] - 1
72 * for (k = 0; k < 5; k++)
73 * if (likely(i*2 < 4))
74 * A[3*i+2j+k] = B[7*i+3j+k]
75 *
76 * Only the For with k variable will be considered and the resulting stmt would be:
77 * for (i = 0; i < 3; i++)
78 * for (j = 0; j < 4; j++)
79 * A[i + j] = A[i + j] - 1
80 * if (likely(i*2 < 4))
81 * for (k = 0; k < 5; k++)
82 * A[3*i+2j+k] = B[7*i+3j+k]
83 *
84 * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
85 * block won't be optimized:
86 * for (i = 0; i < 3; i++)
87 * for (j = 0; j < 4; j++)
88 * for (k = 0; k < 5; k++)
89 * if (likely(i*2 < 4))
90 * A[3*i+2j+k] = B[7*i+3j+k]
91 * if (likely(j > 2))
92 * A[i+j+k] = B[i+j+k]
93 *
94 */
95 class IfThenElseHoist {
96 public:
VisitAndMutate(const Stmt & stmt)97 Stmt VisitAndMutate(const Stmt& stmt) {
98 SelectCandidates(stmt);
99 LocateTopFor();
100 return PostOrderMutate(stmt);
101 }
102
103 private:
104 void SelectCandidates(const Stmt& stmt);
105 void LocateTopFor();
106 Stmt PostOrderMutate(const Stmt& stmt);
107 size_t GetUpdatedFor(const Stmt& for_stmt, const Stmt& if_stmt);
108 Stmt HoistIf(const Stmt& if_stmt);
109
110 // Map of all For nodes to all child IfThenElse nodes.
111 HoistMap for2if_map_;
112 // Map of all IfThenElse nodes to all For nodes which are loop invariant.
113 HoistMap if2for_map_;
114 // Map of highest loop invariant For to child IfThenElse.
115 HoistMap top_for_var_map_;
116 // Map of original For to list of update For nodes.
117 HoistMap for_tracking_map_;
118 // Map of all IfThenElse nodes to condition variable nodes.
119 VarMap cond_var_map_;
120 // List of For nodes added in post order DFS visiting.
121 std::vector<Stmt> ordered_for_list_;
122 };
123
124 // Check whether a given IfThenElse stmt is the first one appearing
125 // in a For stmt.
is_first_if(const Stmt & for_stmt,const Stmt & if_stmt)126 bool is_first_if(const Stmt& for_stmt, const Stmt& if_stmt) {
127 std::vector<const Node*> if_node_list;
128 const For* for_node = for_stmt.as<For>();
129 CHECK(for_node);
130 CHECK(if_stmt.as<IfThenElse>());
131
132 PostOrderVisit(for_node->body, [&](const NodeRef& node) {
133 if (node.as<IfThenElse>()) {
134 if_node_list.push_back(node.get());
135 }
136 });
137 return if_node_list.empty() ? false : if_stmt.get() == if_node_list.back();
138 }
139
140 // Update upper level For node when current For node is modified.
141 // With this function we only need to visit and mutate top level For node
142 // in the main VisitAndMutate function.
update_for(const Stmt & parent_for_stmt,const Stmt & new_if_stmt)143 Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) {
144 const Node* top_for_node;
145 const For* parent_for_node = parent_for_stmt.as<For>();
146 CHECK(parent_for_node);
147 CHECK(new_if_stmt.as<IfThenElse>());
148
149 PostOrderVisit(parent_for_node->body, [&](const NodeRef& node) {
150 if (node.as<For>()) {
151 top_for_node = node.get();
152 }
153 });
154
155 PackedFunc replace_target_for = PackedFunc(
156 [&](TVMArgs args, TVMRetValue *ret){
157 const NodeRef& current_for = args[0];
158 if (current_for.get() == top_for_node) {
159 *ret = new_if_stmt;
160 }
161 });
162
163 return IRTransform(parent_for_stmt, nullptr, replace_target_for,
164 {Expr("For")});
165 }
166
167 // Remove IfThenElse node from a For node.
168 // A pair of For nodes will be generated.
RemoveIf(const Stmt & for_stmt,const Stmt & if_stmt)169 std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) {
170 Stmt then_for;
171 Stmt else_for;
172 CHECK(if_stmt.as<IfThenElse>());
173
174 PackedFunc replace_then_case = PackedFunc(
175 [&](TVMArgs args, TVMRetValue *ret){
176 const NodeRef& node = args[0];
177 if (node == if_stmt) {
178 *ret = node.as<IfThenElse>()->then_case;
179 }
180 });
181
182 PackedFunc replace_else_case = PackedFunc(
183 [&](TVMArgs args, TVMRetValue *ret){
184 const NodeRef& node = args[0];
185 if (node == if_stmt) {
186 *ret = node.as<IfThenElse>()->else_case;
187 }
188 });
189
190 then_for = IRTransform(for_stmt, nullptr, replace_then_case,
191 {Expr("IfThenElse")});
192 if (if_stmt.as<IfThenElse>()->else_case) {
193 else_for = IRTransform(for_stmt, nullptr, replace_else_case,
194 {Expr("IfThenElse")});
195 }
196
197 return std::make_pair(then_for, else_for);
198 }
199
200 // Locate all For nodes and capture child IfThenElse nodes.
SelectCandidates(const Stmt & stmt)201 void IfThenElseHoist::SelectCandidates(const Stmt& stmt) {
202 PostOrderVisit(stmt, [&](const NodeRef& node){
203 const For* for_node = node.as<For>();
204 if (!for_node) return;
205
206 std::queue<Stmt> tracker;
207 tracker.push(for_node->body);
208 Stmt for_stmt = Downcast<Stmt, NodeRef>(node);
209 for2if_map_.insert({for_stmt.get(), std::vector<Stmt>()});
210 while (!tracker.empty()) {
211 Stmt head = tracker.front();
212 tracker.pop();
213 if (head->IsInstance<For>()) {
214 for (const auto& if_stmt : for2if_map_.at(head.get())) {
215 for2if_map_[for_stmt.get()].push_back(if_stmt);
216 }
217 } else if (head->IsInstance<AttrStmt>()) {
218 const AttrStmt* attr_node = head.as<AttrStmt>();
219 tracker.push(attr_node->body);
220 } else if (head->IsInstance<IfThenElse>()) {
221 for2if_map_[for_stmt.get()].push_back(head);
222 const IfThenElse* if_node = head.as<IfThenElse>();
223 tracker.push(if_node->then_case);
224 if (if_node->else_case) {
225 tracker.push(if_node->else_case);
226 }
227
228 // Record condition variables.
229 if (!cond_var_map_.count(head.get())) {
230 std::unordered_set<const Node*> new_var_set;
231 cond_var_map_.insert({head.get(), new_var_set});
232 PostOrderVisit(if_node->condition, [&](const NodeRef& cond_node) {
233 if (cond_node.as<Variable>()) {
234 cond_var_map_[head.get()].insert(cond_node.get());
235 }
236 });
237 }
238 } else {
239 continue;
240 }
241 }
242 ordered_for_list_.emplace_back(Downcast<Stmt, NodeRef>(node));
243 });
244 }
245
246 // For each IfThenElse node, find the highest For node which
247 // meets loop invariant condition.
LocateTopFor()248 void IfThenElseHoist::LocateTopFor() {
249 std::unordered_map<const Node*, Stmt> if_position_map;
250 std::unordered_set<const Node*> top_for_var_set;
251
252 // Create IfThenElse -> For map.
253 for (const Stmt& for_stmt : ordered_for_list_) {
254 std::vector<Stmt> if_list = for2if_map_[for_stmt.get()];
255 const For* for_node = for_stmt.as<For>();
256 CHECK(for_node);
257 top_for_var_map_.insert({for_node->loop_var.get(), if_list});
258 for (const Stmt& if_stmt : if_list) {
259 const Node* if_node = if_stmt.get();
260 if2for_map_[if_node].push_back(for_stmt);
261 }
262 }
263
264 // Locate the highest For node which is loop invariant.
265 for (const auto& item : if2for_map_) {
266 Stmt top_for;
267 const Node* if_stmt = item.first;
268 std::vector<Stmt> for_list = item.second;
269 for (size_t i = 0; i < for_list.size(); ++i) {
270 const Stmt& for_stmt = for_list.at(i);
271 const For* for_node = for_stmt.as<For>();
272 CHECK(for_node);
273 std::vector<Stmt> new_for_list{for_stmt};
274 for_tracking_map_.insert({for_stmt.get(), new_for_list});
275 if (cond_var_map_[if_stmt]
276 .count(for_node->loop_var.get())) {
277 std::vector<Stmt> updated_for_list(for_list.begin(),
278 for_list.begin() + i);
279 if2for_map_[if_stmt] = updated_for_list;
280 break;
281 } else {
282 top_for = for_stmt;
283 }
284 }
285 if (top_for.as<For>()) {
286 if_position_map.insert({if_stmt, top_for});
287 }
288 }
289
290 for (const auto& item : if_position_map) {
291 top_for_var_set.insert(item.second.as<For>()->loop_var.get());
292 }
293
294 std::vector<const Node*> removed_for_var_list;
295 for (const auto& item : top_for_var_map_) {
296 const Node* top_for_var = item.first;
297 std::vector<Stmt> if_list = item.second;
298 if (!top_for_var_set.count(top_for_var)) {
299 removed_for_var_list.push_back(top_for_var);
300 } else {
301 std::vector<Stmt> actual_if_list;
302 for (const Stmt& if_stmt : if_list) {
303 if (if_position_map.count(if_stmt.get())) {
304 actual_if_list.push_back(if_stmt);
305 }
306 }
307 top_for_var_map_[top_for_var] = actual_if_list;
308 }
309 }
310 for (const Node* top_for_var : removed_for_var_list) {
311 top_for_var_map_.erase(top_for_var);
312 }
313 }
314
315 // When we try to mutate a For node, some child For nodes can have already
316 // been mutated. This function is to get the updated For node and further
317 // hoisting can be done based on this new node.
318 // We keep all For nodes tracing in for_tracking_map_. When we get a
319 // hoisted IfThenElse, we match it with tracing For nodes to pick
320 // the updated one.
GetUpdatedFor(const Stmt & for_stmt,const Stmt & if_stmt)321 size_t IfThenElseHoist::GetUpdatedFor(const Stmt& for_stmt,
322 const Stmt& if_stmt) {
323 std::vector<Stmt> tracked_for_list = for_tracking_map_[for_stmt.get()];
324 size_t updated_for_idx = 0;
325 for (size_t i = 0; i < tracked_for_list.size(); ++i) {
326 const Stmt& current_for =
327 tracked_for_list.at(tracked_for_list.size() - 1 - i);
328 if (is_first_if(current_for, if_stmt)) {
329 updated_for_idx = tracked_for_list.size() - 1 - i;
330 break;
331 }
332 }
333 return updated_for_idx;
334 }
335
336 // Hoist an IfThenElse node as high as possible.
337 // This function iterates on all candidate For nodes. For each For node,
338 // it first removes IfThenElse nodes. Then it generates a new IfThenElse
339 // node using mutated For nodes.
HoistIf(const Stmt & if_stmt)340 Stmt IfThenElseHoist::HoistIf(const Stmt& if_stmt) {
341 Stmt new_if = if_stmt;
342
343 for (size_t i = 0; i < if2for_map_[if_stmt.get()].size(); ++i) {
344 const Stmt& for_stmt = if2for_map_[if_stmt.get()].at(i);
345 size_t updated_for_idx = GetUpdatedFor(for_stmt, new_if);
346 const Stmt& updated_for_node =
347 for_tracking_map_[for_stmt.get()].at(updated_for_idx);
348 auto generated_for_pair = RemoveIf(updated_for_node, new_if);
349 const Stmt& then_for = generated_for_pair.first;
350 const Stmt& else_for = generated_for_pair.second;;
351 for_tracking_map_[for_stmt.get()].at(updated_for_idx) = then_for;
352
353 if (else_for.get()) {
354 for_tracking_map_[for_stmt.get()].push_back(else_for);
355 }
356
357 const IfThenElse* new_if_node = new_if.as<IfThenElse>();
358 CHECK(new_if_node);
359 new_if = IfThenElse::make(new_if_node->condition, then_for, else_for);
360 if (i < if2for_map_[if_stmt.get()].size() - 1) {
361 const Stmt& original_next_for = if2for_map_[if_stmt.get()].at(i + 1);
362 const Stmt& actual_next_for =
363 for_tracking_map_[original_next_for.get()].at(updated_for_idx);
364 Stmt update_for_stmt = update_for(actual_next_for, new_if);
365
366 for_tracking_map_[original_next_for.get()].
367 at(updated_for_idx) = update_for_stmt;
368 }
369 }
370 return new_if;
371 }
372
373 // Mutate For nodes in post order DFS manner.
PostOrderMutate(const Stmt & stmt)374 Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) {
375 PackedFunc replace_top_for = PackedFunc(
376 [&](TVMArgs args, TVMRetValue *ret){
377 const NodeRef& current_for = args[0];
378 const For* for_node = current_for.as<For>();
379 if (!for_node) return;
380
381 if (top_for_var_map_.count(for_node->loop_var.get())) {
382 std::vector<Stmt> new_if_list;
383 for (const Stmt& if_stmt :
384 top_for_var_map_[for_node->loop_var.get()]) {
385 new_if_list.emplace_back(HoistIf(if_stmt));
386 }
387
388 const IfThenElse* next_if_node;
389 const IfThenElse* current_if_node =
390 new_if_list.back().as<IfThenElse>();
391 Stmt new_for = Stmt();
392 for (size_t i = new_if_list.size() - 1; i > 0; --i) {
393 CHECK(current_if_node);
394 const Stmt current_if_stmt =
395 IfThenElse::make(current_if_node->condition,
396 current_if_node->then_case,
397 current_if_node->else_case);
398 next_if_node = new_if_list[i - 1].as<IfThenElse>();
399 CHECK(next_if_node);
400 new_for = IfThenElse::make(next_if_node->condition, current_if_stmt,
401 next_if_node->else_case);
402 current_if_node = new_for.as<IfThenElse>();
403 }
404
405 if (!new_for.get()) {
406 const IfThenElse* first_if_node = new_if_list[0].as<IfThenElse>();
407 CHECK(first_if_node);
408 new_for = IfThenElse::make(first_if_node->condition,
409 first_if_node->then_case,
410 first_if_node->else_case);
411 }
412 *ret = new_for;
413 }
414 });
415 return IRTransform(stmt, nullptr, replace_top_for, {Expr("For")});
416 }
417
HoistIfThenElse(Stmt stmt)418 Stmt HoistIfThenElse(Stmt stmt) {
419 return IfThenElseHoist().VisitAndMutate(stmt);
420 }
421
422 } // namespace ir
423 } // namespace tvm
424