1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file hoist_if_then_else.cc
22  */
23 #include <tvm/ir.h>
24 #include <tvm/ir_visitor.h>
25 #include <tvm/ir_mutator.h>
26 #include <tvm/ir_pass.h>
27 #include <tvm/arithmetic.h>
28 #include <tvm/api_registry.h>
29 #include <unordered_map>
30 #include <unordered_set>
31 #include <queue>
32 #include "../arithmetic/int_set.h"
33 #include "../runtime/thread_storage_scope.h"
34 
35 namespace tvm {
36 namespace ir {
37 
38 using HoistMap = std::unordered_map<const Node*, std::vector<Stmt>>;
39 using VarMap = std::unordered_map<const Node*, std::unordered_set<const Node*>>;
40 
41 /*
42  * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
43  * For example, given the following block:
44  * for (i = 0; i < 3; i++)
45  *    for (j = 0; j < 4; j++)
46  *        for (k = 0; k < 5; k++)
47  *            if (likely(i*2 < 4))
48  *                A[3*i+2j+k] = B[7*i+3j+k]
49  *
50  * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
51  * Then we hoist IfThenElse stmt by one For stmt each step:
52  *
53  * Step 1:
54  * for (i = 0; i < 3; i++)
55  *     for (j = 0; j < 4; j++)
56  *         if (likely(i*2 < 4))
57  *             for (k = 0; k < 5; k++)
58  *                 A[3*i+2j+k] = B[7*i+3j+k]
59  *
60  * Step 2:
61  * for (i = 0; i < 3; i++)
62  *     if (likely(i*2 < 4))
63  *         for (j = 0; j < 4; j++)
64  *             for (k = 0; k < 5; k++)
65  *                 A[3*i+2j+k] = B[7*i+3j+k]
66  *
67  * In this pass, we only continue detecting possible hoisting chance when visiting For,
68  * IfThenElse or AttrStmt Node. For example, for the following block:
69  * for (i = 0; i < 3; i++)
70  *    for (j = 0; j < 4; j++)
71  *        A[i + j] = A[i + j] - 1
72  *        for (k = 0; k < 5; k++)
73  *            if (likely(i*2 < 4))
74  *                A[3*i+2j+k] = B[7*i+3j+k]
75  *
76  * Only the For with k variable will be considered and the resulting stmt would be:
77  * for (i = 0; i < 3; i++)
78  *    for (j = 0; j < 4; j++)
79  *        A[i + j] = A[i + j] - 1
80  *        if (likely(i*2 < 4))
81  *            for (k = 0; k < 5; k++)
82  *                A[3*i+2j+k] = B[7*i+3j+k]
83  *
84  * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
85  * block won't be optimized:
86  * for (i = 0; i < 3; i++)
87  *    for (j = 0; j < 4; j++)
88  *        for (k = 0; k < 5; k++)
89  *            if (likely(i*2 < 4))
90  *                A[3*i+2j+k] = B[7*i+3j+k]
91  *            if (likely(j > 2))
92  *                A[i+j+k] = B[i+j+k]
93  *
94  */
95 class IfThenElseHoist {
96  public:
VisitAndMutate(const Stmt & stmt)97   Stmt VisitAndMutate(const Stmt& stmt) {
98     SelectCandidates(stmt);
99     LocateTopFor();
100     return PostOrderMutate(stmt);
101   }
102 
103  private:
104   void SelectCandidates(const Stmt& stmt);
105   void LocateTopFor();
106   Stmt PostOrderMutate(const Stmt& stmt);
107   size_t GetUpdatedFor(const Stmt& for_stmt, const Stmt& if_stmt);
108   Stmt HoistIf(const Stmt& if_stmt);
109 
110   // Map of all For nodes to all child IfThenElse nodes.
111   HoistMap for2if_map_;
112   // Map of all IfThenElse nodes to all For nodes which are loop invariant.
113   HoistMap if2for_map_;
114   // Map of highest loop invariant For to child IfThenElse.
115   HoistMap top_for_var_map_;
116   // Map of original For to list of update For nodes.
117   HoistMap for_tracking_map_;
118   // Map of all IfThenElse nodes to condition variable nodes.
119   VarMap cond_var_map_;
120   // List of For nodes added in post order DFS visiting.
121   std::vector<Stmt> ordered_for_list_;
122 };
123 
124 // Check whether a given IfThenElse stmt is the first one appearing
125 // in a For stmt.
is_first_if(const Stmt & for_stmt,const Stmt & if_stmt)126 bool is_first_if(const Stmt& for_stmt, const Stmt& if_stmt) {
127   std::vector<const Node*> if_node_list;
128   const For* for_node = for_stmt.as<For>();
129   CHECK(for_node);
130   CHECK(if_stmt.as<IfThenElse>());
131 
132   PostOrderVisit(for_node->body, [&](const NodeRef& node) {
133     if (node.as<IfThenElse>()) {
134       if_node_list.push_back(node.get());
135     }
136   });
137   return if_node_list.empty() ? false : if_stmt.get() == if_node_list.back();
138 }
139 
140 // Update upper level For node when current For node is modified.
141 // With this function we only need to visit and mutate top level For node
142 // in the main VisitAndMutate function.
update_for(const Stmt & parent_for_stmt,const Stmt & new_if_stmt)143 Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) {
144   const Node* top_for_node;
145   const For* parent_for_node = parent_for_stmt.as<For>();
146   CHECK(parent_for_node);
147   CHECK(new_if_stmt.as<IfThenElse>());
148 
149   PostOrderVisit(parent_for_node->body, [&](const NodeRef& node) {
150     if (node.as<For>()) {
151       top_for_node = node.get();
152     }
153   });
154 
155   PackedFunc replace_target_for = PackedFunc(
156     [&](TVMArgs args, TVMRetValue *ret){
157       const NodeRef& current_for = args[0];
158       if (current_for.get() == top_for_node) {
159         *ret = new_if_stmt;
160       }
161     });
162 
163   return IRTransform(parent_for_stmt, nullptr, replace_target_for,
164                      {Expr("For")});
165 }
166 
167 // Remove IfThenElse node from a For node.
168 // A pair of For nodes will be generated.
RemoveIf(const Stmt & for_stmt,const Stmt & if_stmt)169 std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) {
170   Stmt then_for;
171   Stmt else_for;
172   CHECK(if_stmt.as<IfThenElse>());
173 
174   PackedFunc replace_then_case = PackedFunc(
175     [&](TVMArgs args, TVMRetValue *ret){
176       const NodeRef& node  = args[0];
177       if (node == if_stmt) {
178         *ret = node.as<IfThenElse>()->then_case;
179       }
180     });
181 
182   PackedFunc replace_else_case = PackedFunc(
183     [&](TVMArgs args, TVMRetValue *ret){
184       const NodeRef& node  = args[0];
185       if (node == if_stmt) {
186         *ret = node.as<IfThenElse>()->else_case;
187       }
188     });
189 
190   then_for = IRTransform(for_stmt, nullptr, replace_then_case,
191                          {Expr("IfThenElse")});
192   if (if_stmt.as<IfThenElse>()->else_case) {
193     else_for = IRTransform(for_stmt, nullptr, replace_else_case,
194                            {Expr("IfThenElse")});
195   }
196 
197   return std::make_pair(then_for, else_for);
198 }
199 
200 // Locate all For nodes and capture child IfThenElse nodes.
SelectCandidates(const Stmt & stmt)201 void IfThenElseHoist::SelectCandidates(const Stmt& stmt) {
202   PostOrderVisit(stmt, [&](const NodeRef& node){
203     const For* for_node = node.as<For>();
204     if (!for_node) return;
205 
206     std::queue<Stmt> tracker;
207     tracker.push(for_node->body);
208     Stmt for_stmt = Downcast<Stmt, NodeRef>(node);
209     for2if_map_.insert({for_stmt.get(), std::vector<Stmt>()});
210     while (!tracker.empty()) {
211       Stmt head = tracker.front();
212       tracker.pop();
213       if (head->IsInstance<For>()) {
214         for (const auto& if_stmt : for2if_map_.at(head.get())) {
215           for2if_map_[for_stmt.get()].push_back(if_stmt);
216         }
217       } else if (head->IsInstance<AttrStmt>()) {
218         const AttrStmt* attr_node = head.as<AttrStmt>();
219         tracker.push(attr_node->body);
220       } else if (head->IsInstance<IfThenElse>()) {
221         for2if_map_[for_stmt.get()].push_back(head);
222         const IfThenElse* if_node = head.as<IfThenElse>();
223         tracker.push(if_node->then_case);
224         if (if_node->else_case) {
225           tracker.push(if_node->else_case);
226         }
227 
228         // Record condition variables.
229         if (!cond_var_map_.count(head.get())) {
230           std::unordered_set<const Node*> new_var_set;
231           cond_var_map_.insert({head.get(), new_var_set});
232           PostOrderVisit(if_node->condition, [&](const NodeRef& cond_node) {
233             if (cond_node.as<Variable>()) {
234               cond_var_map_[head.get()].insert(cond_node.get());
235             }
236           });
237         }
238       } else {
239         continue;
240       }
241     }
242     ordered_for_list_.emplace_back(Downcast<Stmt, NodeRef>(node));
243   });
244 }
245 
246 // For each IfThenElse node, find the highest For node which
247 // meets loop invariant condition.
LocateTopFor()248 void IfThenElseHoist::LocateTopFor() {
249   std::unordered_map<const Node*, Stmt> if_position_map;
250   std::unordered_set<const Node*> top_for_var_set;
251 
252   // Create IfThenElse -> For map.
253   for (const Stmt& for_stmt : ordered_for_list_) {
254     std::vector<Stmt> if_list = for2if_map_[for_stmt.get()];
255     const For* for_node = for_stmt.as<For>();
256     CHECK(for_node);
257     top_for_var_map_.insert({for_node->loop_var.get(), if_list});
258     for (const Stmt& if_stmt : if_list) {
259       const Node* if_node = if_stmt.get();
260       if2for_map_[if_node].push_back(for_stmt);
261     }
262   }
263 
264   // Locate the highest For node which is loop invariant.
265   for (const auto& item : if2for_map_) {
266     Stmt top_for;
267     const Node* if_stmt = item.first;
268     std::vector<Stmt> for_list = item.second;
269     for (size_t i = 0; i < for_list.size(); ++i) {
270       const Stmt& for_stmt = for_list.at(i);
271       const For* for_node = for_stmt.as<For>();
272       CHECK(for_node);
273       std::vector<Stmt> new_for_list{for_stmt};
274       for_tracking_map_.insert({for_stmt.get(), new_for_list});
275       if (cond_var_map_[if_stmt]
276         .count(for_node->loop_var.get())) {
277         std::vector<Stmt> updated_for_list(for_list.begin(),
278                                            for_list.begin() + i);
279         if2for_map_[if_stmt] = updated_for_list;
280         break;
281       } else {
282         top_for = for_stmt;
283       }
284     }
285     if (top_for.as<For>()) {
286       if_position_map.insert({if_stmt, top_for});
287     }
288   }
289 
290   for (const auto& item : if_position_map) {
291     top_for_var_set.insert(item.second.as<For>()->loop_var.get());
292   }
293 
294   std::vector<const Node*> removed_for_var_list;
295   for (const auto& item : top_for_var_map_) {
296     const Node* top_for_var = item.first;
297     std::vector<Stmt> if_list = item.second;
298     if (!top_for_var_set.count(top_for_var)) {
299       removed_for_var_list.push_back(top_for_var);
300     } else {
301       std::vector<Stmt> actual_if_list;
302       for (const Stmt& if_stmt : if_list) {
303         if (if_position_map.count(if_stmt.get())) {
304           actual_if_list.push_back(if_stmt);
305         }
306       }
307       top_for_var_map_[top_for_var] = actual_if_list;
308     }
309   }
310   for (const Node* top_for_var : removed_for_var_list) {
311     top_for_var_map_.erase(top_for_var);
312   }
313 }
314 
315 // When we try to mutate a For node, some child For nodes can have already
316 // been mutated. This function is to get the updated For node and further
317 // hoisting can be done based on this new node.
318 // We keep all For nodes tracing in for_tracking_map_. When we get a
319 // hoisted IfThenElse, we match it with tracing For nodes to pick
320 // the updated one.
GetUpdatedFor(const Stmt & for_stmt,const Stmt & if_stmt)321 size_t IfThenElseHoist::GetUpdatedFor(const Stmt& for_stmt,
322                                        const Stmt& if_stmt) {
323   std::vector<Stmt> tracked_for_list = for_tracking_map_[for_stmt.get()];
324   size_t updated_for_idx = 0;
325   for (size_t i = 0; i < tracked_for_list.size(); ++i) {
326     const Stmt& current_for =
327       tracked_for_list.at(tracked_for_list.size() - 1 - i);
328     if (is_first_if(current_for, if_stmt)) {
329       updated_for_idx = tracked_for_list.size() - 1 - i;
330       break;
331     }
332   }
333   return updated_for_idx;
334 }
335 
336 // Hoist an IfThenElse node as high as possible.
337 // This function iterates on all candidate For nodes. For each For node,
338 // it first removes IfThenElse nodes. Then it generates a new IfThenElse
339 // node using mutated For nodes.
HoistIf(const Stmt & if_stmt)340 Stmt IfThenElseHoist::HoistIf(const Stmt& if_stmt) {
341   Stmt new_if = if_stmt;
342 
343   for (size_t i = 0; i < if2for_map_[if_stmt.get()].size(); ++i) {
344     const Stmt& for_stmt = if2for_map_[if_stmt.get()].at(i);
345     size_t updated_for_idx = GetUpdatedFor(for_stmt, new_if);
346     const Stmt& updated_for_node =
347       for_tracking_map_[for_stmt.get()].at(updated_for_idx);
348     auto generated_for_pair = RemoveIf(updated_for_node, new_if);
349     const Stmt& then_for = generated_for_pair.first;
350     const Stmt& else_for = generated_for_pair.second;;
351     for_tracking_map_[for_stmt.get()].at(updated_for_idx) = then_for;
352 
353     if (else_for.get()) {
354       for_tracking_map_[for_stmt.get()].push_back(else_for);
355     }
356 
357     const IfThenElse* new_if_node = new_if.as<IfThenElse>();
358     CHECK(new_if_node);
359     new_if = IfThenElse::make(new_if_node->condition, then_for, else_for);
360     if (i < if2for_map_[if_stmt.get()].size() - 1) {
361       const Stmt& original_next_for = if2for_map_[if_stmt.get()].at(i + 1);
362       const Stmt& actual_next_for =
363         for_tracking_map_[original_next_for.get()].at(updated_for_idx);
364       Stmt update_for_stmt = update_for(actual_next_for, new_if);
365 
366       for_tracking_map_[original_next_for.get()].
367         at(updated_for_idx) = update_for_stmt;
368     }
369   }
370   return new_if;
371 }
372 
373 // Mutate For nodes in post order DFS manner.
PostOrderMutate(const Stmt & stmt)374 Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) {
375   PackedFunc replace_top_for = PackedFunc(
376     [&](TVMArgs args, TVMRetValue *ret){
377       const NodeRef& current_for = args[0];
378       const For* for_node = current_for.as<For>();
379       if (!for_node) return;
380 
381       if (top_for_var_map_.count(for_node->loop_var.get())) {
382         std::vector<Stmt> new_if_list;
383         for (const Stmt& if_stmt :
384           top_for_var_map_[for_node->loop_var.get()]) {
385           new_if_list.emplace_back(HoistIf(if_stmt));
386         }
387 
388         const IfThenElse* next_if_node;
389         const IfThenElse* current_if_node =
390           new_if_list.back().as<IfThenElse>();
391         Stmt new_for = Stmt();
392         for (size_t i = new_if_list.size() - 1; i > 0; --i) {
393           CHECK(current_if_node);
394           const Stmt current_if_stmt =
395             IfThenElse::make(current_if_node->condition,
396                              current_if_node->then_case,
397                              current_if_node->else_case);
398           next_if_node = new_if_list[i - 1].as<IfThenElse>();
399           CHECK(next_if_node);
400           new_for = IfThenElse::make(next_if_node->condition, current_if_stmt,
401                                      next_if_node->else_case);
402           current_if_node = new_for.as<IfThenElse>();
403         }
404 
405         if (!new_for.get()) {
406           const IfThenElse* first_if_node = new_if_list[0].as<IfThenElse>();
407           CHECK(first_if_node);
408           new_for = IfThenElse::make(first_if_node->condition,
409                                      first_if_node->then_case,
410                                      first_if_node->else_case);
411         }
412         *ret = new_for;
413       }
414     });
415   return IRTransform(stmt, nullptr, replace_top_for, {Expr("For")});
416 }
417 
HoistIfThenElse(Stmt stmt)418 Stmt HoistIfThenElse(Stmt stmt) {
419   return IfThenElseHoist().VisitAndMutate(stmt);
420 }
421 
422 }  // namespace ir
423 }  // namespace tvm
424