1 #include "LICM.h"
2 #include "CSE.h"
3 #include "ExprUsesVar.h"
4 #include "IREquality.h"
5 #include "IRMutator.h"
6 #include "IROperator.h"
7 #include "Scope.h"
8 #include "Simplify.h"
9 #include "Substitute.h"
10
11 namespace Halide {
12 namespace Internal {
13
14 using std::map;
15 using std::pair;
16 using std::set;
17 using std::string;
18 using std::vector;
19
20 // Is it safe to lift an Expr out of a loop (and potentially across a device boundary)
21 class CanLift : public IRVisitor {
22 using IRVisitor::visit;
23
visit(const Call * op)24 void visit(const Call *op) override {
25 if (!op->is_pure()) {
26 result = false;
27 } else {
28 IRVisitor::visit(op);
29 }
30 }
31
visit(const Load * op)32 void visit(const Load *op) override {
33 result = false;
34 }
35
visit(const Variable * op)36 void visit(const Variable *op) override {
37 if (varying.contains(op->name)) {
38 result = false;
39 }
40 }
41
42 const Scope<> &varying;
43
44 public:
45 bool result{true};
46
CanLift(const Scope<> & v)47 CanLift(const Scope<> &v)
48 : varying(v) {
49 }
50 };
51
52 // Lift pure loop invariants to the top level. Applied independently
53 // to each loop.
54 class LiftLoopInvariants : public IRMutator {
55 using IRMutator::visit;
56
57 Scope<> varying;
58
can_lift(const Expr & e)59 bool can_lift(const Expr &e) {
60 CanLift check(varying);
61 e.accept(&check);
62 return check.result;
63 }
64
should_lift(const Expr & e)65 bool should_lift(const Expr &e) {
66 if (!can_lift(e)) return false;
67 if (e.as<Variable>()) return false;
68 if (e.as<Broadcast>()) return false;
69 if (is_const(e)) return false;
70 // bool vectors are buggy enough in LLVM that lifting them is a bad idea.
71 // (We just skip all vectors on the principle that we don't want them
72 // on the stack anyway.)
73 if (e.type().is_vector()) return false;
74 if (const Cast *cast = e.as<Cast>()) {
75 if (cast->type.bytes() > cast->value.type().bytes()) {
76 // Don't lift widening casts.
77 return false;
78 }
79 }
80 if (const Add *add = e.as<Add>()) {
81 if (add->type == Int(32) &&
82 is_const(add->b)) {
83 // Don't lift constant integer offsets. They're often free.
84 return false;
85 }
86 }
87 if (const Call *call = e.as<Call>()) {
88 if (call->is_intrinsic(Call::strict_float) ||
89 call->is_intrinsic(Call::likely) ||
90 call->is_intrinsic(Call::likely_if_innermost) ||
91 call->is_intrinsic(Call::reinterpret)) {
92 // Don't lift these intrinsics. They're free.
93 return should_lift(call->args[0]);
94 }
95 if (call->is_intrinsic(Call::size_of_halide_buffer_t)) {
96 return true;
97 }
98 }
99 return true;
100 }
101
102 template<typename T, typename Body>
visit_let(const T * op)103 Body visit_let(const T *op) {
104 // Visit an entire chain of lets in a single method to conserve stack space.
105 struct Frame {
106 const T *op;
107 Expr new_value;
108 ScopedBinding<> binding;
109 Frame(const T *op, Expr v, Scope<> &scope)
110 : op(op), new_value(std::move(v)), binding(scope, op->name) {
111 }
112 };
113 vector<Frame> frames;
114 Body result;
115 do {
116 frames.emplace_back(op, mutate(op->value), varying);
117 result = op->body;
118 } while ((op = result.template as<T>()));
119
120 result = mutate(result);
121
122 for (auto it = frames.rbegin(); it != frames.rend(); it++) {
123 if (it->new_value.same_as(it->op->value) && result.same_as(it->op->body)) {
124 result = it->op;
125 } else {
126 result = T::make(it->op->name, std::move(it->new_value), result);
127 }
128 }
129
130 return result;
131 }
132
visit(const Let * op)133 Expr visit(const Let *op) override {
134 return visit_let<Let, Expr>(op);
135 }
136
visit(const LetStmt * op)137 Stmt visit(const LetStmt *op) override {
138 return visit_let<LetStmt, Stmt>(op);
139 }
140
visit(const For * op)141 Stmt visit(const For *op) override {
142 ScopedBinding<> p(varying, op->name);
143 return IRMutator::visit(op);
144 }
145
146 public:
147 using IRMutator::mutate;
148
mutate(const Expr & e)149 Expr mutate(const Expr &e) override {
150 if (should_lift(e)) {
151 // Lift it in canonical form
152 Expr lifted_expr = simplify(e);
153 auto it = lifted.find(lifted_expr);
154 if (it == lifted.end()) {
155 string name = unique_name('t');
156 lifted[lifted_expr] = name;
157 return Variable::make(e.type(), name);
158 } else {
159 return Variable::make(e.type(), it->second);
160 }
161 } else {
162 return IRMutator::mutate(e);
163 }
164 }
165
166 map<Expr, string, IRDeepCompare> lifted;
167 };
168
169 // The pass above can lift out the value of lets entirely, leaving
170 // them as just renamings of other variables. Easier to substitute
171 // them in as a post-pass rather than make the pass above more clever.
172 class SubstituteTrivialLets : public IRMutator {
173 using IRMutator::visit;
174
visit(const Let * op)175 Expr visit(const Let *op) override {
176 if (op->value.as<Variable>()) {
177 return mutate(substitute(op->name, op->value, op->body));
178 } else {
179 return IRMutator::visit(op);
180 }
181 }
182
visit(const LetStmt * op)183 Stmt visit(const LetStmt *op) override {
184 if (op->value.as<Variable>()) {
185 return mutate(substitute(op->name, op->value, op->body));
186 } else {
187 return IRMutator::visit(op);
188 }
189 }
190 };
191
192 class LICM : public IRMutator {
193 using IRMutator::visit;
194
195 bool in_gpu_loop{false};
196
197 // Compute the cost of computing an expression inside the inner
198 // loop, compared to just loading it as a parameter.
cost(const Expr & e,const set<string> & vars)199 int cost(const Expr &e, const set<string> &vars) {
200 if (is_const(e)) {
201 return 0;
202 } else if (const Variable *var = e.as<Variable>()) {
203 if (vars.count(var->name)) {
204 // We're loading this already
205 return 0;
206 } else {
207 // Would have to load this
208 return 1;
209 }
210 } else if (const Add *add = e.as<Add>()) {
211 return cost(add->a, vars) + cost(add->b, vars) + 1;
212 } else if (const Sub *sub = e.as<Sub>()) {
213 return cost(sub->a, vars) + cost(sub->b, vars) + 1;
214 } else if (const Mul *mul = e.as<Mul>()) {
215 return cost(mul->a, vars) + cost(mul->b, vars) + 1;
216 } else if (const Call *call = e.as<Call>()) {
217 if (call->is_intrinsic(Call::reinterpret)) {
218 internal_assert(call->args.size() == 1);
219 return cost(call->args[0], vars);
220 } else {
221 return 100;
222 }
223 } else {
224 return 100;
225 }
226 }
227
visit(const For * op)228 Stmt visit(const For *op) override {
229 ScopedValue<bool> old_in_gpu_loop(in_gpu_loop);
230 in_gpu_loop =
231 (op->for_type == ForType::GPUBlock ||
232 op->for_type == ForType::GPUThread);
233
234 if (old_in_gpu_loop && in_gpu_loop) {
235 // Don't lift lets to in-between gpu blocks/threads
236 return IRMutator::visit(op);
237 } else if (op->device_api == DeviceAPI::GLSL) {
238 // GLSL uses magic names for varying things. Just skip LICM.
239 return IRMutator::visit(op);
240 } else {
241
242 // Lift invariants
243 LiftLoopInvariants lifter;
244 Stmt new_stmt = lifter.mutate(op);
245 new_stmt = SubstituteTrivialLets().mutate(new_stmt);
246
247 // As an optimization to reduce register pressure, take
248 // the set of expressions to lift and check if any can
249 // cheaply be computed from others. If so it's better to
250 // do that than to load multiple related values off the
251 // stack. We currently only consider expressions that are
252 // the sum, difference, or product of two variables
253 // already used in the kernel, or a variable plus a
254 // constant.
255
256 // Linearize all the exprs and names
257 vector<Expr> exprs;
258 vector<string> names;
259 for (const auto &p : lifter.lifted) {
260 exprs.push_back(p.first);
261 names.push_back(p.second);
262 }
263
264 // Jointly CSE the lifted exprs put putting them together into a dummy Expr
265 Expr dummy_call = Call::make(Int(32), Call::bundle, exprs, Call::PureIntrinsic);
266 dummy_call = common_subexpression_elimination(dummy_call, true);
267
268 // Peel off containing lets. These will be lifted.
269 vector<pair<string, Expr>> lets;
270 while (const Let *let = dummy_call.as<Let>()) {
271 lets.emplace_back(let->name, let->value);
272 dummy_call = let->body;
273 }
274
275 // Track the set of variables used by the inner loop
276 class CollectVars : public IRVisitor {
277 using IRVisitor::visit;
278 void visit(const Variable *op) override {
279 vars.insert(op->name);
280 }
281
282 public:
283 set<string> vars;
284 } vars;
285 new_stmt.accept(&vars);
286
287 // Now consider substituting back in each use
288 const Call *call = dummy_call.as<Call>();
289 internal_assert(call->is_intrinsic(Call::bundle));
290 bool converged;
291 do {
292 converged = true;
293 for (size_t i = 0; i < exprs.size(); i++) {
294 if (!exprs[i].defined()) continue;
295 Expr e = call->args[i];
296 if (cost(e, vars.vars) <= 1) {
297 // Just subs it back in - computing it is as cheap
298 // as loading it.
299 e.accept(&vars);
300 new_stmt = substitute(names[i], e, new_stmt);
301 names[i].clear();
302 exprs[i] = Expr();
303 converged = false;
304 } else {
305 exprs[i] = e;
306 }
307 }
308 } while (!converged);
309
310 // Recurse
311 const For *loop = new_stmt.as<For>();
312 internal_assert(loop);
313
314 new_stmt = For::make(loop->name, loop->min, loop->extent,
315 loop->for_type, loop->device_api, mutate(loop->body));
316
317 // Wrap lets for the lifted invariants
318 for (size_t i = 0; i < exprs.size(); i++) {
319 if (exprs[i].defined()) {
320 new_stmt = LetStmt::make(names[i], exprs[i], new_stmt);
321 }
322 }
323
324 // Wrap the lets pulled out by CSE
325 while (!lets.empty()) {
326 new_stmt = LetStmt::make(lets.back().first, lets.back().second, new_stmt);
327 lets.pop_back();
328 }
329
330 return new_stmt;
331 }
332 }
333 };
334
335 // Reassociate summations to group together the loop invariants. Useful to run before LICM.
336 class GroupLoopInvariants : public IRMutator {
337 using IRMutator::visit;
338
339 Scope<int> var_depth;
340
341 class ExprDepth : public IRVisitor {
342 using IRVisitor::visit;
343 const Scope<int> &depth;
344
visit(const Variable * op)345 void visit(const Variable *op) override {
346 if (depth.contains(op->name)) {
347 result = std::max(result, depth.get(op->name));
348 }
349 }
350
351 public:
352 int result = 0;
ExprDepth(const Scope<int> & var_depth)353 ExprDepth(const Scope<int> &var_depth)
354 : depth(var_depth) {
355 }
356 };
357
expr_depth(const Expr & e)358 int expr_depth(const Expr &e) {
359 ExprDepth depth(var_depth);
360 e.accept(&depth);
361 return depth.result;
362 }
363
364 struct Term {
365 Expr expr;
366 bool positive;
367 int depth;
368 };
369
extract_summation(const Expr & e)370 vector<Term> extract_summation(const Expr &e) {
371 vector<Term> pending, terms;
372 pending.push_back({e, true, 0});
373 while (!pending.empty()) {
374 Term next = pending.back();
375 pending.pop_back();
376 const Add *add = next.expr.as<Add>();
377 const Sub *sub = next.expr.as<Sub>();
378 if (add) {
379 pending.push_back({add->a, next.positive, 0});
380 pending.push_back({add->b, next.positive, 0});
381 } else if (sub) {
382 pending.push_back({sub->a, next.positive, 0});
383 pending.push_back({sub->b, !next.positive, 0});
384 } else {
385 next.expr = mutate(next.expr);
386 if (next.expr.as<Add>() || next.expr.as<Sub>()) {
387 // After mutation it became an add or sub, throw it back on the pending queue.
388 pending.push_back(next);
389 } else {
390 next.depth = expr_depth(next.expr);
391 terms.push_back(next);
392 }
393 }
394 }
395
396 // Sort the terms by loop depth. Terms of equal depth are
397 // likely already in a good order, so don't mess with them.
398 std::stable_sort(terms.begin(), terms.end(),
399 [](const Term &a, const Term &b) {
400 return a.depth > b.depth;
401 });
402
403 return terms;
404 }
405
reassociate_summation(const Expr & e)406 Expr reassociate_summation(const Expr &e) {
407 vector<Term> terms = extract_summation(e);
408
409 Expr result;
410 bool positive = true;
411 while (!terms.empty()) {
412 Term next = terms.back();
413 terms.pop_back();
414 if (result.defined()) {
415 if (next.positive == positive) {
416 result += next.expr;
417 } else if (next.positive) {
418 result = next.expr - result;
419 positive = true;
420 } else {
421 result -= next.expr;
422 }
423 } else {
424 result = next.expr;
425 positive = next.positive;
426 }
427 }
428
429 if (!positive) {
430 result = make_zero(result.type()) - result;
431 }
432
433 return result;
434 }
435
visit(const Add * op)436 Expr visit(const Add *op) override {
437 if (op->type.is_float() || (op->type == Int(32) && is_const(op->b))) {
438 // Don't reassociate float exprs. (If strict_float is
439 // off, we're allowed to reassociate, and we do
440 // reassociate elsewhere, but there's no benefit to it
441 // here and it's friendlier not to.)
442 //
443 // Also don't reassociate trailing integer constants. They're the
444 // ultimate loop invariant, but doing this to stencils
445 // causes inner loops to track N different pointers
446 // instead of one pointer with constant offsets, and that
447 // complicates aliasing analysis.
448 return IRMutator::visit(op);
449 }
450
451 return reassociate_summation(op);
452 }
453
visit(const Sub * op)454 Expr visit(const Sub *op) override {
455 if (op->type.is_float() || (op->type == Int(32) && is_const(op->b))) {
456 return IRMutator::visit(op);
457 }
458
459 return reassociate_summation(op);
460 }
461
462 int depth = 0;
463
visit(const For * op)464 Stmt visit(const For *op) override {
465 depth++;
466 ScopedBinding<int> bind(var_depth, op->name, depth);
467 Stmt stmt = IRMutator::visit(op);
468 depth--;
469 return stmt;
470 }
471
472 template<typename T, typename Body>
visit_let(const T * op)473 Body visit_let(const T *op) {
474 struct Frame {
475 const T *op;
476 Expr new_value;
477 ScopedBinding<int> binding;
478 Frame(const T *op, Expr v, int depth, Scope<int> &scope)
479 : op(op),
480 new_value(std::move(v)),
481 binding(scope, op->name, depth) {
482 }
483 };
484 std::vector<Frame> frames;
485 Body result;
486
487 do {
488 result = op->body;
489 int d = 0;
490 if (depth > 0) {
491 d = expr_depth(op->value);
492 }
493 frames.emplace_back(op, mutate(op->value), d, var_depth);
494 } while ((op = result.template as<T>()));
495
496 result = mutate(result);
497
498 for (auto it = frames.rbegin(); it != frames.rend(); it++) {
499 if (it->new_value.same_as(it->op->value) && result.same_as(it->op->body)) {
500 result = it->op;
501 } else {
502 result = T::make(it->op->name, it->new_value, result);
503 }
504 }
505
506 return result;
507 }
508
visit(const Let * op)509 Expr visit(const Let *op) override {
510 return visit_let<Let, Expr>(op);
511 }
512
visit(const LetStmt * op)513 Stmt visit(const LetStmt *op) override {
514 return visit_let<LetStmt, Stmt>(op);
515 }
516 };
517
hoist_loop_invariant_values(Stmt s)518 Stmt hoist_loop_invariant_values(Stmt s) {
519 s = GroupLoopInvariants().mutate(s);
520 s = common_subexpression_elimination(s);
521 s = LICM().mutate(s);
522 s = simplify_exprs(s);
523 return s;
524 }
525
526 namespace {
527
528 // Move IfThenElse nodes from the inside of a piece of Stmt IR to the
529 // outside when legal.
530 class HoistIfStatements : public IRMutator {
531 using IRMutator::visit;
532
visit(const LetStmt * op)533 Stmt visit(const LetStmt *op) override {
534 Stmt body = mutate(op->body);
535 if (const IfThenElse *i = body.as<IfThenElse>()) {
536 if (!i->else_case.defined() &&
537 is_pure(op->value) &&
538 is_pure(i->condition) &&
539 !expr_uses_var(i->condition, op->name)) {
540 Stmt s = LetStmt::make(op->name, op->value, i->then_case);
541 return IfThenElse::make(i->condition, s);
542 }
543 }
544 return LetStmt::make(op->name, op->value, body);
545 }
546
visit(const For * op)547 Stmt visit(const For *op) override {
548 Stmt body = mutate(op->body);
549 if (const IfThenElse *i = body.as<IfThenElse>()) {
550 if (!i->else_case.defined() &&
551 is_pure(i->condition) &&
552 !expr_uses_var(i->condition, op->name)) {
553 Stmt s = For::make(op->name, op->min, op->extent,
554 op->for_type, op->device_api, i->then_case);
555 return IfThenElse::make(i->condition, s);
556 }
557 }
558 return For::make(op->name, op->min, op->extent,
559 op->for_type, op->device_api, body);
560 }
561
visit(const ProducerConsumer * op)562 Stmt visit(const ProducerConsumer *op) override {
563 Stmt body = mutate(op->body);
564 if (const IfThenElse *i = body.as<IfThenElse>()) {
565 if (!i->else_case.defined() &&
566 is_pure(i->condition)) {
567 Stmt s = ProducerConsumer::make(op->name, op->is_producer, i->then_case);
568 return IfThenElse::make(i->condition, s);
569 }
570 }
571 return ProducerConsumer::make(op->name, op->is_producer, body);
572 }
573
visit(const IfThenElse * op)574 Stmt visit(const IfThenElse *op) override {
575 Stmt then_case = mutate(op->then_case);
576 if (!op->else_case.defined() &&
577 is_pure(op->condition)) {
578 if (const IfThenElse *i = then_case.as<IfThenElse>()) {
579 if (!i->else_case.defined() &&
580 is_pure(i->condition)) {
581 return IfThenElse::make(op->condition && i->condition, then_case);
582 }
583 }
584 }
585 return IfThenElse::make(op->condition, then_case, mutate(op->else_case));
586 }
587
visit(const Allocate * op)588 Stmt visit(const Allocate *op) override {
589 Stmt body = mutate(op->body);
590 if (const IfThenElse *i = body.as<IfThenElse>()) {
591 if (!i->else_case.defined() &&
592 is_pure(i->condition)) {
593 Stmt s = Allocate::make(op->name, op->type, op->memory_type,
594 op->extents, op->condition, i->then_case,
595 op->new_expr, op->free_function);
596 return IfThenElse::make(i->condition, s);
597 }
598 }
599 return Allocate::make(op->name, op->type, op->memory_type,
600 op->extents, op->condition, body,
601 op->new_expr, op->free_function);
602 }
603
visit(const Block * op)604 Stmt visit(const Block *op) override {
605 Stmt first = mutate(op->first);
606 Stmt rest = mutate(op->rest);
607
608 const IfThenElse *i1 = first.as<IfThenElse>();
609 const Block *b = rest.as<Block>();
610 const IfThenElse *i2 = b ? b->first.as<IfThenElse>() : rest.as<IfThenElse>();
611
612 if (i1 &&
613 i2 &&
614 !i1->else_case.defined() &&
615 !i2->else_case.defined() &&
616 is_pure(i1->condition) &&
617 can_prove(i1->condition == i2->condition)) {
618 Stmt s = Block::make(i1->then_case, i2->then_case);
619 s = IfThenElse::make(i1->condition, s);
620 if (b) {
621 s = Block::make(s, b->rest);
622 }
623 return s;
624 } else {
625 return Block::make(first, rest);
626 }
627 }
628 };
629
630 } // namespace
631
hoist_loop_invariant_if_statements(Stmt s)632 Stmt hoist_loop_invariant_if_statements(Stmt s) {
633 s = HoistIfStatements().mutate(s);
634 return s;
635 }
636
637 } // namespace Internal
638 } // namespace Halide
639