1 #include "BoundsInference.h"
2 #include "Bounds.h"
3 #include "ExternFuncArgument.h"
4 #include "Function.h"
5 #include "IREquality.h"
6 #include "IRMutator.h"
7 #include "IROperator.h"
8 #include "Inline.h"
9 #include "Scope.h"
10 #include "Simplify.h"
11 
12 #include <algorithm>
13 #include <iterator>
14 
15 namespace Halide {
16 namespace Internal {
17 
18 using std::map;
19 using std::pair;
20 using std::set;
21 using std::string;
22 using std::vector;
23 
24 namespace {
25 
var_name_match(const string & candidate,const string & var)26 bool var_name_match(const string &candidate, const string &var) {
27     internal_assert(var.find('.') == string::npos)
28         << "var_name_match expects unqualified names for the second argument. "
29         << "Name passed: " << var << "\n";
30     return (candidate == var) || Internal::ends_with(candidate, "." + var);
31 }
32 
33 class DependsOnBoundsInference : public IRVisitor {
34     using IRVisitor::visit;
35 
visit(const Variable * var)36     void visit(const Variable *var) override {
37         if (ends_with(var->name, ".max") ||
38             ends_with(var->name, ".min")) {
39             result = true;
40         }
41     }
42 
visit(const Call * op)43     void visit(const Call *op) override {
44         if (op->name == Call::buffer_get_min ||
45             op->name == Call::buffer_get_max) {
46             result = true;
47         } else {
48             IRVisitor::visit(op);
49         }
50     }
51 
52 public:
53     bool result;
DependsOnBoundsInference()54     DependsOnBoundsInference()
55         : result(false) {
56     }
57 };
58 
depends_on_bounds_inference(const Expr & e)59 bool depends_on_bounds_inference(const Expr &e) {
60     DependsOnBoundsInference d;
61     e.accept(&d);
62     return d.result;
63 }
64 
65 /** Compute the bounds of the value of some variable defined by an
66  * inner let stmt or for loop. E.g. for the stmt:
67  *
68  *
69  * for x from 0 to 10:
70  *  let y = x + 2;
71  *
72  * bounds_of_inner_var(y) would return 2 to 12, and
73  * bounds_of_inner_var(x) would return 0 to 10.
74  */
75 class BoundsOfInnerVar : public IRVisitor {
76 public:
77     Interval result;
BoundsOfInnerVar(const string & v)78     BoundsOfInnerVar(const string &v)
79         : var(v) {
80     }
81 
82 private:
83     string var;
84     Scope<Interval> scope;
85 
86     using IRVisitor::visit;
87 
visit(const LetStmt * op)88     void visit(const LetStmt *op) override {
89         Interval in = bounds_of_expr_in_scope(op->value, scope);
90         if (op->name == var) {
91             result = in;
92         } else {
93             ScopedBinding<Interval> p(scope, op->name, in);
94             op->body.accept(this);
95         }
96     }
97 
visit(const For * op)98     void visit(const For *op) override {
99         // At this stage of lowering, loop_min and loop_max
100         // conveniently exist in scope.
101         Interval in(Variable::make(Int(32), op->name + ".loop_min"),
102                     Variable::make(Int(32), op->name + ".loop_max"));
103 
104         if (op->name == var) {
105             result = in;
106         } else {
107             ScopedBinding<Interval> p(scope, op->name, in);
108             op->body.accept(this);
109         }
110     }
111 };
112 
bounds_of_inner_var(const string & var,const Stmt & s)113 Interval bounds_of_inner_var(const string &var, const Stmt &s) {
114     BoundsOfInnerVar b(var);
115     s.accept(&b);
116     return b.result;
117 }
118 
find_fused_group_index(const Function & producing_func,const vector<vector<Function>> & fused_groups)119 size_t find_fused_group_index(const Function &producing_func,
120                               const vector<vector<Function>> &fused_groups) {
121     const auto &iter = std::find_if(fused_groups.begin(), fused_groups.end(),
122                                     [&producing_func](const vector<Function> &group) {
123                                         return std::any_of(group.begin(), group.end(),
124                                                            [&producing_func](const Function &f) {
125                                                                return (f.name() == producing_func.name());
126                                                            });
127                                     });
128     internal_assert(iter != fused_groups.end());
129     return iter - fused_groups.begin();
130 }
131 
132 // Determine if the current producing stage is fused with other
133 // stage (i.e. the consumer stage) at dimension 'var'.
is_fused_with_others(const vector<vector<Function>> & fused_groups,const vector<set<FusedPair>> & fused_pairs_in_groups,const Function & producing_func,int producing_stage_index,const string & consumer_name,int consumer_stage,string var)134 bool is_fused_with_others(const vector<vector<Function>> &fused_groups,
135                           const vector<set<FusedPair>> &fused_pairs_in_groups,
136                           const Function &producing_func, int producing_stage_index,
137                           const string &consumer_name, int consumer_stage,
138                           string var) {
139     if (producing_func.has_extern_definition()) {
140         return false;
141     }
142 
143     // Find the fused group this producing stage belongs to.
144     size_t index = find_fused_group_index(producing_func, fused_groups);
145 
146     const vector<Dim> &dims = (producing_stage_index == 0) ? producing_func.definition().schedule().dims() : producing_func.update(producing_stage_index - 1).schedule().dims();
147 
148     size_t var_index;
149     {
150         const auto &iter = std::find_if(dims.begin(), dims.end(),
151                                         [&var](const Dim &d) { return var_name_match(d.var, var); });
152         if (iter == dims.end()) {
153             return false;
154         }
155         var_index = iter - dims.begin();
156     }
157 
158     // Iterate over the fused pair list to check if the producer stage
159     // is fused with the consumer stage at 'var'
160     for (const auto &pair : fused_pairs_in_groups[index]) {
161         if (((pair.func_1 == consumer_name) && ((int)pair.stage_1 == consumer_stage)) ||
162             ((pair.func_2 == consumer_name) && ((int)pair.stage_2 == consumer_stage))) {
163             const auto &iter = std::find_if(dims.begin(), dims.end(),
164                                             [&pair](const Dim &d) { return var_name_match(d.var, pair.var_name); });
165             if (iter == dims.end()) {
166                 continue;
167             }
168             size_t idx = iter - dims.begin();
169             if (var_index >= idx) {
170                 return true;
171             }
172         }
173     }
174     return false;
175 }
176 }  // namespace
177 
178 class BoundsInference : public IRMutator {
179 public:
180     const vector<Function> &funcs;
181     // Each element in the list indicates a group of functions which loops
182     // are fused together.
183     const vector<vector<Function>> &fused_groups;
184     // Contain list of all pairwise fused function stages for each fused group.
185     // The fused group is indexed in the same way as 'fused_groups'.
186     const vector<set<FusedPair>> &fused_pairs_in_groups;
187     const FuncValueBounds &func_bounds;
188     set<string> in_pipeline, inner_productions, has_extern_consumer;
189     const Target target;
190 
191     struct CondValue {
192         Expr cond;  // Condition on params only (can't depend on loop variable)
193         Expr value;
194 
CondValueHalide::Internal::BoundsInference::CondValue195         CondValue(const Expr &c, const Expr &v)
196             : cond(c), value(v) {
197         }
198     };
199 
200     struct Stage {
201         Function func;
202         size_t stage;  // 0 is the pure definition, 1 is the first update
203         string name;
204         vector<int> consumers;
205         map<pair<string, int>, Box> bounds;
206         vector<CondValue> exprs;
207         set<ReductionVariable, ReductionVariable::Compare> rvars;
208         string stage_prefix;
209         size_t fused_group_index;
210 
211         // Computed expressions on the left and right-hand sides.
212         // Note that a function definition might have different LHS or reduction domain
213         // (if it's an update def) or RHS per specialization. All specializations
214         // of an init definition should have the same LHS.
215         // This also pushes all the reduction domains it encounters into the 'rvars'
216         // set for later use.
compute_exprs_helperHalide::Internal::BoundsInference::Stage217         vector<vector<CondValue>> compute_exprs_helper(const Definition &def, bool is_update) {
218             vector<vector<CondValue>> result(2);  // <args, values>
219 
220             if (!def.defined()) {
221                 return result;
222             }
223 
224             // Default case (no specialization)
225             vector<Expr> predicates = def.split_predicate();
226             for (const ReductionVariable &rv : def.schedule().rvars()) {
227                 rvars.insert(rv);
228             }
229 
230             vector<vector<Expr>> vecs(2);
231             if (is_update) {
232                 vecs[0] = def.args();
233             }
234             vecs[1] = def.values();
235 
236             for (size_t i = 0; i < result.size(); ++i) {
237                 for (const Expr &val : vecs[i]) {
238                     if (!predicates.empty()) {
239                         Expr cond_val = Call::make(val.type(),
240                                                    Internal::Call::if_then_else,
241                                                    {likely(predicates[0]), val, make_zero(val.type())},
242                                                    Internal::Call::PureIntrinsic);
243                         for (size_t i = 1; i < predicates.size(); ++i) {
244                             cond_val = Call::make(cond_val.type(),
245                                                   Internal::Call::if_then_else,
246                                                   {likely(predicates[i]), cond_val, make_zero(cond_val.type())},
247                                                   Internal::Call::PureIntrinsic);
248                         }
249                         result[i].push_back(CondValue(const_true(), cond_val));
250                     } else {
251                         result[i].push_back(CondValue(const_true(), val));
252                     }
253                 }
254             }
255 
256             const vector<Specialization> &specializations = def.specializations();
257             for (size_t i = specializations.size(); i > 0; i--) {
258                 Expr s_cond = specializations[i - 1].condition;
259                 const Definition &s_def = specializations[i - 1].definition;
260 
261                 // Else case (i.e. specialization condition is false)
262                 for (auto &vec : result) {
263                     for (CondValue &cval : vec) {
264                         cval.cond = simplify(!s_cond && cval.cond);
265                     }
266                 }
267 
268                 // Then case (i.e. specialization condition is true)
269                 vector<vector<CondValue>> s_result = compute_exprs_helper(s_def, is_update);
270                 for (auto &vec : s_result) {
271                     for (CondValue &cval : vec) {
272                         cval.cond = simplify(s_cond && cval.cond);
273                     }
274                 }
275                 for (size_t i = 0; i < result.size(); i++) {
276                     result[i].insert(result[i].end(), s_result[i].begin(), s_result[i].end());
277                 }
278             }
279 
280             // Optimization: If the args/values across specializations including
281             // the default case, are the same, we can combine those args/values
282             // into one arg/value with a const_true() condition for the purpose
283             // of bounds inference.
284             for (auto &vec : result) {
285                 if (vec.size() > 1) {
286                     bool all_equal = true;
287                     Expr val = vec[0].value;
288                     for (size_t i = 1; i < vec.size(); ++i) {
289                         if (!equal(val, vec[i].value)) {
290                             all_equal = false;
291                             break;
292                         }
293                     }
294                     if (all_equal) {
295                         debug(4) << "compute_exprs: all values (size: " << vec.size() << ") "
296                                  << "(" << val << ") are equal, combine them together\n";
297                         internal_assert(val.defined());
298                         vec.clear();
299                         vec.emplace_back(const_true(), val);
300                     }
301                 }
302             }
303             return result;
304         }
305 
306         // Computed expressions on the left and right-hand sides. This also
307         // pushes all reduction domains it encounters into the 'rvars' set
308         // for later use.
compute_exprsHalide::Internal::BoundsInference::Stage309         void compute_exprs() {
310             // We need to clear 'exprs' and 'rvars' first, in case compute_exprs()
311             // is called multiple times.
312             exprs.clear();
313             rvars.clear();
314 
315             bool is_update = (stage != 0);
316             vector<vector<CondValue>> result;
317             if (!is_update) {
318                 result = compute_exprs_helper(func.definition(), is_update);
319             } else {
320                 const Definition &def = func.update(stage - 1);
321                 result = compute_exprs_helper(def, is_update);
322             }
323             internal_assert(result.size() == 2);
324             exprs = result[0];
325 
326             if (func.extern_definition_proxy_expr().defined()) {
327                 exprs.emplace_back(const_true(), func.extern_definition_proxy_expr());
328             }
329 
330             exprs.insert(exprs.end(), result[1].begin(), result[1].end());
331 
332             // For the purposes of computation bounds inference, we
333             // don't care what sites are loaded, just what sites need
334             // to have the correct value in them. So remap all selects
335             // to if_then_elses to get tighter bounds.
336             class SelectToIfThenElse : public IRMutator {
337                 using IRMutator::visit;
338                 Expr visit(const Select *op) override {
339                     if (is_pure(op->condition)) {
340                         return Call::make(op->type, Call::if_then_else,
341                                           {mutate(op->condition),
342                                            mutate(op->true_value),
343                                            mutate(op->false_value)},
344                                           Call::PureIntrinsic);
345                     } else {
346                         return IRMutator::visit(op);
347                     }
348                 }
349             } select_to_if_then_else;
350 
351             for (auto &e : exprs) {
352                 e.value = select_to_if_then_else.mutate(e.value);
353             }
354         }
355 
356         // Check if the dimension at index 'dim_idx' is always pure (i.e. equal to 'dim')
357         // in the definition (including in its specializations)
is_dim_always_pureHalide::Internal::BoundsInference::Stage358         bool is_dim_always_pure(const Definition &def, const string &dim, int dim_idx) {
359             const Variable *var = def.args()[dim_idx].as<Variable>();
360             if ((!var) || (var->name != dim)) {
361                 return false;
362             }
363 
364             for (const Specialization &s : def.specializations()) {
365                 bool pure = is_dim_always_pure(s.definition, dim, dim_idx);
366                 if (!pure) {
367                     return false;
368                 }
369             }
370             return true;
371         }
372 
373         // Wrap a statement in let stmts defining the box
define_boundsHalide::Internal::BoundsInference::Stage374         Stmt define_bounds(Stmt s,
375                            const Function &producing_func,
376                            const string &producing_stage_index,
377                            int producing_stage_index_index,
378                            const string &loop_level,
379                            const vector<vector<Function>> &fused_groups,
380                            const vector<set<FusedPair>> &fused_pairs_in_groups,
381                            const set<string> &in_pipeline,
382                            const set<string> &inner_productions,
383                            const set<string> &has_extern_consumer,
384                            const Target &target) {
385 
386             // Merge all the relevant boxes.
387             Box b;
388 
389             const vector<string> func_args = func.args();
390 
391             size_t last_dot = loop_level.rfind('.');
392             string var = loop_level.substr(last_dot + 1);
393 
394             for (const pair<const pair<string, int>, Box> &i : bounds) {
395                 string func_name = i.first.first;
396                 int func_stage_index = i.first.second;
397                 string stage_name = func_name + ".s" + std::to_string(func_stage_index);
398                 if (stage_name == producing_stage_index ||
399                     inner_productions.count(func_name) ||
400                     is_fused_with_others(fused_groups, fused_pairs_in_groups,
401                                          producing_func, producing_stage_index_index,
402                                          func_name, func_stage_index, var)) {
403                     merge_boxes(b, i.second);
404                 }
405             }
406 
407             internal_assert(b.empty() || b.size() == func_args.size());
408 
409             if (!b.empty()) {
410                 // Optimization: If a dimension is pure in every update
411                 // step of a func, then there exists a single bound for
412                 // that dimension, instead of one bound per stage. Let's
413                 // figure out what those dimensions are, and just have all
414                 // stages but the last use the bounds for the last stage.
415                 vector<bool> always_pure_dims(func_args.size(), true);
416                 for (const Definition &def : func.updates()) {
417                     for (size_t j = 0; j < always_pure_dims.size(); j++) {
418                         bool pure = is_dim_always_pure(def, func_args[j], j);
419                         if (!pure) {
420                             always_pure_dims[j] = false;
421                         }
422                     }
423                 }
424 
425                 if (stage < func.updates().size()) {
426                     size_t stages = func.updates().size();
427                     string last_stage = func.name() + ".s" + std::to_string(stages) + ".";
428                     for (size_t i = 0; i < always_pure_dims.size(); i++) {
429                         if (always_pure_dims[i]) {
430                             const string &dim = func_args[i];
431                             Expr min = Variable::make(Int(32), last_stage + dim + ".min");
432                             Expr max = Variable::make(Int(32), last_stage + dim + ".max");
433                             b[i] = Interval(min, max);
434                         }
435                     }
436                 }
437             }
438 
439             if (func.has_extern_definition() &&
440                 !func.extern_definition_proxy_expr().defined()) {
441                 // After we define our bounds required, we need to
442                 // figure out what we're actually going to compute,
443                 // and what inputs we need. To do this we:
444 
445                 // 1) Grab a handle on the bounds query results from one level up
446 
447                 // 2) Run the bounds query to let it round up the output size.
448 
449                 // 3) Shift the requested output box back inside of the
450                 // bounds query result from one loop level up (in case
451                 // it was rounded up)
452 
453                 // 4) then run the bounds query again to get the input
454                 // sizes.
455 
456                 // Because we're wrapping a stmt, this happens in reverse order.
457 
458                 // 4)
459                 s = do_bounds_query(s, in_pipeline, target);
460 
461                 if (!in_pipeline.empty()) {
462                     // 3)
463                     string outer_query_name = func.name() + ".outer_bounds_query";
464                     Expr outer_query = Variable::make(type_of<struct halide_buffer_t *>(), outer_query_name);
465                     string inner_query_name = func.name() + ".o0.bounds_query";
466                     Expr inner_query = Variable::make(type_of<struct halide_buffer_t *>(), inner_query_name);
467                     for (int i = 0; i < func.dimensions(); i++) {
468                         Expr outer_min = Call::make(Int(32), Call::buffer_get_min,
469                                                     {outer_query, i}, Call::Extern);
470                         Expr outer_max = Call::make(Int(32), Call::buffer_get_max,
471                                                     {outer_query, i}, Call::Extern);
472 
473                         Expr inner_min = Call::make(Int(32), Call::buffer_get_min,
474                                                     {inner_query, i}, Call::Extern);
475                         Expr inner_max = Call::make(Int(32), Call::buffer_get_max,
476                                                     {inner_query, i}, Call::Extern);
477 
478                         // Push 'inner' inside of 'outer'
479                         Expr shift = Min::make(0, outer_max - inner_max);
480                         Expr new_min = inner_min + shift;
481                         Expr new_max = inner_max + shift;
482 
483                         // Modify the region to be computed accordingly
484                         s = LetStmt::make(func.name() + ".s0." + func_args[i] + ".max", new_max, s);
485                         s = LetStmt::make(func.name() + ".s0." + func_args[i] + ".min", new_min, s);
486                     }
487 
488                     // 2)
489                     s = do_bounds_query(s, in_pipeline, target);
490 
491                     // 1)
492                     s = LetStmt::make(func.name() + ".outer_bounds_query",
493                                       Variable::make(type_of<struct halide_buffer_t *>(), func.name() + ".o0.bounds_query"), s);
494                 } else {
495                     // If we're at the outermost loop, there is no
496                     // bounds query result from one level up, but we
497                     // still need to modify the region to be computed
498                     // based on the bounds query result and then do
499                     // another bounds query to ask for the required
500                     // input size given that.
501 
502                     // 2)
503                     string inner_query_name = func.name() + ".o0.bounds_query";
504                     Expr inner_query = Variable::make(type_of<struct halide_buffer_t *>(), inner_query_name);
505                     for (int i = 0; i < func.dimensions(); i++) {
506                         Expr new_min = Call::make(Int(32), Call::buffer_get_min,
507                                                   {inner_query, i}, Call::Extern);
508                         Expr new_max = Call::make(Int(32), Call::buffer_get_max,
509                                                   {inner_query, i}, Call::Extern);
510 
511                         s = LetStmt::make(func.name() + ".s0." + func_args[i] + ".max", new_max, s);
512                         s = LetStmt::make(func.name() + ".s0." + func_args[i] + ".min", new_min, s);
513                     }
514 
515                     s = do_bounds_query(s, in_pipeline, target);
516                 }
517             }
518 
519             if (in_pipeline.count(name) == 0) {
520                 // Inject any explicit bounds
521                 string prefix = name + ".s" + std::to_string(stage) + ".";
522 
523                 LoopLevel compute_at = func.schedule().compute_level();
524                 LoopLevel store_at = func.schedule().store_level();
525 
526                 for (size_t i = 0; i < func.schedule().bounds().size(); i++) {
527                     Bound bound = func.schedule().bounds()[i];
528                     string min_var = prefix + bound.var + ".min";
529                     string max_var = prefix + bound.var + ".max";
530                     Expr min_required = Variable::make(Int(32), min_var);
531                     Expr max_required = Variable::make(Int(32), max_var);
532 
533                     if (bound.extent.defined()) {
534                         // If the Func is compute_at some inner loop, and
535                         // only extent is bounded, then the min could
536                         // actually move around, which makes the extent
537                         // bound not actually useful for determining the
538                         // max required from the point of view of
539                         // producers.
540                         if (bound.min.defined() ||
541                             compute_at.is_root() ||
542                             (compute_at.match(loop_level) &&
543                              store_at.match(loop_level))) {
544                             if (!bound.min.defined()) {
545                                 bound.min = min_required;
546                             }
547                             s = LetStmt::make(min_var, bound.min, s);
548                             s = LetStmt::make(max_var, bound.min + bound.extent - 1, s);
549                         }
550 
551                         // Save the unbounded values to use in bounds-checking assertions
552                         s = LetStmt::make(min_var + "_unbounded", min_required, s);
553                         s = LetStmt::make(max_var + "_unbounded", max_required, s);
554                     }
555 
556                     if (bound.modulus.defined()) {
557                         min_required -= bound.remainder;
558                         min_required = (min_required / bound.modulus) * bound.modulus;
559                         min_required += bound.remainder;
560                         Expr max_plus_one = max_required + 1;
561                         max_plus_one -= bound.remainder;
562                         max_plus_one = ((max_plus_one + bound.modulus - 1) / bound.modulus) * bound.modulus;
563                         max_plus_one += bound.remainder;
564                         max_required = max_plus_one - 1;
565                         s = LetStmt::make(min_var, min_required, s);
566                         s = LetStmt::make(max_var, max_required, s);
567                     }
568                 }
569             }
570 
571             for (size_t d = 0; d < b.size(); d++) {
572                 string arg = name + ".s" + std::to_string(stage) + "." + func_args[d];
573 
574                 const bool clamp_to_outer_bounds =
575                     !in_pipeline.empty() && has_extern_consumer.count(name);
576                 if (clamp_to_outer_bounds) {
577                     // Allocation bounds inference is going to have a
578                     // bad time lifting the results of the bounds
579                     // queries outwards. Help it out by insisting that
580                     // the bounds are clamped to lie within the bounds
581                     // one loop level up.
582                     Expr outer_min = Variable::make(Int(32), arg + ".outer_min");
583                     Expr outer_max = Variable::make(Int(32), arg + ".outer_max");
584                     b[d].min = clamp(b[d].min, outer_min, outer_max);
585                     b[d].max = clamp(b[d].max, outer_min, outer_max);
586                 }
587 
588                 if (b[d].is_single_point()) {
589                     s = LetStmt::make(arg + ".min", Variable::make(Int(32), arg + ".max"), s);
590                 } else {
591                     s = LetStmt::make(arg + ".min", b[d].min, s);
592                 }
593                 s = LetStmt::make(arg + ".max", b[d].max, s);
594 
595                 if (clamp_to_outer_bounds) {
596                     s = LetStmt::make(arg + ".outer_min", Variable::make(Int(32), arg + ".min"), s);
597                     s = LetStmt::make(arg + ".outer_max", Variable::make(Int(32), arg + ".max"), s);
598                 }
599             }
600 
601             if (stage > 0) {
602                 for (const ReductionVariable &rvar : rvars) {
603                     string arg = name + ".s" + std::to_string(stage) + "." + rvar.var;
604                     s = LetStmt::make(arg + ".min", rvar.min, s);
605                     s = LetStmt::make(arg + ".max", rvar.extent + rvar.min - 1, s);
606                 }
607             }
608 
609             return s;
610         }
611 
do_bounds_queryHalide::Internal::BoundsInference::Stage612         Stmt do_bounds_query(Stmt s, const set<string> &in_pipeline, const Target &target) {
613 
614             const string &extern_name = func.extern_function_name();
615             const vector<ExternFuncArgument> &args = func.extern_arguments();
616 
617             vector<Expr> bounds_inference_args;
618 
619             vector<pair<string, Expr>> lets;
620 
621             // Iterate through all of the input args to the extern
622             // function building a suitable argument list for the
623             // extern function call.  We need a query halide_buffer_t per
624             // producer and a query halide_buffer_t for the output
625 
626             Expr null_handle = make_zero(Handle());
627 
628             vector<pair<Expr, int>> buffers_to_annotate;
629             for (size_t j = 0; j < args.size(); j++) {
630                 if (args[j].is_expr()) {
631                     bounds_inference_args.push_back(args[j].expr);
632                 } else if (args[j].is_func()) {
633                     Function input(args[j].func);
634                     for (int k = 0; k < input.outputs(); k++) {
635                         string name = input.name() + ".o" + std::to_string(k) + ".bounds_query." + func.name();
636 
637                         BufferBuilder builder;
638                         builder.type = input.output_types()[k];
639                         builder.dimensions = input.dimensions();
640                         Expr buf = builder.build();
641 
642                         lets.emplace_back(name, buf);
643                         bounds_inference_args.push_back(Variable::make(type_of<struct halide_buffer_t *>(), name));
644                         buffers_to_annotate.emplace_back(bounds_inference_args.back(), input.dimensions());
645                     }
646                 } else if (args[j].is_image_param() || args[j].is_buffer()) {
647                     Parameter p = args[j].image_param;
648                     Buffer<> b = args[j].buffer;
649                     string name = args[j].is_image_param() ? p.name() : b.name();
650                     int dims = args[j].is_image_param() ? p.dimensions() : b.dimensions();
651 
652                     Expr in_buf = Variable::make(type_of<struct halide_buffer_t *>(), name + ".buffer");
653 
654                     // Copy the input buffer into a query buffer to mutate.
655                     string query_name = name + ".bounds_query." + func.name();
656 
657                     Expr alloca_size = Call::make(Int(32), Call::size_of_halide_buffer_t, {}, Call::Intrinsic);
658                     Expr query_buf = Call::make(type_of<struct halide_buffer_t *>(), Call::alloca,
659                                                 {alloca_size}, Call::Intrinsic);
660                     Expr query_shape = Call::make(type_of<struct halide_dimension_t *>(), Call::alloca,
661                                                   {(int)(sizeof(halide_dimension_t) * dims)}, Call::Intrinsic);
662                     query_buf = Call::make(type_of<struct halide_buffer_t *>(), Call::buffer_init_from_buffer,
663                                            {query_buf, query_shape, in_buf}, Call::Extern);
664 
665                     lets.emplace_back(query_name, query_buf);
666                     Expr buf = Variable::make(type_of<struct halide_buffer_t *>(), query_name, b, p, ReductionDomain());
667                     bounds_inference_args.push_back(buf);
668                     // Although we expect ImageParams to be properly initialized and sanitized by the caller,
669                     // we create a copy with copy_memory (not msan-aware), so we need to annotate it as initialized.
670                     buffers_to_annotate.emplace_back(bounds_inference_args.back(), dims);
671                 } else {
672                     internal_error << "Bad ExternFuncArgument type";
673                 }
674             }
675 
676             // Make the buffer_ts representing the output. They all
677             // use the same size, but have differing types.
678             for (int j = 0; j < func.outputs(); j++) {
679                 BufferBuilder builder;
680                 builder.type = func.output_types()[j];
681                 builder.dimensions = func.dimensions();
682                 for (const string &arg : func.args()) {
683                     string prefix = func.name() + ".s" + std::to_string(stage) + "." + arg;
684                     Expr min = Variable::make(Int(32), prefix + ".min");
685                     Expr max = Variable::make(Int(32), prefix + ".max");
686                     builder.mins.push_back(min);
687                     builder.extents.push_back(max + 1 - min);
688                     builder.strides.emplace_back(0);
689                 }
690                 Expr output_buffer_t = builder.build();
691 
692                 string buf_name = func.name() + ".o" + std::to_string(j) + ".bounds_query";
693                 bounds_inference_args.push_back(Variable::make(type_of<struct halide_buffer_t *>(), buf_name));
694                 // Since this is a temporary, internal-only buffer used for bounds inference,
695                 // we need to mark it
696                 buffers_to_annotate.emplace_back(bounds_inference_args.back(), func.dimensions());
697                 lets.emplace_back(buf_name, output_buffer_t);
698             }
699 
700             Stmt annotate;
701             if (target.has_feature(Target::MSAN)) {
702                 // Mark the buffers as initialized before calling out.
703                 for (const auto &p : buffers_to_annotate) {
704                     Expr buffer = p.first;
705                     int dimensions = p.second;
706                     // Return type is really 'void', but no way to represent that in our IR.
707                     // Precedent (from halide_print, etc) is to use Int(32) and ignore the result.
708                     Expr sizeof_buffer_t = cast<uint64_t>(
709                         Call::make(Int(32), Call::size_of_halide_buffer_t, {}, Call::Intrinsic));
710                     Stmt mark_buffer =
711                         Evaluate::make(Call::make(Int(32), "halide_msan_annotate_memory_is_initialized",
712                                                   {buffer, sizeof_buffer_t}, Call::Extern));
713                     Expr shape = Call::make(type_of<halide_dimension_t *>(), Call::buffer_get_shape, {buffer},
714                                             Call::Extern);
715                     Expr shape_size = Expr((uint64_t)(sizeof(halide_dimension_t) * dimensions));
716                     Stmt mark_shape =
717                         Evaluate::make(Call::make(Int(32), "halide_msan_annotate_memory_is_initialized",
718                                                   {shape, shape_size}, Call::Extern));
719 
720                     mark_buffer = Block::make(mark_buffer, mark_shape);
721                     if (annotate.defined()) {
722                         annotate = Block::make(annotate, mark_buffer);
723                     } else {
724                         annotate = mark_buffer;
725                     }
726                 }
727             }
728 
729             // Make the extern call
730             Expr e = func.make_call_to_extern_definition(bounds_inference_args, target);
731 
732             // Check if it succeeded
733             string result_name = unique_name('t');
734             Expr result = Variable::make(Int(32), result_name);
735             Expr error = Call::make(Int(32), "halide_error_bounds_inference_call_failed",
736                                     {extern_name, result}, Call::Extern);
737             Stmt check = AssertStmt::make(EQ::make(result, 0), error);
738 
739             check = LetStmt::make(result_name, e, check);
740 
741             if (annotate.defined()) {
742                 check = Block::make(annotate, check);
743             }
744 
745             // Now inner code is free to extract the fields from the halide_buffer_t
746             s = Block::make(check, s);
747 
748             // Wrap in let stmts defining the args
749             for (size_t i = 0; i < lets.size(); i++) {
750                 s = LetStmt::make(lets[i].first, lets[i].second, s);
751             }
752 
753             return s;
754         }
755 
756         // A scope giving the bounds for variables used by this stage.
757         // We need to take into account specializations which may refer to
758         // different reduction variables as well.
populate_scopeHalide::Internal::BoundsInference::Stage759         void populate_scope(Scope<Interval> &result) {
760             for (const string &farg : func.args()) {
761                 string arg = name + ".s" + std::to_string(stage) + "." + farg;
762                 result.push(farg,
763                             Interval(Variable::make(Int(32), arg + ".min"),
764                                      Variable::make(Int(32), arg + ".max")));
765             }
766             if (stage > 0) {
767                 for (const ReductionVariable &rv : rvars) {
768                     string arg = name + ".s" + std::to_string(stage) + "." + rv.var;
769                     result.push(rv.var, Interval(Variable::make(Int(32), arg + ".min"),
770                                                  Variable::make(Int(32), arg + ".max")));
771                 }
772             }
773 
774             /*for (size_t i = 0; i < func.definition().schedule().bounds().size(); i++) {
775                 const Bound &b = func.definition().schedule().bounds()[i];
776                 result.push(b.var, Interval(b.min, (b.min + b.extent) - 1));
777             }*/
778         }
779     };
780     vector<Stage> stages;
781 
BoundsInference(const vector<Function> & f,const vector<vector<Function>> & fg,const vector<set<FusedPair>> & fp,const vector<Function> & outputs,const FuncValueBounds & fb,const Target & target)782     BoundsInference(const vector<Function> &f,
783                     const vector<vector<Function>> &fg,
784                     const vector<set<FusedPair>> &fp,
785                     const vector<Function> &outputs,
786                     const FuncValueBounds &fb,
787                     const Target &target)
788         : funcs(f), fused_groups(fg), fused_pairs_in_groups(fp), func_bounds(fb), target(target) {
789         internal_assert(!f.empty());
790 
791         // Compute the intrinsic relationships between the stages of
792         // the functions.
793 
794         // Figure out which functions will be inlined away
795         vector<bool> inlined(f.size());
796         for (size_t i = 0; i < inlined.size(); i++) {
797             if (i < f.size() - 1 &&
798                 f[i].schedule().compute_level().is_inlined() &&
799                 f[i].can_be_inlined()) {
800                 inlined[i] = true;
801             } else {
802                 inlined[i] = false;
803             }
804         }
805 
806         // First lay out all the stages in their realization order.
807         // The functions are already in topologically sorted order, so
808         // this is straight-forward.
809         for (size_t i = 0; i < f.size(); i++) {
810 
811             if (inlined[i]) continue;
812 
813             Stage s;
814             s.func = f[i];
815             s.stage = 0;
816             s.name = s.func.name();
817             s.fused_group_index = find_fused_group_index(s.func, fused_groups);
818             s.compute_exprs();
819             s.stage_prefix = s.name + ".s0.";
820             stages.push_back(s);
821 
822             for (size_t j = 0; j < f[i].updates().size(); j++) {
823                 s.stage = (int)(j + 1);
824                 s.stage_prefix = s.name + ".s" + std::to_string(s.stage) + ".";
825                 s.compute_exprs();
826                 stages.push_back(s);
827             }
828         }
829 
830         // Do any pure inlining (TODO: This is currently slow)
831         for (size_t i = f.size(); i > 0; i--) {
832             Function func = f[i - 1];
833             if (inlined[i - 1]) {
834                 for (size_t j = 0; j < stages.size(); j++) {
835                     Stage &s = stages[j];
836                     for (size_t k = 0; k < s.exprs.size(); k++) {
837                         CondValue &cond_val = s.exprs[k];
838                         internal_assert(cond_val.value.defined());
839                         cond_val.value = inline_function(cond_val.value, func);
840                     }
841                 }
842             }
843         }
844 
845         // Remove the inlined stages
846         vector<Stage> new_stages;
847         for (size_t i = 0; i < stages.size(); i++) {
848             if (!stages[i].func.schedule().compute_level().is_inlined() ||
849                 !stages[i].func.can_be_inlined()) {
850                 new_stages.push_back(stages[i]);
851             }
852         }
853         new_stages.swap(stages);
854 
855         // Dump the stages post-inlining for debugging
856         /*
857         debug(0) << "Bounds inference stages after inlining: \n";
858         for (size_t i = 0; i < stages.size(); i++) {
859             debug(0) << " " << i << ") " << stages[i].name << "\n";
860         }
861         */
862 
863         // Then compute relationships between them.
864         for (size_t i = 0; i < stages.size(); i++) {
865 
866             Stage &consumer = stages[i];
867 
868             // Set up symbols representing the bounds over which this
869             // stage will be computed.
870             Scope<Interval> scope;
871             consumer.populate_scope(scope);
872 
873             // Compute all the boxes of the producers this consumer
874             // uses.
875             map<string, Box> boxes;
876             if (consumer.func.has_extern_definition() &&
877                 !consumer.func.extern_definition_proxy_expr().defined()) {
878 
879                 const vector<ExternFuncArgument> &args = consumer.func.extern_arguments();
880                 // Stage::define_bounds is going to compute a query
881                 // halide_buffer_t per producer for bounds inference to
882                 // use. We just need to extract those values.
883                 for (size_t j = 0; j < args.size(); j++) {
884                     if (args[j].is_func()) {
885                         Function f(args[j].func);
886                         has_extern_consumer.insert(f.name());
887                         string stage_name = f.name() + ".s" + std::to_string(f.updates().size());
888                         Box b(f.dimensions());
889                         for (int d = 0; d < f.dimensions(); d++) {
890                             string buf_name = f.name() + ".o0.bounds_query." + consumer.name;
891                             Expr buf = Variable::make(type_of<struct halide_buffer_t *>(), buf_name);
892                             Expr min = Call::make(Int(32), Call::buffer_get_min,
893                                                   {buf, d}, Call::Extern);
894                             Expr max = Call::make(Int(32), Call::buffer_get_max,
895                                                   {buf, d}, Call::Extern);
896                             b[d] = Interval(min, max);
897                         }
898                         merge_boxes(boxes[f.name()], b);
899                     }
900                 }
901             } else {
902                 for (const auto &cval : consumer.exprs) {
903                     map<string, Box> new_boxes;
904                     new_boxes = boxes_required(cval.value, scope, func_bounds);
905                     for (auto &i : new_boxes) {
906                         // Add the condition on which this value is evaluated to the box before merging
907                         Box &box = i.second;
908                         box.used = cval.cond;
909                         merge_boxes(boxes[i.first], box);
910                     }
911                 }
912             }
913 
914             // Expand the bounds required of all the producers found
915             // (and we are checking until i, because stages are topologically sorted).
916             for (size_t j = 0; j < i; j++) {
917                 Stage &producer = stages[j];
918                 // A consumer depends on *all* stages of a producer, not just the last one.
919                 const Box &b = boxes[producer.func.name()];
920 
921                 if (!b.empty()) {
922                     // Check for unboundedness
923                     for (size_t k = 0; k < b.size(); k++) {
924                         if (!b[k].is_bounded()) {
925                             std::ostringstream err;
926                             if (consumer.stage == 0) {
927                                 err << "The pure definition ";
928                             } else {
929                                 err << "Update definition number " << (consumer.stage - 1);
930                             }
931                             err << " of Function " << consumer.name
932                                 << " calls function " << producer.name
933                                 << " in an unbounded way in dimension " << k << "\n";
934                             user_error << err.str();
935                         }
936                     }
937 
938                     // Dump out the region required of each stage for debugging.
939 
940                     /*
941                     debug(0) << "Box required of " << producer.name
942                              << " by " << consumer.name
943                              << " stage " << consumer.stage << ":\n";
944                     for (size_t k = 0; k < b.size(); k++) {
945                         debug(0) << "  " << b[k].min << " ... " << b[k].max << "\n";
946                     }
947                     debug(0) << "\n";
948                     */
949 
950                     producer.bounds[{consumer.name, consumer.stage}] = b;
951                     producer.consumers.push_back((int)i);
952                 }
953             }
954         }
955 
956         // The region required of the each output is expanded to include the size of the output buffer.
957         for (Function output : outputs) {
958             Box output_box;
959             string buffer_name = output.name();
960             if (output.outputs() > 1) {
961                 // Use the output size of the first output buffer
962                 buffer_name += ".0";
963             }
964             for (int d = 0; d < output.dimensions(); d++) {
965                 Parameter buf = output.output_buffers()[0];
966                 Expr min = Variable::make(Int(32), buffer_name + ".min." + std::to_string(d), buf);
967                 Expr extent = Variable::make(Int(32), buffer_name + ".extent." + std::to_string(d), buf);
968 
969                 // Respect any output min and extent constraints
970                 Expr min_constraint = buf.min_constraint(d);
971                 Expr extent_constraint = buf.extent_constraint(d);
972 
973                 if (min_constraint.defined()) {
974                     min = min_constraint;
975                 }
976                 if (extent_constraint.defined()) {
977                     extent = extent_constraint;
978                 }
979 
980                 output_box.push_back(Interval(min, (min + extent) - 1));
981             }
982             for (size_t i = 0; i < stages.size(); i++) {
983                 Stage &s = stages[i];
984                 if (!s.func.same_as(output)) continue;
985                 s.bounds[{s.name, s.stage}] = output_box;
986             }
987         }
988     }
989 
990     using IRMutator::visit;
991 
visit(const For * op)992     Stmt visit(const For *op) override {
993         // Don't recurse inside loops marked 'Extern', they will be
994         // removed later.
995         if (op->for_type == ForType::Extern) {
996             return op;
997         }
998 
999         set<string> old_inner_productions;
1000         inner_productions.swap(old_inner_productions);
1001 
1002         Stmt body = op->body;
1003 
1004         // Walk inside of any let/if statements that don't depend on
1005         // bounds inference results so that we don't needlessly
1006         // complicate our bounds expressions.
1007         vector<pair<string, Expr>> wrappers;
1008         while (1) {
1009             if (const LetStmt *let = body.as<LetStmt>()) {
1010                 if (depends_on_bounds_inference(let->value)) {
1011                     break;
1012                 }
1013 
1014                 body = let->body;
1015                 wrappers.emplace_back(let->name, let->value);
1016             } else if (const IfThenElse *if_then_else = body.as<IfThenElse>()) {
1017                 if (depends_on_bounds_inference(if_then_else->condition) ||
1018                     if_then_else->else_case.defined()) {
1019                     break;
1020                 }
1021 
1022                 body = if_then_else->then_case;
1023                 wrappers.emplace_back(std::string(), if_then_else->condition);
1024             } else {
1025                 break;
1026             }
1027         }
1028 
1029         // If there are no pipelines at this loop level, we can skip
1030         // most of the work.  Consider 'extern' for loops as pipelines
1031         // (we aren't recursing into these loops above).
1032         bool no_pipelines =
1033             body.as<For>() != nullptr &&
1034             body.as<For>()->for_type != ForType::Extern;
1035 
1036         // Figure out which stage of which function we're producing
1037         int producing = -1;
1038         Function f;
1039         int stage_index = -1;
1040         string stage_name;
1041         for (size_t i = 0; i < stages.size(); i++) {
1042             if (starts_with(op->name, stages[i].stage_prefix)) {
1043                 producing = i;
1044                 f = stages[i].func;
1045                 stage_index = (int)stages[i].stage;
1046                 stage_name = stages[i].name + ".s" + std::to_string(stages[i].stage);
1047                 break;
1048             }
1049         }
1050 
1051         // Figure out how much of it we're producing
1052 
1053         // Note: the case when functions are fused is a little bit tricky, so may need extra care:
1054         // when we're producing some of a Func A, at every loop belonging to A
1055         // you potentially need to define symbols for what box is being computed
1056         // of A (A.x.min, A.x.max ...), because that any other producer Func P nested
1057         // there is going to define its loop bounds in terms of these symbols, to ensure
1058         // it computes enough of itself to satisfy the consumer.
1059         // Now say we compute B with A, and say B consumes P, not A. Bounds inference
1060         // will see the shared loop, and think it belongs to A only. It will define A.x.min and
1061         // friends, but that's not very useful, because P's loops are in terms of B.x.min, B.x.max, etc.
1062         // So without a local definition of those symbols, P will use the one in the outer scope, and
1063         // compute way too much of itself. It'll still be correct, but it's massive over-compute.
1064         // The fix is to realize that in this loop belonging to A, we also potentially need to define
1065         // a box for B, because B belongs to the same fused group as A, so really this loop belongs to A and B.
1066         // We'll get the box using boxes_provided and only filtering for A and B after the fact
1067         // Note that even though the loops are fused, the boxes touched of A and B might be totally different,
1068         // because e.g. B could be double-resolution (as happens when fusing yuv computations), so this
1069         // is not just a matter of giving A's box B's name as an alias.
1070         map<string, Box> boxes_for_fused_group;
1071         map<string, Function> stage_name_to_func;
1072         if (!no_pipelines && producing >= 0 && !f.has_extern_definition()) {
1073             Scope<Interval> empty_scope;
1074             size_t last_dot = op->name.rfind('.');
1075             string var = op->name.substr(last_dot + 1);
1076 
1077             set<pair<string, int>> fused_with_f;
1078             for (const auto &pair : fused_pairs_in_groups[stages[producing].fused_group_index]) {
1079                 if (!((pair.func_1 == stages[producing].name) && ((int)pair.stage_1 == stage_index)) && is_fused_with_others(fused_groups, fused_pairs_in_groups,
1080                                                                                                                              f, stage_index,
1081                                                                                                                              pair.func_1, pair.stage_1, var)) {
1082                     fused_with_f.insert(make_pair(pair.func_1, pair.stage_1));
1083                 }
1084                 if (!((pair.func_2 == stages[producing].name) && ((int)pair.stage_2 == stage_index)) && is_fused_with_others(fused_groups, fused_pairs_in_groups,
1085                                                                                                                              f, stage_index,
1086                                                                                                                              pair.func_2, pair.stage_2, var)) {
1087                     fused_with_f.insert(make_pair(pair.func_2, pair.stage_2));
1088                 }
1089             }
1090 
1091             if (fused_with_f.empty()) {
1092                 boxes_for_fused_group[stage_name] = box_provided(body, stages[producing].name, empty_scope, func_bounds);
1093                 stage_name_to_func[stage_name] = f;
1094                 internal_assert((int)boxes_for_fused_group[stage_name].size() == f.dimensions());
1095             } else {
1096                 auto boxes = boxes_provided(body, empty_scope, func_bounds);
1097                 boxes_for_fused_group[stage_name] = boxes[stages[producing].name];
1098                 stage_name_to_func[stage_name] = f;
1099                 internal_assert((int)boxes_for_fused_group[stage_name].size() == f.dimensions());
1100                 for (const auto &fused : fused_with_f) {
1101                     string fused_stage_name = fused.first + ".s" + std::to_string(fused.second);
1102                     boxes_for_fused_group[fused_stage_name] = boxes[fused.first];
1103                     for (const auto &fn : funcs) {
1104                         if (fn.name() == fused.first) {
1105                             stage_name_to_func[fused_stage_name] = fn;
1106                             break;
1107                         }
1108                     }
1109                 }
1110             }
1111         }
1112 
1113         // Recurse.
1114         body = mutate(body);
1115 
1116         if (!no_pipelines) {
1117             // We only care about the bounds of a func if:
1118             // A) We're not already in a pipeline over that func AND
1119             // B.1) There's a production of this func somewhere inside this loop OR
1120             // B.2) We're downstream (a consumer) of a func for which we care about the bounds.
1121             vector<bool> bounds_needed(stages.size(), false);
1122             for (size_t i = 0; i < stages.size(); i++) {
1123                 if (inner_productions.count(stages[i].name)) {
1124                     bounds_needed[i] = true;
1125                 }
1126 
1127                 if (in_pipeline.count(stages[i].name)) {
1128                     bounds_needed[i] = false;
1129                 }
1130 
1131                 if (bounds_needed[i]) {
1132                     for (size_t j = 0; j < stages[i].consumers.size(); j++) {
1133                         bounds_needed[stages[i].consumers[j]] = true;
1134                     }
1135                     body = stages[i].define_bounds(
1136                         body, f, stage_name, stage_index, op->name, fused_groups,
1137                         fused_pairs_in_groups, in_pipeline, inner_productions,
1138                         has_extern_consumer, target);
1139                 }
1140             }
1141 
1142             // Finally, define the production bounds for the thing
1143             // we're producing.
1144             if (producing >= 0 && !inner_productions.empty()) {
1145                 for (const auto &b : boxes_for_fused_group) {
1146                     const vector<string> &f_args = stage_name_to_func[b.first].args();
1147                     const auto &box = b.second;
1148                     internal_assert(f_args.size() == box.size());
1149                     for (size_t i = 0; i < box.size(); i++) {
1150                         internal_assert(box[i].is_bounded());
1151                         string var = b.first + "." + f_args[i];
1152 
1153                         if (box[i].is_single_point()) {
1154                             body = LetStmt::make(var + ".max", Variable::make(Int(32), var + ".min"), body);
1155                         } else {
1156                             body = LetStmt::make(var + ".max", box[i].max, body);
1157                         }
1158 
1159                         body = LetStmt::make(var + ".min", box[i].min, body);
1160                     }
1161                 }
1162             }
1163 
1164             // And the current bounds on its reduction variables, and
1165             // variables from extern for loops.
1166             if (producing >= 0) {
1167                 const Stage &s = stages[producing];
1168                 vector<string> vars;
1169                 if (s.func.has_extern_definition()) {
1170                     vars = s.func.args();
1171                 }
1172                 if (stages[producing].stage > 0) {
1173                     for (const ReductionVariable &rv : s.rvars) {
1174                         vars.push_back(rv.var);
1175                     }
1176                 }
1177                 for (const string &i : vars) {
1178                     string var = s.stage_prefix + i;
1179                     Interval in = bounds_of_inner_var(var, body);
1180                     if (in.is_bounded()) {
1181                         // bounds_of_inner_var doesn't understand
1182                         // GuardWithIf, but we know split rvars never
1183                         // have inner bounds that exceed the outer
1184                         // ones.
1185                         if (!s.rvars.empty()) {
1186                             in.min = max(in.min, Variable::make(Int(32), var + ".min"));
1187                             in.max = min(in.max, Variable::make(Int(32), var + ".max"));
1188                         }
1189 
1190                         body = LetStmt::make(var + ".min", in.min, body);
1191                         body = LetStmt::make(var + ".max", in.max, body);
1192                     } else {
1193                         // If it's not found, we're already in the
1194                         // scope of the injected let. The let was
1195                         // probably lifted to an outer level.
1196                         Expr val = Variable::make(Int(32), var);
1197                         body = LetStmt::make(var + ".min", val, body);
1198                         body = LetStmt::make(var + ".max", val, body);
1199                     }
1200                 }
1201             }
1202         }
1203 
1204         inner_productions.insert(old_inner_productions.begin(),
1205                                  old_inner_productions.end());
1206 
1207         // Rewrap the let/if statements
1208         for (size_t i = wrappers.size(); i > 0; i--) {
1209             const auto &p = wrappers[i - 1];
1210             if (p.first.empty()) {
1211                 body = IfThenElse::make(p.second, body);
1212             } else {
1213                 body = LetStmt::make(p.first, p.second, body);
1214             }
1215         }
1216 
1217         return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body);
1218     }
1219 
visit(const ProducerConsumer * p)1220     Stmt visit(const ProducerConsumer *p) override {
1221         in_pipeline.insert(p->name);
1222         Stmt stmt = IRMutator::visit(p);
1223         in_pipeline.erase(p->name);
1224         inner_productions.insert(p->name);
1225         return stmt;
1226     }
1227 };
1228 
bounds_inference(Stmt s,const vector<Function> & outputs,const vector<string> & order,const vector<vector<string>> & fused_groups,const map<string,Function> & env,const FuncValueBounds & func_bounds,const Target & target)1229 Stmt bounds_inference(Stmt s,
1230                       const vector<Function> &outputs,
1231                       const vector<string> &order,
1232                       const vector<vector<string>> &fused_groups,
1233                       const map<string, Function> &env,
1234                       const FuncValueBounds &func_bounds,
1235                       const Target &target) {
1236 
1237     vector<Function> funcs(order.size());
1238     for (size_t i = 0; i < order.size(); i++) {
1239         funcs[i] = env.find(order[i])->second;
1240     }
1241 
1242     // Each element in 'fused_func_groups' indicates a group of functions
1243     // which loops should be fused together.
1244     vector<vector<Function>> fused_func_groups;
1245     for (const vector<string> &group : fused_groups) {
1246         vector<Function> fs;
1247         for (const string &fname : group) {
1248             fs.push_back(env.find(fname)->second);
1249         }
1250         fused_func_groups.push_back(fs);
1251     }
1252 
1253     // For each fused group, collect the pairwise fused function stages.
1254     vector<set<FusedPair>> fused_pairs_in_groups;
1255     for (const vector<string> &group : fused_groups) {
1256         set<FusedPair> pairs;
1257         for (const string &fname : group) {
1258             Function f = env.find(fname)->second;
1259             if (!f.has_extern_definition()) {
1260                 std::copy(f.definition().schedule().fused_pairs().begin(),
1261                           f.definition().schedule().fused_pairs().end(),
1262                           std::inserter(pairs, pairs.end()));
1263 
1264                 for (size_t i = 0; i < f.updates().size(); ++i) {
1265                     std::copy(f.updates()[i].schedule().fused_pairs().begin(),
1266                               f.updates()[i].schedule().fused_pairs().end(),
1267                               std::inserter(pairs, pairs.end()));
1268                 }
1269             }
1270         }
1271         fused_pairs_in_groups.push_back(pairs);
1272     }
1273 
1274     // Add a note in the IR for where assertions on input images
1275     // should go. Those are handled by a later lowering pass.
1276     Expr marker = Call::make(Int(32), Call::add_image_checks_marker, {}, Call::Intrinsic);
1277     s = Block::make(Evaluate::make(marker), s);
1278 
1279     // Add a synthetic outermost loop to act as 'root'.
1280     s = For::make("<outermost>", 0, 1, ForType::Serial, DeviceAPI::None, s);
1281 
1282     s = BoundsInference(funcs, fused_func_groups, fused_pairs_in_groups,
1283                         outputs, func_bounds, target)
1284             .mutate(s);
1285     return s.as<For>()->body;
1286 }
1287 
1288 }  // namespace Internal
1289 }  // namespace Halide
1290