1 #include <algorithm>
2 #include <numeric>
3 #include <utility>
4 
5 #include "CSE.h"
6 #include "CodeGen_GPU_Dev.h"
7 #include "ExprUsesVar.h"
8 #include "IREquality.h"
9 #include "IRMutator.h"
10 #include "IROperator.h"
11 #include "PartitionLoops.h"
12 #include "Simplify.h"
13 #include "Solve.h"
14 #include "Substitute.h"
15 #include "Var.h"
16 
17 namespace Halide {
18 namespace Internal {
19 
20 using std::pair;
21 using std::string;
22 using std::vector;
23 
24 namespace {
25 
26 // Loop partitioning only applies to things marked as 'likely'. Loads
27 // through hand-written boundary conditions will produce clamped
28 // ramps, which will turn into gathers. This pass injects likely
29 // intrinsics so that these clamped ramps are picked up by loop
30 // partitioning.
31 class MarkClampedRampsAsLikely : public IRMutator {
32     using IRMutator::visit;
visit(const Min * op)33     Expr visit(const Min *op) override {
34         if (in_index && op->a.as<Ramp>()) {
35             // No point recursing into the ramp - it can't contain
36             // another ramp.
37             return min(likely(op->a), mutate(op->b));
38         } else if (in_index && op->b.as<Ramp>()) {
39             return min(mutate(op->a), likely(op->b));
40         } else {
41             return IRMutator::visit(op);
42         }
43     }
44 
visit(const Max * op)45     Expr visit(const Max *op) override {
46         if (in_index && op->a.as<Ramp>()) {
47             return max(likely(op->a), mutate(op->b));
48         } else if (in_index && op->b.as<Ramp>()) {
49             return max(mutate(op->a), likely(op->b));
50         } else {
51             return IRMutator::visit(op);
52         }
53     }
54 
visit(const Load * op)55     Expr visit(const Load *op) override {
56         bool old_in_index = in_index;
57         in_index = true;
58         Expr expr = IRMutator::visit(op);
59         in_index = old_in_index;
60         return expr;
61     }
62 
visit(const Store * op)63     Stmt visit(const Store *op) override {
64         bool old_in_index = in_index;
65         in_index = true;
66         Expr index = mutate(op->index);
67         in_index = old_in_index;
68         Expr value = mutate(op->value);
69         Expr predicate = mutate(op->predicate);
70         if (predicate.same_as(op->predicate) && index.same_as(op->index) && value.same_as(op->value)) {
71             return op;
72         } else {
73             return Store::make(op->name, value, index, op->param, predicate, op->alignment);
74         }
75     }
76 
77     bool in_index = false;
78 };
79 
80 // Check if an expression or statement uses a likely tag
81 class HasLikelyTag : public IRVisitor {
82 protected:
83     using IRVisitor::visit;
visit(const Call * op)84     void visit(const Call *op) override {
85         if (op->is_intrinsic(Call::likely)) {
86             result = true;
87         } else {
88             IRVisitor::visit(op);
89         }
90     }
91 
92 public:
93     bool result = false;
94 };
95 
96 class HasUncapturedLikelyTag : public HasLikelyTag {
97     using HasLikelyTag::visit;
98 
99     // Any likelies buried inside the following ops are captured the by respective ops
visit(const Select * op)100     void visit(const Select *op) override {
101     }
visit(const Min * op)102     void visit(const Min *op) override {
103     }
visit(const Max * op)104     void visit(const Max *op) override {
105     }
106 };
107 
108 // The goal of loop partitioning is to split loops up into a prologue,
109 // a clean steady state, and an epilogue. The next visitor
110 // (FindSimplifications) finds a list of simplifications that can be
111 // applied to produce that clean steady-state version of the loop
112 // body. It tries to simplify selects, mins, and maxes to just their
113 // likely branch. For example:
114 //
115 //   select(a, likely(b), c)     -> b
116 //   select(a, b, 5 + likely(c)) -> 5 + c
117 //   max(a, likely(b))           -> b
118 //
119 // These three simplifications are only valid if a is true, false, or
120 // less than b, respectively. So we visit the loop body looking for
121 // these sort of things, record the associated conditions, and try to
122 // solve for a range of the loop variable for which all of our
123 // conditions are true (by solving for each one and then taking the
124 // intersection). That gives us the clean steady state.
125 //
126 // It may be that we can also make some simplifications to the
127 // prologue or epilogue. For example, consider the case:
128 //
129 //   select(x > 0, likely(expr_t), expr_f)
130 //
131 // It can simplify to expr_t when x > 0. However, if this is the sole
132 // simplification which gives us a lower bound on x for the steady
133 // state, we can also simplify this select in the prologue to just be
134 // expr_f.
135 //
136 // Now consider this case:
137 //
138 //   (select(x > a, likely(expr_t1), expr_f1) +
139 //    select(x > b, likely(expr_t2), expr_f2))
140 //
141 // The steady state starts at x == max(a, b), which we get from the
142 // intersection of the intervals derived from each condition: x > a
143 // and x > b. In the steady state, the expression simplifies to
144 // expr_t1 + expr_t2. In the prologue we know that either x <= a or x
145 // <= b, but we don't know which one might be true, so we can't make
146 // any simplifications to the prologue.
147 //
148 // We may also encounter single conditions where we can simplify the
149 // steady-state but not the prologue. Say we're splitting up a loop
150 // over x and we encounter a condition that depends on a variable
151 // introduced in some inner loop:
152 //
153 // for x:
154 //   for z from 0 to 10:
155 //     ... select(x > z, likely(expr_t), expr_f) ...
156 //
157 // This select definitely simplifies to expr_t when x > 9, because
158 // that's the maximum value z could be, so we'll start the steady
159 // state at x == 10. This means the prologue covers values like x ==
160 // 5, where the select could be either true or false, so we can't make
161 // any simplifications to the prologue.
162 //
163 // There are some simplifications that we won't be able to do. For
164 // example, if we're partitioning the loop over x, and we encounter:
165 //
166 // for x:
167 //   for z from 0 to 10:
168 //     ... select(z < 5, likely(expr_t), expr_f)
169 //
170 // Restricting the range of x isn't going to simplify that expression
171 // - it doesn't even depend on x. We just make all the simplifications
172 // that we can, and take the intersection of the resulting regions. In
173 // this case, we'll make that simplification later, when we do loop
174 // partitioning over the loop in z. Some cases we'll never
175 // handle. E.g. consider:
176 //
177 // for x:
178 //   ... select(a + x*(b + x*(c + x*(d + x*e))) > 0, likely(expr_t), expr_f)
179 //
180 // In order to simplify that we'd have to come up with a formula that
181 // tells us an interval where a quintic is strictly positive. No such
182 // general formula exists (because no formula exists for the roots),
183 // so there's no programmatic way we can partition the loop over x to
184 // make that condition simplify. Finally my Galois theory course pays
185 // off. For failures like this, we just drop the likely tag. So loop
186 // partitioning is best-effort, but it should always work for things
187 // like x > a. A simpler case for which we bail is:
188 //
189 // for x:
190 //   ... select(x == 5, expr_t, likely(expr_f))
191 //
192 // This simplifies to the likely case in two disjoint ranges, but
193 // we're only producing one steady state, and we have no reason to
194 // believe one side is better than the other, so we just bail and drop
195 // the likely tag.
196 
197 // First we define the struct that represents a single simplification
198 // that can be applied to the steady state of the loop.
199 struct Simplification {
200     // This condition is sufficient for the simplification to occur.
201     Expr condition;
202     // The expression we're simplifying
203     Expr old_expr;
204     // The replacement if the condition is true
205     Expr likely_value;
206     // The replacement if the condition is false. Not useful
207     // unless it's tight.
208     Expr unlikely_value;
209     // Is the condition necessary (as well as sufficient)?
210     bool tight;
211     // The interval over which this simplification applies. Comes from solving the condition.
212     Interval interval;
213 };
214 
215 class ExprUsesInvalidBuffers : public IRVisitor {
216     using IRVisitor::visit;
217 
218     const Scope<> &invalid_buffers;
219 
visit(const Load * op)220     void visit(const Load *op) override {
221         if (invalid_buffers.contains(op->name)) {
222             invalid = true;
223         } else {
224             IRVisitor::visit(op);
225         }
226     }
227 
228 public:
ExprUsesInvalidBuffers(const Scope<> & buffers)229     ExprUsesInvalidBuffers(const Scope<> &buffers)
230         : invalid_buffers(buffers), invalid(false) {
231     }
232     bool invalid;
233 };
234 
235 /** Check if any references to buffers in an expression is invalid. */
expr_uses_invalid_buffers(const Expr & e,const Scope<> & invalid_buffers)236 bool expr_uses_invalid_buffers(const Expr &e, const Scope<> &invalid_buffers) {
237     ExprUsesInvalidBuffers uses(invalid_buffers);
238     e.accept(&uses);
239     return uses.invalid;
240 }
241 
242 // Then we define the visitor that hunts for them.
243 class FindSimplifications : public IRVisitor {
244     using IRVisitor::visit;
245 
246     Scope<> depends_on_loop_var, depends_on_invalid_buffers;
247     Scope<> buffers;
248 
visit(const Allocate * op)249     void visit(const Allocate *op) override {
250         buffers.push(op->name);
251         IRVisitor::visit(op);
252     }
253 
new_simplification(Expr condition,Expr old,Expr likely_val,Expr unlikely_val)254     void new_simplification(Expr condition, Expr old, Expr likely_val, Expr unlikely_val) {
255         if (!expr_uses_vars(condition, depends_on_loop_var)) {
256             return;
257         }
258 
259         if (expr_uses_vars(condition, depends_on_invalid_buffers) ||
260             expr_uses_invalid_buffers(condition, buffers)) {
261             // The condition refers to buffer allocated in the inner loop.
262             // We should throw away the condition
263             return;
264         }
265         condition = remove_likelies(condition);
266         Simplification s = {condition, std::move(old), std::move(likely_val), std::move(unlikely_val), true};
267         if (s.condition.type().is_vector()) {
268             s.condition = simplify(s.condition);
269             if (const Broadcast *b = s.condition.as<Broadcast>()) {
270                 s.condition = b->value;
271             } else {
272                 // Devectorize the condition
273                 s.condition = and_condition_over_domain(s.condition, Scope<Interval>::empty_scope());
274                 s.tight = false;
275             }
276         }
277         internal_assert(s.condition.type().is_scalar()) << s.condition << "\n";
278         simplifications.push_back(s);
279     }
280 
visit(const Min * op)281     void visit(const Min *op) override {
282         bool likely_a = has_uncaptured_likely_tag(op->a);
283         bool likely_b = has_uncaptured_likely_tag(op->b);
284 
285         // Prefer the side that has an uncaptured top-level likely
286         // call. If neither does, prefer the side that contains any
287         // likely call at all.
288         if (!likely_a && !likely_b) {
289             likely_a = has_likely_tag(op->a);
290             likely_b = has_likely_tag(op->b);
291         }
292 
293         // Don't hunt for simplifications in unlikely paths
294         if (!likely_a) {
295             op->b.accept(this);
296         }
297         if (!likely_b) {
298             op->a.accept(this);
299         }
300 
301         if (likely_b && !likely_a) {
302             new_simplification(op->b <= op->a, op, op->b, op->a);
303         } else if (likely_a && !likely_b) {
304             new_simplification(op->a <= op->b, op, op->a, op->b);
305         }
306     }
307 
visit(const Max * op)308     void visit(const Max *op) override {
309         bool likely_a = has_uncaptured_likely_tag(op->a);
310         bool likely_b = has_uncaptured_likely_tag(op->b);
311 
312         if (!likely_a && !likely_b) {
313             likely_a = has_likely_tag(op->a);
314             likely_b = has_likely_tag(op->b);
315         }
316 
317         if (!likely_a) {
318             op->b.accept(this);
319         }
320         if (!likely_b) {
321             op->a.accept(this);
322         }
323 
324         if (likely_b && !likely_a) {
325             new_simplification(op->b >= op->a, op, op->b, op->a);
326         } else if (likely_a && !likely_b) {
327             new_simplification(op->a >= op->b, op, op->a, op->b);
328         }
329     }
330 
visit(const Select * op)331     void visit(const Select *op) override {
332         op->condition.accept(this);
333 
334         bool likely_t = has_uncaptured_likely_tag(op->true_value);
335         bool likely_f = has_uncaptured_likely_tag(op->false_value);
336 
337         if (!likely_t && !likely_f) {
338             likely_t = has_likely_tag(op->true_value);
339             likely_f = has_likely_tag(op->false_value);
340         }
341 
342         if (!likely_t) {
343             op->false_value.accept(this);
344         }
345         if (!likely_f) {
346             op->true_value.accept(this);
347         }
348 
349         if (likely_t && !likely_f) {
350             new_simplification(op->condition, op, op->true_value, op->false_value);
351         } else if (likely_f && !likely_t) {
352             new_simplification(!op->condition, op, op->false_value, op->true_value);
353         }
354     }
355 
visit(const IfThenElse * op)356     void visit(const IfThenElse *op) override {
357         // For select statements, mins, and maxes, you can mark the
358         // likely branch with likely. For if statements there's no way
359         // to mark the likely stmt. So if the condition of an if
360         // statement is marked as likely, treat it as likely true and
361         // partition accordingly.
362         IRVisitor::visit(op);
363         const Call *call = op->condition.as<Call>();
364         if (call && call->is_intrinsic(Call::likely)) {
365             new_simplification(op->condition, op->condition, const_true(), const_false());
366         }
367     }
368 
visit(const For * op)369     void visit(const For *op) override {
370         vector<Simplification> old;
371         old.swap(simplifications);
372         IRVisitor::visit(op);
373 
374         // Relax all the new conditions using the loop bounds
375         for (Simplification &s : simplifications) {
376             if (expr_uses_var(s.condition, op->name)) {
377                 Scope<Interval> varying;
378                 varying.push(op->name, Interval(op->min, op->min + op->extent - 1));
379                 Expr relaxed = and_condition_over_domain(s.condition, varying);
380                 internal_assert(!expr_uses_var(relaxed, op->name))
381                     << "Should not have had used the loop var (" << op->name
382                     << ") any longer\n  before: " << s.condition << "\n  after: "
383                     << relaxed << "\n";
384                 if (!equal(relaxed, s.condition)) {
385                     s.tight = false;
386                 }
387                 s.condition = relaxed;
388             }
389         }
390 
391         simplifications.insert(simplifications.end(), old.begin(), old.end());
392     }
393 
394     template<typename LetOrLetStmt>
visit_let(const LetOrLetStmt * op)395     void visit_let(const LetOrLetStmt *op) {
396         ScopedBinding<> bind_varying(expr_uses_vars(op->value, depends_on_loop_var),
397                                      depends_on_loop_var, op->name);
398         ScopedBinding<> bind_invalid(expr_uses_invalid_buffers(op->value, buffers) ||
399                                          expr_uses_vars(op->value, depends_on_invalid_buffers),
400                                      depends_on_invalid_buffers, op->name);
401         vector<Simplification> old;
402         old.swap(simplifications);
403         IRVisitor::visit(op);
404         for (Simplification &s : simplifications) {
405             if (expr_uses_var(s.condition, op->name)) {
406                 s.condition = Let::make(op->name, op->value, s.condition);
407             }
408         }
409         simplifications.insert(simplifications.end(), old.begin(), old.end());
410     }
411 
visit(const LetStmt * op)412     void visit(const LetStmt *op) override {
413         visit_let(op);
414     }
415 
visit(const Let * op)416     void visit(const Let *op) override {
417         visit_let(op);
418     }
419 
420 public:
421     vector<Simplification> simplifications;
422 
FindSimplifications(const std::string & v)423     FindSimplifications(const std::string &v) {
424         depends_on_loop_var.push(v);
425     }
426 };
427 
428 // Blindly apply a list of simplifications.
429 class MakeSimplifications : public IRMutator {
430     using IRMutator::visit;
431 
432     const vector<Simplification> &simplifications;
433 
434 public:
MakeSimplifications(const vector<Simplification> & s)435     MakeSimplifications(const vector<Simplification> &s)
436         : simplifications(s) {
437     }
438 
439     using IRMutator::mutate;
mutate(const Expr & e)440     Expr mutate(const Expr &e) override {
441         for (auto const &s : simplifications) {
442             if (e.same_as(s.old_expr)) {
443                 return mutate(s.likely_value);
444             }
445         }
446         return IRMutator::mutate(e);
447     }
448 };
449 
450 class ContainsWarpSynchronousLogic : public IRVisitor {
451 public:
452     bool result = false;
453 
454 protected:
455     using IRVisitor::visit;
visit(const Call * op)456     void visit(const Call *op) override {
457         if (op->is_intrinsic(Call::gpu_thread_barrier)) {
458             result = true;
459         } else {
460             IRVisitor::visit(op);
461         }
462     }
463 
visit(const For * op)464     void visit(const For *op) override {
465         if (op->for_type == ForType::GPULane) {
466             result = true;
467         } else {
468             IRVisitor::visit(op);
469         }
470     }
471 
visit(const Load * op)472     void visit(const Load *op) override {
473     }
474 };
475 
contains_warp_synchronous_logic(const Stmt & s)476 bool contains_warp_synchronous_logic(const Stmt &s) {
477     ContainsWarpSynchronousLogic c;
478     s.accept(&c);
479     return c.result;
480 }
481 
482 class PartitionLoops : public IRMutator {
483     using IRMutator::visit;
484 
485     bool in_gpu_loop = false;
486 
visit(const For * op)487     Stmt visit(const For *op) override {
488         Stmt body = op->body;
489 
490         ScopedValue<bool> old_in_gpu_loop(in_gpu_loop, in_gpu_loop ||
491                                                            CodeGen_GPU_Dev::is_gpu_var(op->name));
492 
493         // If we're inside GPU kernel, and the body contains thread
494         // barriers or warp shuffles, it's not safe to partition loops.
495         if (in_gpu_loop && contains_warp_synchronous_logic(op)) {
496             return IRMutator::visit(op);
497         }
498 
499         // We shouldn't partition GLSL loops - they have control-flow
500         // constraints.
501         if (op->device_api == DeviceAPI::GLSL) {
502             return op;
503         }
504 
505         // Find simplifications in this loop body
506         FindSimplifications finder(op->name);
507         body.accept(&finder);
508 
509         if (finder.simplifications.empty()) {
510             return IRMutator::visit(op);
511         }
512 
513         debug(3) << "\n\n**** Partitioning loop over " << op->name << "\n";
514 
515         vector<Expr> min_vals, max_vals;
516         vector<Simplification> middle_simps, prologue_simps, epilogue_simps;
517         bool lower_bound_is_tight = true, upper_bound_is_tight = true;
518         for (auto &s : finder.simplifications) {
519             // Solve for the interval over which this simplification is true.
520             s.interval = solve_for_inner_interval(s.condition, op->name);
521             if (s.tight) {
522                 // Check if the solve is tight. I.e. the condition is
523                 // definitely false outside of the interval.
524                 Interval outer = solve_for_outer_interval(s.condition, op->name);
525                 s.tight &= equal(outer.min, s.interval.min) && equal(outer.max, s.interval.max);
526             }
527 
528             debug(3) << "\nSimplification: \n"
529                      << "  condition: " << s.condition << "\n"
530                      << "  old: " << s.old_expr << "\n"
531                      << "  new: " << s.likely_value << "\n"
532                      << "  min: " << s.interval.min << "\n"
533                      << "  max: " << s.interval.max << "\n"
534                      << "  tight: " << s.tight << "\n";
535 
536             // Accept all non-empty intervals
537             if (!s.interval.is_empty()) {
538                 if (s.interval.has_lower_bound()) {
539                     Expr m = s.interval.min;
540                     if (!s.tight) {
541                         lower_bound_is_tight = false;
542                     }
543                     if (min_vals.empty()) {
544                         min_vals.push_back(m);
545                     } else if (equal(m, min_vals.back())) {
546                         // We already have this min val
547                     } else {
548                         // This is a new distinct min val
549                         min_vals.push_back(m);
550                         lower_bound_is_tight = false;
551                     }
552                 }
553                 if (s.interval.has_upper_bound()) {
554                     Expr m = s.interval.max;
555                     if (!s.tight) {
556                         upper_bound_is_tight = false;
557                     }
558                     if (max_vals.empty()) {
559                         max_vals.push_back(m);
560                     } else if (equal(m, max_vals.back())) {
561                         // We already have this max val
562                     } else {
563                         // This is a new distinct max val
564                         max_vals.push_back(m);
565                         upper_bound_is_tight = false;
566                     }
567                 }
568 
569                 // We'll apply this simplification to the
570                 // steady-state.
571                 middle_simps.push_back(s);
572             }
573         }
574 
575         // In general we can't simplify the prologue - it may run up
576         // to after the epilogue starts for small images. However if
577         // we can prove the epilogue starts after the prologue ends,
578         // we're OK.
579         bool can_simplify_prologue = true;
580         for (Expr min_val : min_vals) {
581             for (Expr max_val : max_vals) {
582                 Expr test = simplify(common_subexpression_elimination(min_val - 1 < max_val + 1));
583                 if (!is_one(test)) {
584                     can_simplify_prologue = false;
585                 }
586             }
587         }
588 
589         // Find simplifications we can apply to the prologue and epilogue.
590         for (const auto &s : middle_simps) {
591             // If it goes down to minus infinity, we can also
592             // apply it to the prologue
593             if (can_simplify_prologue &&
594                 !s.interval.has_lower_bound()) {
595                 prologue_simps.push_back(s);
596             }
597 
598             // If it goes up to positive infinity, we can also
599             // apply it to the epilogue
600             if (!s.interval.has_upper_bound()) {
601                 epilogue_simps.push_back(s);
602             }
603 
604             // If our simplifications only contain one lower bound, and
605             // it's tight, then the reverse rule can be applied to the
606             // prologue.
607             if (can_simplify_prologue &&
608                 s.interval.has_lower_bound() &&
609                 lower_bound_is_tight) {
610                 internal_assert(s.tight);
611                 Simplification s2 = s;
612                 // This condition is never used (we already solved
613                 // for the interval), but it's nice for it to be
614                 // correct.
615                 s2.condition = !s2.condition;
616                 std::swap(s2.likely_value, s2.unlikely_value);
617                 prologue_simps.push_back(s2);
618             }
619             if (s.interval.has_upper_bound() &&
620                 upper_bound_is_tight) {
621                 internal_assert(s.tight);
622                 Simplification s2 = s;
623                 s2.condition = !s2.condition;
624                 std::swap(s2.likely_value, s2.unlikely_value);
625                 epilogue_simps.push_back(s2);
626             }
627         }
628 
629         // Simplify each section of the loop.
630         Stmt simpler_body = MakeSimplifications(middle_simps).mutate(body);
631         Stmt prologue = MakeSimplifications(prologue_simps).mutate(body);
632         Stmt epilogue = MakeSimplifications(epilogue_simps).mutate(body);
633 
634         bool make_prologue = !equal(prologue, simpler_body);
635         bool make_epilogue = !equal(epilogue, simpler_body);
636 
637         // Recurse on the middle section.
638         simpler_body = mutate(simpler_body);
639 
640         // Construct variables for the bounds of the simplified middle section
641         Expr min_steady = op->min, max_steady = op->extent + op->min;
642         Expr prologue_val, epilogue_val;
643         string prologue_name = unique_name(op->name + ".prologue");
644         string epilogue_name = unique_name(op->name + ".epilogue");
645 
646         if (make_prologue) {
647             // They'll simplify better if you put them in
648             // lexicographic order. This puts things like (x+1) and
649             // (x+3) next to each other so that the simplifier sees
650             // them together and can drop one of them.
651             std::sort(min_vals.begin(), min_vals.end(), IRDeepCompare());
652             min_vals.push_back(op->min);
653             prologue_val = fold_left(min_vals, Max::make);
654             // Stop the prologue from running past the end of the loop
655             prologue_val = min(prologue_val, op->extent + op->min);
656             // prologue_val = print(prologue_val, prologue_name);
657             min_steady = Variable::make(Int(32), prologue_name);
658 
659             internal_assert(!expr_uses_var(prologue_val, op->name));
660         }
661         if (make_epilogue) {
662             std::sort(max_vals.begin(), max_vals.end(), IRDeepCompare());
663             max_vals.push_back(op->min + op->extent - 1);
664             epilogue_val = fold_left(max_vals, Min::make) + 1;
665             // Stop the epilogue from running before the start of the loop/prologue
666             if (make_prologue) {
667                 epilogue_val = max(epilogue_val, prologue_val);
668             } else {
669                 epilogue_val = max(op->min, epilogue_val);
670             }
671             // epilogue_val = print(epilogue_val, epilogue_name);
672             max_steady = Variable::make(Int(32), epilogue_name);
673 
674             internal_assert(!expr_uses_var(epilogue_val, op->name));
675         }
676 
677         Stmt stmt;
678         // Bust simple serial for loops up into three.
679         if (op->for_type == ForType::Serial && !op->body.as<Acquire>()) {
680             stmt = For::make(op->name, min_steady, max_steady - min_steady,
681                              op->for_type, op->device_api, simpler_body);
682 
683             if (make_prologue) {
684                 prologue = For::make(op->name, op->min, min_steady - op->min,
685                                      op->for_type, op->device_api, prologue);
686                 stmt = Block::make(prologue, stmt);
687             }
688             if (make_epilogue) {
689                 epilogue = For::make(op->name, max_steady, op->min + op->extent - max_steady,
690                                      op->for_type, op->device_api, epilogue);
691                 stmt = Block::make(stmt, epilogue);
692             }
693         } else {
694             // For parallel for loops we could use a Fork node here,
695             // but that would introduce the more complicated parallel
696             // runtime into code that doesn't use async(), which may
697             // interfere with legacy overrides of
698             // halide_do_par_for. So for parallel for loops just put
699             // an if-then-else in the loop body. It should
700             // branch-predict to the steady state pretty well.
701             //
702             // Simple serial for loops that contain an Acquire node go
703             // into the task system as a single entity, but Block
704             // nodes do not, so we get a flatter task graph if we do
705             // the same trick.
706             Expr loop_var = Variable::make(Int(32), op->name);
707             stmt = simpler_body;
708             if (make_epilogue && make_prologue && equal(prologue, epilogue)) {
709                 stmt = IfThenElse::make(min_steady <= loop_var && loop_var < max_steady, stmt, prologue);
710             } else {
711                 if (make_epilogue) {
712                     stmt = IfThenElse::make(loop_var < max_steady, stmt, epilogue);
713                 }
714                 if (make_prologue) {
715                     stmt = IfThenElse::make(loop_var < min_steady, prologue, stmt);
716                 }
717             }
718             stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, stmt);
719         }
720 
721         if (make_epilogue) {
722             // Uncomment to include code that prints the epilogue value
723             //epilogue_val = print(epilogue_val, op->name, "epilogue");
724             stmt = LetStmt::make(epilogue_name, epilogue_val, stmt);
725         } else {
726             epilogue_val = op->min + op->extent;
727         }
728         if (make_prologue) {
729             // Uncomment to include code that prints the prologue value
730             //prologue_val = print(prologue_val, op->name, "prologue");
731             stmt = LetStmt::make(prologue_name, prologue_val, stmt);
732         } else {
733             prologue_val = op->min;
734         }
735 
736         if (can_prove(epilogue_val <= prologue_val)) {
737             // The steady state is empty. I've made a huge
738             // mistake. Try to partition a loop further in.
739             return IRMutator::visit(op);
740         }
741 
742         debug(3) << "Partition loop.\n"
743                  << "Old: " << Stmt(op) << "\n"
744                  << "New: " << stmt << "\n";
745 
746         return stmt;
747     }
748 };
749 
750 class ExprContainsLoad : public IRVisitor {
751     using IRVisitor::visit;
752 
visit(const Load * op)753     void visit(const Load *op) override {
754         result = true;
755     }
756 
757 public:
758     bool result = false;
759 };
760 
expr_contains_load(const Expr & e)761 bool expr_contains_load(const Expr &e) {
762     ExprContainsLoad l;
763     e.accept(&l);
764     return l.result;
765 }
766 
767 // The loop partitioning logic can introduce if and let statements in
768 // between GPU loop levels. This pass moves them inwards or outwards.
769 class RenormalizeGPULoops : public IRMutator {
770     bool in_gpu_loop = false, in_thread_loop = false;
771 
772     using IRMutator::visit;
773 
774     // Track all vars that depend on GPU loop indices or loops inside GPU kernels.
775     Scope<> gpu_vars;
776 
777     vector<pair<string, Expr>> lifted_lets;
778 
visit(const For * op)779     Stmt visit(const For *op) override {
780         if (op->device_api == DeviceAPI::GLSL) {
781             // The partitioner did not enter GLSL loops
782             return op;
783         }
784 
785         bool old_in_gpu_loop = in_gpu_loop;
786         Stmt stmt;
787 
788         if (in_gpu_loop || CodeGen_GPU_Dev::is_gpu_var(op->name)) {
789             gpu_vars.push(op->name);
790             in_gpu_loop = true;
791         }
792 
793         if (ends_with(op->name, "__thread_id_x")) {
794             internal_assert(!in_thread_loop);
795             in_thread_loop = true;
796             stmt = IRMutator::visit(op);
797             in_thread_loop = false;
798         } else {
799             stmt = IRMutator::visit(op);
800         }
801 
802         if (in_gpu_loop && !old_in_gpu_loop) {
803             // This was the outermost GPU loop. Dump any lifted lets here.
804             while (lifted_lets.size()) {
805                 stmt = LetStmt::make(lifted_lets.back().first,
806                                      lifted_lets.back().second,
807                                      stmt);
808                 lifted_lets.pop_back();
809             }
810         }
811 
812         in_gpu_loop = old_in_gpu_loop;
813         return stmt;
814     }
815 
visit(const LetStmt * op)816     Stmt visit(const LetStmt *op) override {
817         if (!in_gpu_loop) {
818             return IRMutator::visit(op);
819         }
820 
821         if (!expr_uses_vars(op->value, gpu_vars) && !expr_contains_load(op->value)) {
822             // This let value doesn't depend in the gpu vars. We
823             // should lift it outermost. Note that this might expand
824             // its scope to encompass other uses of the same name, so
825             // we'd better give it a new name.
826             string new_name = unique_name('t');
827             Expr new_var = Variable::make(op->value.type(), new_name);
828             lifted_lets.emplace_back(new_name, op->value);
829             return mutate(substitute(op->name, new_var, op->body));
830         }
831 
832         gpu_vars.push(op->name);
833 
834         if (in_thread_loop) {
835             return IRMutator::visit(op);
836         }
837 
838         Stmt body = mutate(op->body);
839         const For *f = body.as<For>();
840         const Allocate *a = body.as<Allocate>();
841         // Move lets in-between gpu loop levels inwards.
842         if (f && in_gpu_loop && !in_thread_loop) {
843             internal_assert(!expr_uses_var(f->min, op->name) &&
844                             !expr_uses_var(f->extent, op->name));
845             Stmt inner = LetStmt::make(op->name, op->value, f->body);
846             inner = For::make(f->name, f->min, f->extent, f->for_type, f->device_api, inner);
847             return mutate(inner);
848         } else if (a && in_gpu_loop && !in_thread_loop) {
849             internal_assert(a->extents.size() == 1);
850             if (expr_uses_var(a->extents[0], op->name)) {
851                 // This var depends on the block index, and is used to
852                 // define the size of shared memory. Can't move it
853                 // inwards or outwards. Codegen will have to deal with
854                 // it when it deduces how much shared or warp-level
855                 // memory to allocate.
856                 return IRMutator::visit(op);
857             } else {
858                 Stmt inner = LetStmt::make(op->name, op->value, a->body);
859                 inner = Allocate::make(a->name, a->type, a->memory_type, a->extents, a->condition, inner);
860                 return mutate(inner);
861             }
862         } else {
863             return IRMutator::visit(op);
864         }
865     }
866 
visit(const IfThenElse * op)867     Stmt visit(const IfThenElse *op) override {
868         if (!in_gpu_loop || in_thread_loop) {
869             return IRMutator::visit(op);
870         }
871 
872         internal_assert(op->else_case.defined())
873             << "PartitionLoops should only introduce if statements with an else branch\n";
874 
875         Stmt then_case = mutate(op->then_case);
876         Stmt else_case = mutate(op->else_case);
877 
878         if (equal(then_case, else_case)) {
879             // This can happen if the only difference between the
880             // cases was a let statement that we pulled out of the if.
881             return then_case;
882         }
883 
884         const Allocate *allocate_a = then_case.as<Allocate>();
885         const Allocate *allocate_b = else_case.as<Allocate>();
886         const For *for_a = then_case.as<For>();
887         const For *for_b = else_case.as<For>();
888         const LetStmt *let_a = then_case.as<LetStmt>();
889         const LetStmt *let_b = else_case.as<LetStmt>();
890         if (allocate_a && allocate_b) {
891             Stmt inner = IfThenElse::make(op->condition, allocate_a->body, allocate_b->body);
892             inner = Allocate::make(allocate_a->name, allocate_a->type,
893                                    allocate_a->memory_type, allocate_a->extents,
894                                    allocate_a->condition, inner);
895             return mutate(inner);
896         } else if (let_a && let_b && let_a->name == let_b->name) {
897             string condition_name = unique_name('t');
898             Expr condition = Variable::make(op->condition.type(), condition_name);
899             Stmt inner = IfThenElse::make(condition, let_a->body, let_b->body);
900             inner = LetStmt::make(let_a->name, select(condition, let_a->value, let_b->value), inner);
901             inner = LetStmt::make(condition_name, op->condition, inner);
902             return mutate(inner);
903         } else if (let_a) {
904             string new_name = unique_name(let_a->name);
905             Stmt inner = let_a->body;
906             inner = substitute(let_a->name, Variable::make(let_a->value.type(), new_name), inner);
907             inner = IfThenElse::make(op->condition, inner, else_case);
908             inner = LetStmt::make(new_name, let_a->value, inner);
909             return mutate(inner);
910         } else if (let_b) {
911             string new_name = unique_name(let_b->name);
912             Stmt inner = let_b->body;
913             inner = substitute(let_b->name, Variable::make(let_b->value.type(), new_name), inner);
914             inner = IfThenElse::make(op->condition, then_case, inner);
915             inner = LetStmt::make(new_name, let_b->value, inner);
916             return mutate(inner);
917         } else if (for_a && for_b &&
918                    for_a->name == for_b->name &&
919                    for_a->min.same_as(for_b->min) &&
920                    for_a->extent.same_as(for_b->extent)) {
921             Stmt inner = IfThenElse::make(op->condition, for_a->body, for_b->body);
922             inner = For::make(for_a->name, for_a->min, for_a->extent, for_a->for_type, for_a->device_api, inner);
923             return mutate(inner);
924         } else {
925             internal_error << "Unexpected construct inside if statement: " << Stmt(op) << "\n";
926             return Stmt();
927         }
928     }
929 };
930 
931 // Expand selects of boolean conditions so that the partitioner can
932 // consider them one-at-a-time.
933 class ExpandSelects : public IRMutator {
934     using IRMutator::visit;
935 
is_trivial(const Expr & e)936     bool is_trivial(const Expr &e) {
937         return e.as<Variable>() || is_const(e);
938     }
939 
visit(const Select * op)940     Expr visit(const Select *op) override {
941         Expr condition = mutate(op->condition);
942         Expr true_value = mutate(op->true_value);
943         Expr false_value = mutate(op->false_value);
944         if (const Or *o = condition.as<Or>()) {
945             if (is_trivial(true_value)) {
946                 return mutate(Select::make(o->a, true_value, Select::make(o->b, true_value, false_value)));
947             } else {
948                 string var_name = unique_name('t');
949                 Expr var = Variable::make(true_value.type(), var_name);
950                 Expr expr = mutate(Select::make(o->a, var, Select::make(o->b, var, false_value)));
951                 return Let::make(var_name, true_value, expr);
952             }
953         } else if (const And *a = condition.as<And>()) {
954             if (is_trivial(false_value)) {
955                 return mutate(Select::make(a->a, Select::make(a->b, true_value, false_value), false_value));
956             } else {
957                 string var_name = unique_name('t');
958                 Expr var = Variable::make(false_value.type(), var_name);
959                 Expr expr = mutate(Select::make(a->a, Select::make(a->b, true_value, var), var));
960                 return Let::make(var_name, false_value, expr);
961             }
962         } else if (const Not *n = condition.as<Not>()) {
963             return mutate(Select::make(n->a, false_value, true_value));
964         } else if (condition.same_as(op->condition) &&
965                    true_value.same_as(op->true_value) &&
966                    false_value.same_as(op->false_value)) {
967             return op;
968         } else {
969             return Select::make(condition, true_value, false_value);
970         }
971     }
972 };
973 
974 // Collapse selects back together
975 class CollapseSelects : public IRMutator {
976     using IRMutator::visit;
977 
visit(const Select * op)978     Expr visit(const Select *op) override {
979         const Select *t = op->true_value.as<Select>();
980         const Select *f = op->false_value.as<Select>();
981 
982         if (t && equal(t->false_value, op->false_value)) {
983             // select(a, select(b, t, f), f) -> select(a && b, t, f)
984             return mutate(select(op->condition && t->condition, t->true_value, op->false_value));
985         } else if (f && equal(op->true_value, f->true_value)) {
986             // select(a, t, select(b, t, f)) -> select(a || b, t, f)
987             return mutate(select(op->condition || f->condition, op->true_value, f->false_value));
988         } else {
989             return IRMutator::visit(op);
990         }
991     }
992 };
993 
994 class ContainsLoop : public IRVisitor {
995     using IRVisitor::visit;
visit(const For * op)996     void visit(const For *op) override {
997         result = true;
998     }
999 
1000 public:
1001     bool result = false;
1002 };
1003 
1004 class LowerLikelyIfInnermost : public IRMutator {
1005     using IRMutator::visit;
1006 
1007     bool inside_innermost_loop = false;
1008 
visit(const Call * op)1009     Expr visit(const Call *op) override {
1010         if (op->is_intrinsic(Call::likely_if_innermost)) {
1011             internal_assert(op->args.size() == 1);
1012             if (inside_innermost_loop) {
1013                 return Call::make(op->type, Call::likely, {mutate(op->args[0])}, Call::PureIntrinsic);
1014             } else {
1015                 return mutate(op->args[0]);
1016             }
1017         } else {
1018             return IRMutator::visit(op);
1019         }
1020     }
1021 
visit(const For * op)1022     Stmt visit(const For *op) override {
1023         ContainsLoop c;
1024         op->body.accept(&c);
1025         inside_innermost_loop = !c.result;
1026         Stmt stmt = IRMutator::visit(op);
1027         inside_innermost_loop = false;
1028         return stmt;
1029     }
1030 };
1031 
1032 }  // namespace
1033 
has_uncaptured_likely_tag(const Expr & e)1034 bool has_uncaptured_likely_tag(const Expr &e) {
1035     HasUncapturedLikelyTag h;
1036     e.accept(&h);
1037     return h.result;
1038 }
1039 
has_likely_tag(const Expr & e)1040 bool has_likely_tag(const Expr &e) {
1041     HasLikelyTag h;
1042     e.accept(&h);
1043     return h.result;
1044 }
1045 
partition_loops(Stmt s)1046 Stmt partition_loops(Stmt s) {
1047     s = LowerLikelyIfInnermost().mutate(s);
1048 
1049     // Walk inwards to the first loop before doing any more work.
1050     class Mutator : public IRMutator {
1051         using IRMutator::visit;
1052         Stmt visit(const For *op) override {
1053             Stmt s = op;
1054             s = MarkClampedRampsAsLikely().mutate(s);
1055             s = ExpandSelects().mutate(s);
1056             s = PartitionLoops().mutate(s);
1057             s = RenormalizeGPULoops().mutate(s);
1058             s = CollapseSelects().mutate(s);
1059             return s;
1060         }
1061     } mutator;
1062     s = mutator.mutate(s);
1063 
1064     s = remove_likelies(s);
1065     return s;
1066 }
1067 
1068 }  // namespace Internal
1069 }  // namespace Halide
1070