1 #include <algorithm>
2 #include <numeric>
3 #include <utility>
4
5 #include "CSE.h"
6 #include "CodeGen_GPU_Dev.h"
7 #include "ExprUsesVar.h"
8 #include "IREquality.h"
9 #include "IRMutator.h"
10 #include "IROperator.h"
11 #include "PartitionLoops.h"
12 #include "Simplify.h"
13 #include "Solve.h"
14 #include "Substitute.h"
15 #include "Var.h"
16
17 namespace Halide {
18 namespace Internal {
19
20 using std::pair;
21 using std::string;
22 using std::vector;
23
24 namespace {
25
26 // Loop partitioning only applies to things marked as 'likely'. Loads
27 // through hand-written boundary conditions will produce clamped
28 // ramps, which will turn into gathers. This pass injects likely
29 // intrinsics so that these clamped ramps are picked up by loop
30 // partitioning.
31 class MarkClampedRampsAsLikely : public IRMutator {
32 using IRMutator::visit;
visit(const Min * op)33 Expr visit(const Min *op) override {
34 if (in_index && op->a.as<Ramp>()) {
35 // No point recursing into the ramp - it can't contain
36 // another ramp.
37 return min(likely(op->a), mutate(op->b));
38 } else if (in_index && op->b.as<Ramp>()) {
39 return min(mutate(op->a), likely(op->b));
40 } else {
41 return IRMutator::visit(op);
42 }
43 }
44
visit(const Max * op)45 Expr visit(const Max *op) override {
46 if (in_index && op->a.as<Ramp>()) {
47 return max(likely(op->a), mutate(op->b));
48 } else if (in_index && op->b.as<Ramp>()) {
49 return max(mutate(op->a), likely(op->b));
50 } else {
51 return IRMutator::visit(op);
52 }
53 }
54
visit(const Load * op)55 Expr visit(const Load *op) override {
56 bool old_in_index = in_index;
57 in_index = true;
58 Expr expr = IRMutator::visit(op);
59 in_index = old_in_index;
60 return expr;
61 }
62
visit(const Store * op)63 Stmt visit(const Store *op) override {
64 bool old_in_index = in_index;
65 in_index = true;
66 Expr index = mutate(op->index);
67 in_index = old_in_index;
68 Expr value = mutate(op->value);
69 Expr predicate = mutate(op->predicate);
70 if (predicate.same_as(op->predicate) && index.same_as(op->index) && value.same_as(op->value)) {
71 return op;
72 } else {
73 return Store::make(op->name, value, index, op->param, predicate, op->alignment);
74 }
75 }
76
77 bool in_index = false;
78 };
79
80 // Check if an expression or statement uses a likely tag
81 class HasLikelyTag : public IRVisitor {
82 protected:
83 using IRVisitor::visit;
visit(const Call * op)84 void visit(const Call *op) override {
85 if (op->is_intrinsic(Call::likely)) {
86 result = true;
87 } else {
88 IRVisitor::visit(op);
89 }
90 }
91
92 public:
93 bool result = false;
94 };
95
96 class HasUncapturedLikelyTag : public HasLikelyTag {
97 using HasLikelyTag::visit;
98
99 // Any likelies buried inside the following ops are captured the by respective ops
visit(const Select * op)100 void visit(const Select *op) override {
101 }
visit(const Min * op)102 void visit(const Min *op) override {
103 }
visit(const Max * op)104 void visit(const Max *op) override {
105 }
106 };
107
108 // The goal of loop partitioning is to split loops up into a prologue,
109 // a clean steady state, and an epilogue. The next visitor
110 // (FindSimplifications) finds a list of simplifications that can be
111 // applied to produce that clean steady-state version of the loop
112 // body. It tries to simplify selects, mins, and maxes to just their
113 // likely branch. For example:
114 //
115 // select(a, likely(b), c) -> b
116 // select(a, b, 5 + likely(c)) -> 5 + c
117 // max(a, likely(b)) -> b
118 //
119 // These three simplifications are only valid if a is true, false, or
120 // less than b, respectively. So we visit the loop body looking for
121 // these sort of things, record the associated conditions, and try to
122 // solve for a range of the loop variable for which all of our
123 // conditions are true (by solving for each one and then taking the
124 // intersection). That gives us the clean steady state.
125 //
126 // It may be that we can also make some simplifications to the
127 // prologue or epilogue. For example, consider the case:
128 //
129 // select(x > 0, likely(expr_t), expr_f)
130 //
131 // It can simplify to expr_t when x > 0. However, if this is the sole
132 // simplification which gives us a lower bound on x for the steady
133 // state, we can also simplify this select in the prologue to just be
134 // expr_f.
135 //
136 // Now consider this case:
137 //
138 // (select(x > a, likely(expr_t1), expr_f1) +
139 // select(x > b, likely(expr_t2), expr_f2))
140 //
141 // The steady state starts at x == max(a, b), which we get from the
142 // intersection of the intervals derived from each condition: x > a
143 // and x > b. In the steady state, the expression simplifies to
144 // expr_t1 + expr_t2. In the prologue we know that either x <= a or x
145 // <= b, but we don't know which one might be true, so we can't make
146 // any simplifications to the prologue.
147 //
148 // We may also encounter single conditions where we can simplify the
149 // steady-state but not the prologue. Say we're splitting up a loop
150 // over x and we encounter a condition that depends on a variable
151 // introduced in some inner loop:
152 //
153 // for x:
154 // for z from 0 to 10:
155 // ... select(x > z, likely(expr_t), expr_f) ...
156 //
157 // This select definitely simplifies to expr_t when x > 9, because
158 // that's the maximum value z could be, so we'll start the steady
159 // state at x == 10. This means the prologue covers values like x ==
160 // 5, where the select could be either true or false, so we can't make
161 // any simplifications to the prologue.
162 //
163 // There are some simplifications that we won't be able to do. For
164 // example, if we're partitioning the loop over x, and we encounter:
165 //
166 // for x:
167 // for z from 0 to 10:
168 // ... select(z < 5, likely(expr_t), expr_f)
169 //
170 // Restricting the range of x isn't going to simplify that expression
171 // - it doesn't even depend on x. We just make all the simplifications
172 // that we can, and take the intersection of the resulting regions. In
173 // this case, we'll make that simplification later, when we do loop
174 // partitioning over the loop in z. Some cases we'll never
175 // handle. E.g. consider:
176 //
177 // for x:
178 // ... select(a + x*(b + x*(c + x*(d + x*e))) > 0, likely(expr_t), expr_f)
179 //
180 // In order to simplify that we'd have to come up with a formula that
181 // tells us an interval where a quintic is strictly positive. No such
182 // general formula exists (because no formula exists for the roots),
183 // so there's no programmatic way we can partition the loop over x to
184 // make that condition simplify. Finally my Galois theory course pays
185 // off. For failures like this, we just drop the likely tag. So loop
186 // partitioning is best-effort, but it should always work for things
187 // like x > a. A simpler case for which we bail is:
188 //
189 // for x:
190 // ... select(x == 5, expr_t, likely(expr_f))
191 //
192 // This simplifies to the likely case in two disjoint ranges, but
193 // we're only producing one steady state, and we have no reason to
194 // believe one side is better than the other, so we just bail and drop
195 // the likely tag.
196
197 // First we define the struct that represents a single simplification
198 // that can be applied to the steady state of the loop.
199 struct Simplification {
200 // This condition is sufficient for the simplification to occur.
201 Expr condition;
202 // The expression we're simplifying
203 Expr old_expr;
204 // The replacement if the condition is true
205 Expr likely_value;
206 // The replacement if the condition is false. Not useful
207 // unless it's tight.
208 Expr unlikely_value;
209 // Is the condition necessary (as well as sufficient)?
210 bool tight;
211 // The interval over which this simplification applies. Comes from solving the condition.
212 Interval interval;
213 };
214
215 class ExprUsesInvalidBuffers : public IRVisitor {
216 using IRVisitor::visit;
217
218 const Scope<> &invalid_buffers;
219
visit(const Load * op)220 void visit(const Load *op) override {
221 if (invalid_buffers.contains(op->name)) {
222 invalid = true;
223 } else {
224 IRVisitor::visit(op);
225 }
226 }
227
228 public:
ExprUsesInvalidBuffers(const Scope<> & buffers)229 ExprUsesInvalidBuffers(const Scope<> &buffers)
230 : invalid_buffers(buffers), invalid(false) {
231 }
232 bool invalid;
233 };
234
235 /** Check if any references to buffers in an expression is invalid. */
expr_uses_invalid_buffers(const Expr & e,const Scope<> & invalid_buffers)236 bool expr_uses_invalid_buffers(const Expr &e, const Scope<> &invalid_buffers) {
237 ExprUsesInvalidBuffers uses(invalid_buffers);
238 e.accept(&uses);
239 return uses.invalid;
240 }
241
242 // Then we define the visitor that hunts for them.
243 class FindSimplifications : public IRVisitor {
244 using IRVisitor::visit;
245
246 Scope<> depends_on_loop_var, depends_on_invalid_buffers;
247 Scope<> buffers;
248
visit(const Allocate * op)249 void visit(const Allocate *op) override {
250 buffers.push(op->name);
251 IRVisitor::visit(op);
252 }
253
new_simplification(Expr condition,Expr old,Expr likely_val,Expr unlikely_val)254 void new_simplification(Expr condition, Expr old, Expr likely_val, Expr unlikely_val) {
255 if (!expr_uses_vars(condition, depends_on_loop_var)) {
256 return;
257 }
258
259 if (expr_uses_vars(condition, depends_on_invalid_buffers) ||
260 expr_uses_invalid_buffers(condition, buffers)) {
261 // The condition refers to buffer allocated in the inner loop.
262 // We should throw away the condition
263 return;
264 }
265 condition = remove_likelies(condition);
266 Simplification s = {condition, std::move(old), std::move(likely_val), std::move(unlikely_val), true};
267 if (s.condition.type().is_vector()) {
268 s.condition = simplify(s.condition);
269 if (const Broadcast *b = s.condition.as<Broadcast>()) {
270 s.condition = b->value;
271 } else {
272 // Devectorize the condition
273 s.condition = and_condition_over_domain(s.condition, Scope<Interval>::empty_scope());
274 s.tight = false;
275 }
276 }
277 internal_assert(s.condition.type().is_scalar()) << s.condition << "\n";
278 simplifications.push_back(s);
279 }
280
visit(const Min * op)281 void visit(const Min *op) override {
282 bool likely_a = has_uncaptured_likely_tag(op->a);
283 bool likely_b = has_uncaptured_likely_tag(op->b);
284
285 // Prefer the side that has an uncaptured top-level likely
286 // call. If neither does, prefer the side that contains any
287 // likely call at all.
288 if (!likely_a && !likely_b) {
289 likely_a = has_likely_tag(op->a);
290 likely_b = has_likely_tag(op->b);
291 }
292
293 // Don't hunt for simplifications in unlikely paths
294 if (!likely_a) {
295 op->b.accept(this);
296 }
297 if (!likely_b) {
298 op->a.accept(this);
299 }
300
301 if (likely_b && !likely_a) {
302 new_simplification(op->b <= op->a, op, op->b, op->a);
303 } else if (likely_a && !likely_b) {
304 new_simplification(op->a <= op->b, op, op->a, op->b);
305 }
306 }
307
visit(const Max * op)308 void visit(const Max *op) override {
309 bool likely_a = has_uncaptured_likely_tag(op->a);
310 bool likely_b = has_uncaptured_likely_tag(op->b);
311
312 if (!likely_a && !likely_b) {
313 likely_a = has_likely_tag(op->a);
314 likely_b = has_likely_tag(op->b);
315 }
316
317 if (!likely_a) {
318 op->b.accept(this);
319 }
320 if (!likely_b) {
321 op->a.accept(this);
322 }
323
324 if (likely_b && !likely_a) {
325 new_simplification(op->b >= op->a, op, op->b, op->a);
326 } else if (likely_a && !likely_b) {
327 new_simplification(op->a >= op->b, op, op->a, op->b);
328 }
329 }
330
visit(const Select * op)331 void visit(const Select *op) override {
332 op->condition.accept(this);
333
334 bool likely_t = has_uncaptured_likely_tag(op->true_value);
335 bool likely_f = has_uncaptured_likely_tag(op->false_value);
336
337 if (!likely_t && !likely_f) {
338 likely_t = has_likely_tag(op->true_value);
339 likely_f = has_likely_tag(op->false_value);
340 }
341
342 if (!likely_t) {
343 op->false_value.accept(this);
344 }
345 if (!likely_f) {
346 op->true_value.accept(this);
347 }
348
349 if (likely_t && !likely_f) {
350 new_simplification(op->condition, op, op->true_value, op->false_value);
351 } else if (likely_f && !likely_t) {
352 new_simplification(!op->condition, op, op->false_value, op->true_value);
353 }
354 }
355
visit(const IfThenElse * op)356 void visit(const IfThenElse *op) override {
357 // For select statements, mins, and maxes, you can mark the
358 // likely branch with likely. For if statements there's no way
359 // to mark the likely stmt. So if the condition of an if
360 // statement is marked as likely, treat it as likely true and
361 // partition accordingly.
362 IRVisitor::visit(op);
363 const Call *call = op->condition.as<Call>();
364 if (call && call->is_intrinsic(Call::likely)) {
365 new_simplification(op->condition, op->condition, const_true(), const_false());
366 }
367 }
368
visit(const For * op)369 void visit(const For *op) override {
370 vector<Simplification> old;
371 old.swap(simplifications);
372 IRVisitor::visit(op);
373
374 // Relax all the new conditions using the loop bounds
375 for (Simplification &s : simplifications) {
376 if (expr_uses_var(s.condition, op->name)) {
377 Scope<Interval> varying;
378 varying.push(op->name, Interval(op->min, op->min + op->extent - 1));
379 Expr relaxed = and_condition_over_domain(s.condition, varying);
380 internal_assert(!expr_uses_var(relaxed, op->name))
381 << "Should not have had used the loop var (" << op->name
382 << ") any longer\n before: " << s.condition << "\n after: "
383 << relaxed << "\n";
384 if (!equal(relaxed, s.condition)) {
385 s.tight = false;
386 }
387 s.condition = relaxed;
388 }
389 }
390
391 simplifications.insert(simplifications.end(), old.begin(), old.end());
392 }
393
394 template<typename LetOrLetStmt>
visit_let(const LetOrLetStmt * op)395 void visit_let(const LetOrLetStmt *op) {
396 ScopedBinding<> bind_varying(expr_uses_vars(op->value, depends_on_loop_var),
397 depends_on_loop_var, op->name);
398 ScopedBinding<> bind_invalid(expr_uses_invalid_buffers(op->value, buffers) ||
399 expr_uses_vars(op->value, depends_on_invalid_buffers),
400 depends_on_invalid_buffers, op->name);
401 vector<Simplification> old;
402 old.swap(simplifications);
403 IRVisitor::visit(op);
404 for (Simplification &s : simplifications) {
405 if (expr_uses_var(s.condition, op->name)) {
406 s.condition = Let::make(op->name, op->value, s.condition);
407 }
408 }
409 simplifications.insert(simplifications.end(), old.begin(), old.end());
410 }
411
visit(const LetStmt * op)412 void visit(const LetStmt *op) override {
413 visit_let(op);
414 }
415
visit(const Let * op)416 void visit(const Let *op) override {
417 visit_let(op);
418 }
419
420 public:
421 vector<Simplification> simplifications;
422
FindSimplifications(const std::string & v)423 FindSimplifications(const std::string &v) {
424 depends_on_loop_var.push(v);
425 }
426 };
427
428 // Blindly apply a list of simplifications.
429 class MakeSimplifications : public IRMutator {
430 using IRMutator::visit;
431
432 const vector<Simplification> &simplifications;
433
434 public:
MakeSimplifications(const vector<Simplification> & s)435 MakeSimplifications(const vector<Simplification> &s)
436 : simplifications(s) {
437 }
438
439 using IRMutator::mutate;
mutate(const Expr & e)440 Expr mutate(const Expr &e) override {
441 for (auto const &s : simplifications) {
442 if (e.same_as(s.old_expr)) {
443 return mutate(s.likely_value);
444 }
445 }
446 return IRMutator::mutate(e);
447 }
448 };
449
450 class ContainsWarpSynchronousLogic : public IRVisitor {
451 public:
452 bool result = false;
453
454 protected:
455 using IRVisitor::visit;
visit(const Call * op)456 void visit(const Call *op) override {
457 if (op->is_intrinsic(Call::gpu_thread_barrier)) {
458 result = true;
459 } else {
460 IRVisitor::visit(op);
461 }
462 }
463
visit(const For * op)464 void visit(const For *op) override {
465 if (op->for_type == ForType::GPULane) {
466 result = true;
467 } else {
468 IRVisitor::visit(op);
469 }
470 }
471
visit(const Load * op)472 void visit(const Load *op) override {
473 }
474 };
475
contains_warp_synchronous_logic(const Stmt & s)476 bool contains_warp_synchronous_logic(const Stmt &s) {
477 ContainsWarpSynchronousLogic c;
478 s.accept(&c);
479 return c.result;
480 }
481
482 class PartitionLoops : public IRMutator {
483 using IRMutator::visit;
484
485 bool in_gpu_loop = false;
486
visit(const For * op)487 Stmt visit(const For *op) override {
488 Stmt body = op->body;
489
490 ScopedValue<bool> old_in_gpu_loop(in_gpu_loop, in_gpu_loop ||
491 CodeGen_GPU_Dev::is_gpu_var(op->name));
492
493 // If we're inside GPU kernel, and the body contains thread
494 // barriers or warp shuffles, it's not safe to partition loops.
495 if (in_gpu_loop && contains_warp_synchronous_logic(op)) {
496 return IRMutator::visit(op);
497 }
498
499 // We shouldn't partition GLSL loops - they have control-flow
500 // constraints.
501 if (op->device_api == DeviceAPI::GLSL) {
502 return op;
503 }
504
505 // Find simplifications in this loop body
506 FindSimplifications finder(op->name);
507 body.accept(&finder);
508
509 if (finder.simplifications.empty()) {
510 return IRMutator::visit(op);
511 }
512
513 debug(3) << "\n\n**** Partitioning loop over " << op->name << "\n";
514
515 vector<Expr> min_vals, max_vals;
516 vector<Simplification> middle_simps, prologue_simps, epilogue_simps;
517 bool lower_bound_is_tight = true, upper_bound_is_tight = true;
518 for (auto &s : finder.simplifications) {
519 // Solve for the interval over which this simplification is true.
520 s.interval = solve_for_inner_interval(s.condition, op->name);
521 if (s.tight) {
522 // Check if the solve is tight. I.e. the condition is
523 // definitely false outside of the interval.
524 Interval outer = solve_for_outer_interval(s.condition, op->name);
525 s.tight &= equal(outer.min, s.interval.min) && equal(outer.max, s.interval.max);
526 }
527
528 debug(3) << "\nSimplification: \n"
529 << " condition: " << s.condition << "\n"
530 << " old: " << s.old_expr << "\n"
531 << " new: " << s.likely_value << "\n"
532 << " min: " << s.interval.min << "\n"
533 << " max: " << s.interval.max << "\n"
534 << " tight: " << s.tight << "\n";
535
536 // Accept all non-empty intervals
537 if (!s.interval.is_empty()) {
538 if (s.interval.has_lower_bound()) {
539 Expr m = s.interval.min;
540 if (!s.tight) {
541 lower_bound_is_tight = false;
542 }
543 if (min_vals.empty()) {
544 min_vals.push_back(m);
545 } else if (equal(m, min_vals.back())) {
546 // We already have this min val
547 } else {
548 // This is a new distinct min val
549 min_vals.push_back(m);
550 lower_bound_is_tight = false;
551 }
552 }
553 if (s.interval.has_upper_bound()) {
554 Expr m = s.interval.max;
555 if (!s.tight) {
556 upper_bound_is_tight = false;
557 }
558 if (max_vals.empty()) {
559 max_vals.push_back(m);
560 } else if (equal(m, max_vals.back())) {
561 // We already have this max val
562 } else {
563 // This is a new distinct max val
564 max_vals.push_back(m);
565 upper_bound_is_tight = false;
566 }
567 }
568
569 // We'll apply this simplification to the
570 // steady-state.
571 middle_simps.push_back(s);
572 }
573 }
574
575 // In general we can't simplify the prologue - it may run up
576 // to after the epilogue starts for small images. However if
577 // we can prove the epilogue starts after the prologue ends,
578 // we're OK.
579 bool can_simplify_prologue = true;
580 for (Expr min_val : min_vals) {
581 for (Expr max_val : max_vals) {
582 Expr test = simplify(common_subexpression_elimination(min_val - 1 < max_val + 1));
583 if (!is_one(test)) {
584 can_simplify_prologue = false;
585 }
586 }
587 }
588
589 // Find simplifications we can apply to the prologue and epilogue.
590 for (const auto &s : middle_simps) {
591 // If it goes down to minus infinity, we can also
592 // apply it to the prologue
593 if (can_simplify_prologue &&
594 !s.interval.has_lower_bound()) {
595 prologue_simps.push_back(s);
596 }
597
598 // If it goes up to positive infinity, we can also
599 // apply it to the epilogue
600 if (!s.interval.has_upper_bound()) {
601 epilogue_simps.push_back(s);
602 }
603
604 // If our simplifications only contain one lower bound, and
605 // it's tight, then the reverse rule can be applied to the
606 // prologue.
607 if (can_simplify_prologue &&
608 s.interval.has_lower_bound() &&
609 lower_bound_is_tight) {
610 internal_assert(s.tight);
611 Simplification s2 = s;
612 // This condition is never used (we already solved
613 // for the interval), but it's nice for it to be
614 // correct.
615 s2.condition = !s2.condition;
616 std::swap(s2.likely_value, s2.unlikely_value);
617 prologue_simps.push_back(s2);
618 }
619 if (s.interval.has_upper_bound() &&
620 upper_bound_is_tight) {
621 internal_assert(s.tight);
622 Simplification s2 = s;
623 s2.condition = !s2.condition;
624 std::swap(s2.likely_value, s2.unlikely_value);
625 epilogue_simps.push_back(s2);
626 }
627 }
628
629 // Simplify each section of the loop.
630 Stmt simpler_body = MakeSimplifications(middle_simps).mutate(body);
631 Stmt prologue = MakeSimplifications(prologue_simps).mutate(body);
632 Stmt epilogue = MakeSimplifications(epilogue_simps).mutate(body);
633
634 bool make_prologue = !equal(prologue, simpler_body);
635 bool make_epilogue = !equal(epilogue, simpler_body);
636
637 // Recurse on the middle section.
638 simpler_body = mutate(simpler_body);
639
640 // Construct variables for the bounds of the simplified middle section
641 Expr min_steady = op->min, max_steady = op->extent + op->min;
642 Expr prologue_val, epilogue_val;
643 string prologue_name = unique_name(op->name + ".prologue");
644 string epilogue_name = unique_name(op->name + ".epilogue");
645
646 if (make_prologue) {
647 // They'll simplify better if you put them in
648 // lexicographic order. This puts things like (x+1) and
649 // (x+3) next to each other so that the simplifier sees
650 // them together and can drop one of them.
651 std::sort(min_vals.begin(), min_vals.end(), IRDeepCompare());
652 min_vals.push_back(op->min);
653 prologue_val = fold_left(min_vals, Max::make);
654 // Stop the prologue from running past the end of the loop
655 prologue_val = min(prologue_val, op->extent + op->min);
656 // prologue_val = print(prologue_val, prologue_name);
657 min_steady = Variable::make(Int(32), prologue_name);
658
659 internal_assert(!expr_uses_var(prologue_val, op->name));
660 }
661 if (make_epilogue) {
662 std::sort(max_vals.begin(), max_vals.end(), IRDeepCompare());
663 max_vals.push_back(op->min + op->extent - 1);
664 epilogue_val = fold_left(max_vals, Min::make) + 1;
665 // Stop the epilogue from running before the start of the loop/prologue
666 if (make_prologue) {
667 epilogue_val = max(epilogue_val, prologue_val);
668 } else {
669 epilogue_val = max(op->min, epilogue_val);
670 }
671 // epilogue_val = print(epilogue_val, epilogue_name);
672 max_steady = Variable::make(Int(32), epilogue_name);
673
674 internal_assert(!expr_uses_var(epilogue_val, op->name));
675 }
676
677 Stmt stmt;
678 // Bust simple serial for loops up into three.
679 if (op->for_type == ForType::Serial && !op->body.as<Acquire>()) {
680 stmt = For::make(op->name, min_steady, max_steady - min_steady,
681 op->for_type, op->device_api, simpler_body);
682
683 if (make_prologue) {
684 prologue = For::make(op->name, op->min, min_steady - op->min,
685 op->for_type, op->device_api, prologue);
686 stmt = Block::make(prologue, stmt);
687 }
688 if (make_epilogue) {
689 epilogue = For::make(op->name, max_steady, op->min + op->extent - max_steady,
690 op->for_type, op->device_api, epilogue);
691 stmt = Block::make(stmt, epilogue);
692 }
693 } else {
694 // For parallel for loops we could use a Fork node here,
695 // but that would introduce the more complicated parallel
696 // runtime into code that doesn't use async(), which may
697 // interfere with legacy overrides of
698 // halide_do_par_for. So for parallel for loops just put
699 // an if-then-else in the loop body. It should
700 // branch-predict to the steady state pretty well.
701 //
702 // Simple serial for loops that contain an Acquire node go
703 // into the task system as a single entity, but Block
704 // nodes do not, so we get a flatter task graph if we do
705 // the same trick.
706 Expr loop_var = Variable::make(Int(32), op->name);
707 stmt = simpler_body;
708 if (make_epilogue && make_prologue && equal(prologue, epilogue)) {
709 stmt = IfThenElse::make(min_steady <= loop_var && loop_var < max_steady, stmt, prologue);
710 } else {
711 if (make_epilogue) {
712 stmt = IfThenElse::make(loop_var < max_steady, stmt, epilogue);
713 }
714 if (make_prologue) {
715 stmt = IfThenElse::make(loop_var < min_steady, prologue, stmt);
716 }
717 }
718 stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, stmt);
719 }
720
721 if (make_epilogue) {
722 // Uncomment to include code that prints the epilogue value
723 //epilogue_val = print(epilogue_val, op->name, "epilogue");
724 stmt = LetStmt::make(epilogue_name, epilogue_val, stmt);
725 } else {
726 epilogue_val = op->min + op->extent;
727 }
728 if (make_prologue) {
729 // Uncomment to include code that prints the prologue value
730 //prologue_val = print(prologue_val, op->name, "prologue");
731 stmt = LetStmt::make(prologue_name, prologue_val, stmt);
732 } else {
733 prologue_val = op->min;
734 }
735
736 if (can_prove(epilogue_val <= prologue_val)) {
737 // The steady state is empty. I've made a huge
738 // mistake. Try to partition a loop further in.
739 return IRMutator::visit(op);
740 }
741
742 debug(3) << "Partition loop.\n"
743 << "Old: " << Stmt(op) << "\n"
744 << "New: " << stmt << "\n";
745
746 return stmt;
747 }
748 };
749
750 class ExprContainsLoad : public IRVisitor {
751 using IRVisitor::visit;
752
visit(const Load * op)753 void visit(const Load *op) override {
754 result = true;
755 }
756
757 public:
758 bool result = false;
759 };
760
expr_contains_load(const Expr & e)761 bool expr_contains_load(const Expr &e) {
762 ExprContainsLoad l;
763 e.accept(&l);
764 return l.result;
765 }
766
767 // The loop partitioning logic can introduce if and let statements in
768 // between GPU loop levels. This pass moves them inwards or outwards.
769 class RenormalizeGPULoops : public IRMutator {
770 bool in_gpu_loop = false, in_thread_loop = false;
771
772 using IRMutator::visit;
773
774 // Track all vars that depend on GPU loop indices or loops inside GPU kernels.
775 Scope<> gpu_vars;
776
777 vector<pair<string, Expr>> lifted_lets;
778
visit(const For * op)779 Stmt visit(const For *op) override {
780 if (op->device_api == DeviceAPI::GLSL) {
781 // The partitioner did not enter GLSL loops
782 return op;
783 }
784
785 bool old_in_gpu_loop = in_gpu_loop;
786 Stmt stmt;
787
788 if (in_gpu_loop || CodeGen_GPU_Dev::is_gpu_var(op->name)) {
789 gpu_vars.push(op->name);
790 in_gpu_loop = true;
791 }
792
793 if (ends_with(op->name, "__thread_id_x")) {
794 internal_assert(!in_thread_loop);
795 in_thread_loop = true;
796 stmt = IRMutator::visit(op);
797 in_thread_loop = false;
798 } else {
799 stmt = IRMutator::visit(op);
800 }
801
802 if (in_gpu_loop && !old_in_gpu_loop) {
803 // This was the outermost GPU loop. Dump any lifted lets here.
804 while (lifted_lets.size()) {
805 stmt = LetStmt::make(lifted_lets.back().first,
806 lifted_lets.back().second,
807 stmt);
808 lifted_lets.pop_back();
809 }
810 }
811
812 in_gpu_loop = old_in_gpu_loop;
813 return stmt;
814 }
815
visit(const LetStmt * op)816 Stmt visit(const LetStmt *op) override {
817 if (!in_gpu_loop) {
818 return IRMutator::visit(op);
819 }
820
821 if (!expr_uses_vars(op->value, gpu_vars) && !expr_contains_load(op->value)) {
822 // This let value doesn't depend in the gpu vars. We
823 // should lift it outermost. Note that this might expand
824 // its scope to encompass other uses of the same name, so
825 // we'd better give it a new name.
826 string new_name = unique_name('t');
827 Expr new_var = Variable::make(op->value.type(), new_name);
828 lifted_lets.emplace_back(new_name, op->value);
829 return mutate(substitute(op->name, new_var, op->body));
830 }
831
832 gpu_vars.push(op->name);
833
834 if (in_thread_loop) {
835 return IRMutator::visit(op);
836 }
837
838 Stmt body = mutate(op->body);
839 const For *f = body.as<For>();
840 const Allocate *a = body.as<Allocate>();
841 // Move lets in-between gpu loop levels inwards.
842 if (f && in_gpu_loop && !in_thread_loop) {
843 internal_assert(!expr_uses_var(f->min, op->name) &&
844 !expr_uses_var(f->extent, op->name));
845 Stmt inner = LetStmt::make(op->name, op->value, f->body);
846 inner = For::make(f->name, f->min, f->extent, f->for_type, f->device_api, inner);
847 return mutate(inner);
848 } else if (a && in_gpu_loop && !in_thread_loop) {
849 internal_assert(a->extents.size() == 1);
850 if (expr_uses_var(a->extents[0], op->name)) {
851 // This var depends on the block index, and is used to
852 // define the size of shared memory. Can't move it
853 // inwards or outwards. Codegen will have to deal with
854 // it when it deduces how much shared or warp-level
855 // memory to allocate.
856 return IRMutator::visit(op);
857 } else {
858 Stmt inner = LetStmt::make(op->name, op->value, a->body);
859 inner = Allocate::make(a->name, a->type, a->memory_type, a->extents, a->condition, inner);
860 return mutate(inner);
861 }
862 } else {
863 return IRMutator::visit(op);
864 }
865 }
866
visit(const IfThenElse * op)867 Stmt visit(const IfThenElse *op) override {
868 if (!in_gpu_loop || in_thread_loop) {
869 return IRMutator::visit(op);
870 }
871
872 internal_assert(op->else_case.defined())
873 << "PartitionLoops should only introduce if statements with an else branch\n";
874
875 Stmt then_case = mutate(op->then_case);
876 Stmt else_case = mutate(op->else_case);
877
878 if (equal(then_case, else_case)) {
879 // This can happen if the only difference between the
880 // cases was a let statement that we pulled out of the if.
881 return then_case;
882 }
883
884 const Allocate *allocate_a = then_case.as<Allocate>();
885 const Allocate *allocate_b = else_case.as<Allocate>();
886 const For *for_a = then_case.as<For>();
887 const For *for_b = else_case.as<For>();
888 const LetStmt *let_a = then_case.as<LetStmt>();
889 const LetStmt *let_b = else_case.as<LetStmt>();
890 if (allocate_a && allocate_b) {
891 Stmt inner = IfThenElse::make(op->condition, allocate_a->body, allocate_b->body);
892 inner = Allocate::make(allocate_a->name, allocate_a->type,
893 allocate_a->memory_type, allocate_a->extents,
894 allocate_a->condition, inner);
895 return mutate(inner);
896 } else if (let_a && let_b && let_a->name == let_b->name) {
897 string condition_name = unique_name('t');
898 Expr condition = Variable::make(op->condition.type(), condition_name);
899 Stmt inner = IfThenElse::make(condition, let_a->body, let_b->body);
900 inner = LetStmt::make(let_a->name, select(condition, let_a->value, let_b->value), inner);
901 inner = LetStmt::make(condition_name, op->condition, inner);
902 return mutate(inner);
903 } else if (let_a) {
904 string new_name = unique_name(let_a->name);
905 Stmt inner = let_a->body;
906 inner = substitute(let_a->name, Variable::make(let_a->value.type(), new_name), inner);
907 inner = IfThenElse::make(op->condition, inner, else_case);
908 inner = LetStmt::make(new_name, let_a->value, inner);
909 return mutate(inner);
910 } else if (let_b) {
911 string new_name = unique_name(let_b->name);
912 Stmt inner = let_b->body;
913 inner = substitute(let_b->name, Variable::make(let_b->value.type(), new_name), inner);
914 inner = IfThenElse::make(op->condition, then_case, inner);
915 inner = LetStmt::make(new_name, let_b->value, inner);
916 return mutate(inner);
917 } else if (for_a && for_b &&
918 for_a->name == for_b->name &&
919 for_a->min.same_as(for_b->min) &&
920 for_a->extent.same_as(for_b->extent)) {
921 Stmt inner = IfThenElse::make(op->condition, for_a->body, for_b->body);
922 inner = For::make(for_a->name, for_a->min, for_a->extent, for_a->for_type, for_a->device_api, inner);
923 return mutate(inner);
924 } else {
925 internal_error << "Unexpected construct inside if statement: " << Stmt(op) << "\n";
926 return Stmt();
927 }
928 }
929 };
930
931 // Expand selects of boolean conditions so that the partitioner can
932 // consider them one-at-a-time.
933 class ExpandSelects : public IRMutator {
934 using IRMutator::visit;
935
is_trivial(const Expr & e)936 bool is_trivial(const Expr &e) {
937 return e.as<Variable>() || is_const(e);
938 }
939
visit(const Select * op)940 Expr visit(const Select *op) override {
941 Expr condition = mutate(op->condition);
942 Expr true_value = mutate(op->true_value);
943 Expr false_value = mutate(op->false_value);
944 if (const Or *o = condition.as<Or>()) {
945 if (is_trivial(true_value)) {
946 return mutate(Select::make(o->a, true_value, Select::make(o->b, true_value, false_value)));
947 } else {
948 string var_name = unique_name('t');
949 Expr var = Variable::make(true_value.type(), var_name);
950 Expr expr = mutate(Select::make(o->a, var, Select::make(o->b, var, false_value)));
951 return Let::make(var_name, true_value, expr);
952 }
953 } else if (const And *a = condition.as<And>()) {
954 if (is_trivial(false_value)) {
955 return mutate(Select::make(a->a, Select::make(a->b, true_value, false_value), false_value));
956 } else {
957 string var_name = unique_name('t');
958 Expr var = Variable::make(false_value.type(), var_name);
959 Expr expr = mutate(Select::make(a->a, Select::make(a->b, true_value, var), var));
960 return Let::make(var_name, false_value, expr);
961 }
962 } else if (const Not *n = condition.as<Not>()) {
963 return mutate(Select::make(n->a, false_value, true_value));
964 } else if (condition.same_as(op->condition) &&
965 true_value.same_as(op->true_value) &&
966 false_value.same_as(op->false_value)) {
967 return op;
968 } else {
969 return Select::make(condition, true_value, false_value);
970 }
971 }
972 };
973
974 // Collapse selects back together
975 class CollapseSelects : public IRMutator {
976 using IRMutator::visit;
977
visit(const Select * op)978 Expr visit(const Select *op) override {
979 const Select *t = op->true_value.as<Select>();
980 const Select *f = op->false_value.as<Select>();
981
982 if (t && equal(t->false_value, op->false_value)) {
983 // select(a, select(b, t, f), f) -> select(a && b, t, f)
984 return mutate(select(op->condition && t->condition, t->true_value, op->false_value));
985 } else if (f && equal(op->true_value, f->true_value)) {
986 // select(a, t, select(b, t, f)) -> select(a || b, t, f)
987 return mutate(select(op->condition || f->condition, op->true_value, f->false_value));
988 } else {
989 return IRMutator::visit(op);
990 }
991 }
992 };
993
994 class ContainsLoop : public IRVisitor {
995 using IRVisitor::visit;
visit(const For * op)996 void visit(const For *op) override {
997 result = true;
998 }
999
1000 public:
1001 bool result = false;
1002 };
1003
1004 class LowerLikelyIfInnermost : public IRMutator {
1005 using IRMutator::visit;
1006
1007 bool inside_innermost_loop = false;
1008
visit(const Call * op)1009 Expr visit(const Call *op) override {
1010 if (op->is_intrinsic(Call::likely_if_innermost)) {
1011 internal_assert(op->args.size() == 1);
1012 if (inside_innermost_loop) {
1013 return Call::make(op->type, Call::likely, {mutate(op->args[0])}, Call::PureIntrinsic);
1014 } else {
1015 return mutate(op->args[0]);
1016 }
1017 } else {
1018 return IRMutator::visit(op);
1019 }
1020 }
1021
visit(const For * op)1022 Stmt visit(const For *op) override {
1023 ContainsLoop c;
1024 op->body.accept(&c);
1025 inside_innermost_loop = !c.result;
1026 Stmt stmt = IRMutator::visit(op);
1027 inside_innermost_loop = false;
1028 return stmt;
1029 }
1030 };
1031
1032 } // namespace
1033
has_uncaptured_likely_tag(const Expr & e)1034 bool has_uncaptured_likely_tag(const Expr &e) {
1035 HasUncapturedLikelyTag h;
1036 e.accept(&h);
1037 return h.result;
1038 }
1039
has_likely_tag(const Expr & e)1040 bool has_likely_tag(const Expr &e) {
1041 HasLikelyTag h;
1042 e.accept(&h);
1043 return h.result;
1044 }
1045
partition_loops(Stmt s)1046 Stmt partition_loops(Stmt s) {
1047 s = LowerLikelyIfInnermost().mutate(s);
1048
1049 // Walk inwards to the first loop before doing any more work.
1050 class Mutator : public IRMutator {
1051 using IRMutator::visit;
1052 Stmt visit(const For *op) override {
1053 Stmt s = op;
1054 s = MarkClampedRampsAsLikely().mutate(s);
1055 s = ExpandSelects().mutate(s);
1056 s = PartitionLoops().mutate(s);
1057 s = RenormalizeGPULoops().mutate(s);
1058 s = CollapseSelects().mutate(s);
1059 return s;
1060 }
1061 } mutator;
1062 s = mutator.mutate(s);
1063
1064 s = remove_likelies(s);
1065 return s;
1066 }
1067
1068 } // namespace Internal
1069 } // namespace Halide
1070