1 #include "LICM.h"
2 #include "CSE.h"
3 #include "ExprUsesVar.h"
4 #include "IREquality.h"
5 #include "IRMutator.h"
6 #include "IROperator.h"
7 #include "Scope.h"
8 #include "Simplify.h"
9 #include "Substitute.h"
10 
11 namespace Halide {
12 namespace Internal {
13 
14 using std::map;
15 using std::pair;
16 using std::set;
17 using std::string;
18 using std::vector;
19 
20 // Is it safe to lift an Expr out of a loop (and potentially across a device boundary)
21 class CanLift : public IRVisitor {
22     using IRVisitor::visit;
23 
visit(const Call * op)24     void visit(const Call *op) override {
25         if (!op->is_pure()) {
26             result = false;
27         } else {
28             IRVisitor::visit(op);
29         }
30     }
31 
visit(const Load * op)32     void visit(const Load *op) override {
33         result = false;
34     }
35 
visit(const Variable * op)36     void visit(const Variable *op) override {
37         if (varying.contains(op->name)) {
38             result = false;
39         }
40     }
41 
42     const Scope<> &varying;
43 
44 public:
45     bool result{true};
46 
CanLift(const Scope<> & v)47     CanLift(const Scope<> &v)
48         : varying(v) {
49     }
50 };
51 
52 // Lift pure loop invariants to the top level. Applied independently
53 // to each loop.
54 class LiftLoopInvariants : public IRMutator {
55     using IRMutator::visit;
56 
57     Scope<> varying;
58 
can_lift(const Expr & e)59     bool can_lift(const Expr &e) {
60         CanLift check(varying);
61         e.accept(&check);
62         return check.result;
63     }
64 
should_lift(const Expr & e)65     bool should_lift(const Expr &e) {
66         if (!can_lift(e)) return false;
67         if (e.as<Variable>()) return false;
68         if (e.as<Broadcast>()) return false;
69         if (is_const(e)) return false;
70         // bool vectors are buggy enough in LLVM that lifting them is a bad idea.
71         // (We just skip all vectors on the principle that we don't want them
72         // on the stack anyway.)
73         if (e.type().is_vector()) return false;
74         if (const Cast *cast = e.as<Cast>()) {
75             if (cast->type.bytes() > cast->value.type().bytes()) {
76                 // Don't lift widening casts.
77                 return false;
78             }
79         }
80         if (const Add *add = e.as<Add>()) {
81             if (add->type == Int(32) &&
82                 is_const(add->b)) {
83                 // Don't lift constant integer offsets. They're often free.
84                 return false;
85             }
86         }
87         if (const Call *call = e.as<Call>()) {
88             if (call->is_intrinsic(Call::strict_float) ||
89                 call->is_intrinsic(Call::likely) ||
90                 call->is_intrinsic(Call::likely_if_innermost) ||
91                 call->is_intrinsic(Call::reinterpret)) {
92                 // Don't lift these intrinsics. They're free.
93                 return should_lift(call->args[0]);
94             }
95             if (call->is_intrinsic(Call::size_of_halide_buffer_t)) {
96                 return true;
97             }
98         }
99         return true;
100     }
101 
102     template<typename T, typename Body>
visit_let(const T * op)103     Body visit_let(const T *op) {
104         // Visit an entire chain of lets in a single method to conserve stack space.
105         struct Frame {
106             const T *op;
107             Expr new_value;
108             ScopedBinding<> binding;
109             Frame(const T *op, Expr v, Scope<> &scope)
110                 : op(op), new_value(std::move(v)), binding(scope, op->name) {
111             }
112         };
113         vector<Frame> frames;
114         Body result;
115         do {
116             frames.emplace_back(op, mutate(op->value), varying);
117             result = op->body;
118         } while ((op = result.template as<T>()));
119 
120         result = mutate(result);
121 
122         for (auto it = frames.rbegin(); it != frames.rend(); it++) {
123             if (it->new_value.same_as(it->op->value) && result.same_as(it->op->body)) {
124                 result = it->op;
125             } else {
126                 result = T::make(it->op->name, std::move(it->new_value), result);
127             }
128         }
129 
130         return result;
131     }
132 
visit(const Let * op)133     Expr visit(const Let *op) override {
134         return visit_let<Let, Expr>(op);
135     }
136 
visit(const LetStmt * op)137     Stmt visit(const LetStmt *op) override {
138         return visit_let<LetStmt, Stmt>(op);
139     }
140 
visit(const For * op)141     Stmt visit(const For *op) override {
142         ScopedBinding<> p(varying, op->name);
143         return IRMutator::visit(op);
144     }
145 
146 public:
147     using IRMutator::mutate;
148 
mutate(const Expr & e)149     Expr mutate(const Expr &e) override {
150         if (should_lift(e)) {
151             // Lift it in canonical form
152             Expr lifted_expr = simplify(e);
153             auto it = lifted.find(lifted_expr);
154             if (it == lifted.end()) {
155                 string name = unique_name('t');
156                 lifted[lifted_expr] = name;
157                 return Variable::make(e.type(), name);
158             } else {
159                 return Variable::make(e.type(), it->second);
160             }
161         } else {
162             return IRMutator::mutate(e);
163         }
164     }
165 
166     map<Expr, string, IRDeepCompare> lifted;
167 };
168 
169 // The pass above can lift out the value of lets entirely, leaving
170 // them as just renamings of other variables. Easier to substitute
171 // them in as a post-pass rather than make the pass above more clever.
172 class SubstituteTrivialLets : public IRMutator {
173     using IRMutator::visit;
174 
visit(const Let * op)175     Expr visit(const Let *op) override {
176         if (op->value.as<Variable>()) {
177             return mutate(substitute(op->name, op->value, op->body));
178         } else {
179             return IRMutator::visit(op);
180         }
181     }
182 
visit(const LetStmt * op)183     Stmt visit(const LetStmt *op) override {
184         if (op->value.as<Variable>()) {
185             return mutate(substitute(op->name, op->value, op->body));
186         } else {
187             return IRMutator::visit(op);
188         }
189     }
190 };
191 
192 class LICM : public IRMutator {
193     using IRMutator::visit;
194 
195     bool in_gpu_loop{false};
196 
197     // Compute the cost of computing an expression inside the inner
198     // loop, compared to just loading it as a parameter.
cost(const Expr & e,const set<string> & vars)199     int cost(const Expr &e, const set<string> &vars) {
200         if (is_const(e)) {
201             return 0;
202         } else if (const Variable *var = e.as<Variable>()) {
203             if (vars.count(var->name)) {
204                 // We're loading this already
205                 return 0;
206             } else {
207                 // Would have to load this
208                 return 1;
209             }
210         } else if (const Add *add = e.as<Add>()) {
211             return cost(add->a, vars) + cost(add->b, vars) + 1;
212         } else if (const Sub *sub = e.as<Sub>()) {
213             return cost(sub->a, vars) + cost(sub->b, vars) + 1;
214         } else if (const Mul *mul = e.as<Mul>()) {
215             return cost(mul->a, vars) + cost(mul->b, vars) + 1;
216         } else if (const Call *call = e.as<Call>()) {
217             if (call->is_intrinsic(Call::reinterpret)) {
218                 internal_assert(call->args.size() == 1);
219                 return cost(call->args[0], vars);
220             } else {
221                 return 100;
222             }
223         } else {
224             return 100;
225         }
226     }
227 
visit(const For * op)228     Stmt visit(const For *op) override {
229         ScopedValue<bool> old_in_gpu_loop(in_gpu_loop);
230         in_gpu_loop =
231             (op->for_type == ForType::GPUBlock ||
232              op->for_type == ForType::GPUThread);
233 
234         if (old_in_gpu_loop && in_gpu_loop) {
235             // Don't lift lets to in-between gpu blocks/threads
236             return IRMutator::visit(op);
237         } else if (op->device_api == DeviceAPI::GLSL) {
238             // GLSL uses magic names for varying things. Just skip LICM.
239             return IRMutator::visit(op);
240         } else {
241 
242             // Lift invariants
243             LiftLoopInvariants lifter;
244             Stmt new_stmt = lifter.mutate(op);
245             new_stmt = SubstituteTrivialLets().mutate(new_stmt);
246 
247             // As an optimization to reduce register pressure, take
248             // the set of expressions to lift and check if any can
249             // cheaply be computed from others. If so it's better to
250             // do that than to load multiple related values off the
251             // stack. We currently only consider expressions that are
252             // the sum, difference, or product of two variables
253             // already used in the kernel, or a variable plus a
254             // constant.
255 
256             // Linearize all the exprs and names
257             vector<Expr> exprs;
258             vector<string> names;
259             for (const auto &p : lifter.lifted) {
260                 exprs.push_back(p.first);
261                 names.push_back(p.second);
262             }
263 
264             // Jointly CSE the lifted exprs put putting them together into a dummy Expr
265             Expr dummy_call = Call::make(Int(32), Call::bundle, exprs, Call::PureIntrinsic);
266             dummy_call = common_subexpression_elimination(dummy_call, true);
267 
268             // Peel off containing lets. These will be lifted.
269             vector<pair<string, Expr>> lets;
270             while (const Let *let = dummy_call.as<Let>()) {
271                 lets.emplace_back(let->name, let->value);
272                 dummy_call = let->body;
273             }
274 
275             // Track the set of variables used by the inner loop
276             class CollectVars : public IRVisitor {
277                 using IRVisitor::visit;
278                 void visit(const Variable *op) override {
279                     vars.insert(op->name);
280                 }
281 
282             public:
283                 set<string> vars;
284             } vars;
285             new_stmt.accept(&vars);
286 
287             // Now consider substituting back in each use
288             const Call *call = dummy_call.as<Call>();
289             internal_assert(call->is_intrinsic(Call::bundle));
290             bool converged;
291             do {
292                 converged = true;
293                 for (size_t i = 0; i < exprs.size(); i++) {
294                     if (!exprs[i].defined()) continue;
295                     Expr e = call->args[i];
296                     if (cost(e, vars.vars) <= 1) {
297                         // Just subs it back in - computing it is as cheap
298                         // as loading it.
299                         e.accept(&vars);
300                         new_stmt = substitute(names[i], e, new_stmt);
301                         names[i].clear();
302                         exprs[i] = Expr();
303                         converged = false;
304                     } else {
305                         exprs[i] = e;
306                     }
307                 }
308             } while (!converged);
309 
310             // Recurse
311             const For *loop = new_stmt.as<For>();
312             internal_assert(loop);
313 
314             new_stmt = For::make(loop->name, loop->min, loop->extent,
315                                  loop->for_type, loop->device_api, mutate(loop->body));
316 
317             // Wrap lets for the lifted invariants
318             for (size_t i = 0; i < exprs.size(); i++) {
319                 if (exprs[i].defined()) {
320                     new_stmt = LetStmt::make(names[i], exprs[i], new_stmt);
321                 }
322             }
323 
324             // Wrap the lets pulled out by CSE
325             while (!lets.empty()) {
326                 new_stmt = LetStmt::make(lets.back().first, lets.back().second, new_stmt);
327                 lets.pop_back();
328             }
329 
330             return new_stmt;
331         }
332     }
333 };
334 
335 // Reassociate summations to group together the loop invariants. Useful to run before LICM.
336 class GroupLoopInvariants : public IRMutator {
337     using IRMutator::visit;
338 
339     Scope<int> var_depth;
340 
341     class ExprDepth : public IRVisitor {
342         using IRVisitor::visit;
343         const Scope<int> &depth;
344 
visit(const Variable * op)345         void visit(const Variable *op) override {
346             if (depth.contains(op->name)) {
347                 result = std::max(result, depth.get(op->name));
348             }
349         }
350 
351     public:
352         int result = 0;
ExprDepth(const Scope<int> & var_depth)353         ExprDepth(const Scope<int> &var_depth)
354             : depth(var_depth) {
355         }
356     };
357 
expr_depth(const Expr & e)358     int expr_depth(const Expr &e) {
359         ExprDepth depth(var_depth);
360         e.accept(&depth);
361         return depth.result;
362     }
363 
364     struct Term {
365         Expr expr;
366         bool positive;
367         int depth;
368     };
369 
extract_summation(const Expr & e)370     vector<Term> extract_summation(const Expr &e) {
371         vector<Term> pending, terms;
372         pending.push_back({e, true, 0});
373         while (!pending.empty()) {
374             Term next = pending.back();
375             pending.pop_back();
376             const Add *add = next.expr.as<Add>();
377             const Sub *sub = next.expr.as<Sub>();
378             if (add) {
379                 pending.push_back({add->a, next.positive, 0});
380                 pending.push_back({add->b, next.positive, 0});
381             } else if (sub) {
382                 pending.push_back({sub->a, next.positive, 0});
383                 pending.push_back({sub->b, !next.positive, 0});
384             } else {
385                 next.expr = mutate(next.expr);
386                 if (next.expr.as<Add>() || next.expr.as<Sub>()) {
387                     // After mutation it became an add or sub, throw it back on the pending queue.
388                     pending.push_back(next);
389                 } else {
390                     next.depth = expr_depth(next.expr);
391                     terms.push_back(next);
392                 }
393             }
394         }
395 
396         // Sort the terms by loop depth. Terms of equal depth are
397         // likely already in a good order, so don't mess with them.
398         std::stable_sort(terms.begin(), terms.end(),
399                          [](const Term &a, const Term &b) {
400                              return a.depth > b.depth;
401                          });
402 
403         return terms;
404     }
405 
reassociate_summation(const Expr & e)406     Expr reassociate_summation(const Expr &e) {
407         vector<Term> terms = extract_summation(e);
408 
409         Expr result;
410         bool positive = true;
411         while (!terms.empty()) {
412             Term next = terms.back();
413             terms.pop_back();
414             if (result.defined()) {
415                 if (next.positive == positive) {
416                     result += next.expr;
417                 } else if (next.positive) {
418                     result = next.expr - result;
419                     positive = true;
420                 } else {
421                     result -= next.expr;
422                 }
423             } else {
424                 result = next.expr;
425                 positive = next.positive;
426             }
427         }
428 
429         if (!positive) {
430             result = make_zero(result.type()) - result;
431         }
432 
433         return result;
434     }
435 
visit(const Add * op)436     Expr visit(const Add *op) override {
437         if (op->type.is_float() || (op->type == Int(32) && is_const(op->b))) {
438             // Don't reassociate float exprs.  (If strict_float is
439             // off, we're allowed to reassociate, and we do
440             // reassociate elsewhere, but there's no benefit to it
441             // here and it's friendlier not to.)
442             //
443             // Also don't reassociate trailing integer constants. They're the
444             // ultimate loop invariant, but doing this to stencils
445             // causes inner loops to track N different pointers
446             // instead of one pointer with constant offsets, and that
447             // complicates aliasing analysis.
448             return IRMutator::visit(op);
449         }
450 
451         return reassociate_summation(op);
452     }
453 
visit(const Sub * op)454     Expr visit(const Sub *op) override {
455         if (op->type.is_float() || (op->type == Int(32) && is_const(op->b))) {
456             return IRMutator::visit(op);
457         }
458 
459         return reassociate_summation(op);
460     }
461 
462     int depth = 0;
463 
visit(const For * op)464     Stmt visit(const For *op) override {
465         depth++;
466         ScopedBinding<int> bind(var_depth, op->name, depth);
467         Stmt stmt = IRMutator::visit(op);
468         depth--;
469         return stmt;
470     }
471 
472     template<typename T, typename Body>
visit_let(const T * op)473     Body visit_let(const T *op) {
474         struct Frame {
475             const T *op;
476             Expr new_value;
477             ScopedBinding<int> binding;
478             Frame(const T *op, Expr v, int depth, Scope<int> &scope)
479                 : op(op),
480                   new_value(std::move(v)),
481                   binding(scope, op->name, depth) {
482             }
483         };
484         std::vector<Frame> frames;
485         Body result;
486 
487         do {
488             result = op->body;
489             int d = 0;
490             if (depth > 0) {
491                 d = expr_depth(op->value);
492             }
493             frames.emplace_back(op, mutate(op->value), d, var_depth);
494         } while ((op = result.template as<T>()));
495 
496         result = mutate(result);
497 
498         for (auto it = frames.rbegin(); it != frames.rend(); it++) {
499             if (it->new_value.same_as(it->op->value) && result.same_as(it->op->body)) {
500                 result = it->op;
501             } else {
502                 result = T::make(it->op->name, it->new_value, result);
503             }
504         }
505 
506         return result;
507     }
508 
visit(const Let * op)509     Expr visit(const Let *op) override {
510         return visit_let<Let, Expr>(op);
511     }
512 
visit(const LetStmt * op)513     Stmt visit(const LetStmt *op) override {
514         return visit_let<LetStmt, Stmt>(op);
515     }
516 };
517 
hoist_loop_invariant_values(Stmt s)518 Stmt hoist_loop_invariant_values(Stmt s) {
519     s = GroupLoopInvariants().mutate(s);
520     s = common_subexpression_elimination(s);
521     s = LICM().mutate(s);
522     s = simplify_exprs(s);
523     return s;
524 }
525 
526 namespace {
527 
528 // Move IfThenElse nodes from the inside of a piece of Stmt IR to the
529 // outside when legal.
530 class HoistIfStatements : public IRMutator {
531     using IRMutator::visit;
532 
visit(const LetStmt * op)533     Stmt visit(const LetStmt *op) override {
534         Stmt body = mutate(op->body);
535         if (const IfThenElse *i = body.as<IfThenElse>()) {
536             if (!i->else_case.defined() &&
537                 is_pure(op->value) &&
538                 is_pure(i->condition) &&
539                 !expr_uses_var(i->condition, op->name)) {
540                 Stmt s = LetStmt::make(op->name, op->value, i->then_case);
541                 return IfThenElse::make(i->condition, s);
542             }
543         }
544         return LetStmt::make(op->name, op->value, body);
545     }
546 
visit(const For * op)547     Stmt visit(const For *op) override {
548         Stmt body = mutate(op->body);
549         if (const IfThenElse *i = body.as<IfThenElse>()) {
550             if (!i->else_case.defined() &&
551                 is_pure(i->condition) &&
552                 !expr_uses_var(i->condition, op->name)) {
553                 Stmt s = For::make(op->name, op->min, op->extent,
554                                    op->for_type, op->device_api, i->then_case);
555                 return IfThenElse::make(i->condition, s);
556             }
557         }
558         return For::make(op->name, op->min, op->extent,
559                          op->for_type, op->device_api, body);
560     }
561 
visit(const ProducerConsumer * op)562     Stmt visit(const ProducerConsumer *op) override {
563         Stmt body = mutate(op->body);
564         if (const IfThenElse *i = body.as<IfThenElse>()) {
565             if (!i->else_case.defined() &&
566                 is_pure(i->condition)) {
567                 Stmt s = ProducerConsumer::make(op->name, op->is_producer, i->then_case);
568                 return IfThenElse::make(i->condition, s);
569             }
570         }
571         return ProducerConsumer::make(op->name, op->is_producer, body);
572     }
573 
visit(const IfThenElse * op)574     Stmt visit(const IfThenElse *op) override {
575         Stmt then_case = mutate(op->then_case);
576         if (!op->else_case.defined() &&
577             is_pure(op->condition)) {
578             if (const IfThenElse *i = then_case.as<IfThenElse>()) {
579                 if (!i->else_case.defined() &&
580                     is_pure(i->condition)) {
581                     return IfThenElse::make(op->condition && i->condition, then_case);
582                 }
583             }
584         }
585         return IfThenElse::make(op->condition, then_case, mutate(op->else_case));
586     }
587 
visit(const Allocate * op)588     Stmt visit(const Allocate *op) override {
589         Stmt body = mutate(op->body);
590         if (const IfThenElse *i = body.as<IfThenElse>()) {
591             if (!i->else_case.defined() &&
592                 is_pure(i->condition)) {
593                 Stmt s = Allocate::make(op->name, op->type, op->memory_type,
594                                         op->extents, op->condition, i->then_case,
595                                         op->new_expr, op->free_function);
596                 return IfThenElse::make(i->condition, s);
597             }
598         }
599         return Allocate::make(op->name, op->type, op->memory_type,
600                               op->extents, op->condition, body,
601                               op->new_expr, op->free_function);
602     }
603 
visit(const Block * op)604     Stmt visit(const Block *op) override {
605         Stmt first = mutate(op->first);
606         Stmt rest = mutate(op->rest);
607 
608         const IfThenElse *i1 = first.as<IfThenElse>();
609         const Block *b = rest.as<Block>();
610         const IfThenElse *i2 = b ? b->first.as<IfThenElse>() : rest.as<IfThenElse>();
611 
612         if (i1 &&
613             i2 &&
614             !i1->else_case.defined() &&
615             !i2->else_case.defined() &&
616             is_pure(i1->condition) &&
617             can_prove(i1->condition == i2->condition)) {
618             Stmt s = Block::make(i1->then_case, i2->then_case);
619             s = IfThenElse::make(i1->condition, s);
620             if (b) {
621                 s = Block::make(s, b->rest);
622             }
623             return s;
624         } else {
625             return Block::make(first, rest);
626         }
627     }
628 };
629 
630 }  // namespace
631 
hoist_loop_invariant_if_statements(Stmt s)632 Stmt hoist_loop_invariant_if_statements(Stmt s) {
633     s = HoistIfStatements().mutate(s);
634     return s;
635 }
636 
637 }  // namespace Internal
638 }  // namespace Halide
639