1 #include "BoundsInference.h"
2 #include "Bounds.h"
3 #include "ExternFuncArgument.h"
4 #include "Function.h"
5 #include "IREquality.h"
6 #include "IRMutator.h"
7 #include "IROperator.h"
8 #include "Inline.h"
9 #include "Scope.h"
10 #include "Simplify.h"
11
12 #include <algorithm>
13 #include <iterator>
14
15 namespace Halide {
16 namespace Internal {
17
18 using std::map;
19 using std::pair;
20 using std::set;
21 using std::string;
22 using std::vector;
23
24 namespace {
25
var_name_match(const string & candidate,const string & var)26 bool var_name_match(const string &candidate, const string &var) {
27 internal_assert(var.find('.') == string::npos)
28 << "var_name_match expects unqualified names for the second argument. "
29 << "Name passed: " << var << "\n";
30 return (candidate == var) || Internal::ends_with(candidate, "." + var);
31 }
32
33 class DependsOnBoundsInference : public IRVisitor {
34 using IRVisitor::visit;
35
visit(const Variable * var)36 void visit(const Variable *var) override {
37 if (ends_with(var->name, ".max") ||
38 ends_with(var->name, ".min")) {
39 result = true;
40 }
41 }
42
visit(const Call * op)43 void visit(const Call *op) override {
44 if (op->name == Call::buffer_get_min ||
45 op->name == Call::buffer_get_max) {
46 result = true;
47 } else {
48 IRVisitor::visit(op);
49 }
50 }
51
52 public:
53 bool result;
DependsOnBoundsInference()54 DependsOnBoundsInference()
55 : result(false) {
56 }
57 };
58
depends_on_bounds_inference(const Expr & e)59 bool depends_on_bounds_inference(const Expr &e) {
60 DependsOnBoundsInference d;
61 e.accept(&d);
62 return d.result;
63 }
64
65 /** Compute the bounds of the value of some variable defined by an
66 * inner let stmt or for loop. E.g. for the stmt:
67 *
68 *
69 * for x from 0 to 10:
70 * let y = x + 2;
71 *
72 * bounds_of_inner_var(y) would return 2 to 12, and
73 * bounds_of_inner_var(x) would return 0 to 10.
74 */
75 class BoundsOfInnerVar : public IRVisitor {
76 public:
77 Interval result;
BoundsOfInnerVar(const string & v)78 BoundsOfInnerVar(const string &v)
79 : var(v) {
80 }
81
82 private:
83 string var;
84 Scope<Interval> scope;
85
86 using IRVisitor::visit;
87
visit(const LetStmt * op)88 void visit(const LetStmt *op) override {
89 Interval in = bounds_of_expr_in_scope(op->value, scope);
90 if (op->name == var) {
91 result = in;
92 } else {
93 ScopedBinding<Interval> p(scope, op->name, in);
94 op->body.accept(this);
95 }
96 }
97
visit(const For * op)98 void visit(const For *op) override {
99 // At this stage of lowering, loop_min and loop_max
100 // conveniently exist in scope.
101 Interval in(Variable::make(Int(32), op->name + ".loop_min"),
102 Variable::make(Int(32), op->name + ".loop_max"));
103
104 if (op->name == var) {
105 result = in;
106 } else {
107 ScopedBinding<Interval> p(scope, op->name, in);
108 op->body.accept(this);
109 }
110 }
111 };
112
bounds_of_inner_var(const string & var,const Stmt & s)113 Interval bounds_of_inner_var(const string &var, const Stmt &s) {
114 BoundsOfInnerVar b(var);
115 s.accept(&b);
116 return b.result;
117 }
118
find_fused_group_index(const Function & producing_func,const vector<vector<Function>> & fused_groups)119 size_t find_fused_group_index(const Function &producing_func,
120 const vector<vector<Function>> &fused_groups) {
121 const auto &iter = std::find_if(fused_groups.begin(), fused_groups.end(),
122 [&producing_func](const vector<Function> &group) {
123 return std::any_of(group.begin(), group.end(),
124 [&producing_func](const Function &f) {
125 return (f.name() == producing_func.name());
126 });
127 });
128 internal_assert(iter != fused_groups.end());
129 return iter - fused_groups.begin();
130 }
131
132 // Determine if the current producing stage is fused with other
133 // stage (i.e. the consumer stage) at dimension 'var'.
is_fused_with_others(const vector<vector<Function>> & fused_groups,const vector<set<FusedPair>> & fused_pairs_in_groups,const Function & producing_func,int producing_stage_index,const string & consumer_name,int consumer_stage,string var)134 bool is_fused_with_others(const vector<vector<Function>> &fused_groups,
135 const vector<set<FusedPair>> &fused_pairs_in_groups,
136 const Function &producing_func, int producing_stage_index,
137 const string &consumer_name, int consumer_stage,
138 string var) {
139 if (producing_func.has_extern_definition()) {
140 return false;
141 }
142
143 // Find the fused group this producing stage belongs to.
144 size_t index = find_fused_group_index(producing_func, fused_groups);
145
146 const vector<Dim> &dims = (producing_stage_index == 0) ? producing_func.definition().schedule().dims() : producing_func.update(producing_stage_index - 1).schedule().dims();
147
148 size_t var_index;
149 {
150 const auto &iter = std::find_if(dims.begin(), dims.end(),
151 [&var](const Dim &d) { return var_name_match(d.var, var); });
152 if (iter == dims.end()) {
153 return false;
154 }
155 var_index = iter - dims.begin();
156 }
157
158 // Iterate over the fused pair list to check if the producer stage
159 // is fused with the consumer stage at 'var'
160 for (const auto &pair : fused_pairs_in_groups[index]) {
161 if (((pair.func_1 == consumer_name) && ((int)pair.stage_1 == consumer_stage)) ||
162 ((pair.func_2 == consumer_name) && ((int)pair.stage_2 == consumer_stage))) {
163 const auto &iter = std::find_if(dims.begin(), dims.end(),
164 [&pair](const Dim &d) { return var_name_match(d.var, pair.var_name); });
165 if (iter == dims.end()) {
166 continue;
167 }
168 size_t idx = iter - dims.begin();
169 if (var_index >= idx) {
170 return true;
171 }
172 }
173 }
174 return false;
175 }
176 } // namespace
177
178 class BoundsInference : public IRMutator {
179 public:
180 const vector<Function> &funcs;
181 // Each element in the list indicates a group of functions which loops
182 // are fused together.
183 const vector<vector<Function>> &fused_groups;
184 // Contain list of all pairwise fused function stages for each fused group.
185 // The fused group is indexed in the same way as 'fused_groups'.
186 const vector<set<FusedPair>> &fused_pairs_in_groups;
187 const FuncValueBounds &func_bounds;
188 set<string> in_pipeline, inner_productions, has_extern_consumer;
189 const Target target;
190
191 struct CondValue {
192 Expr cond; // Condition on params only (can't depend on loop variable)
193 Expr value;
194
CondValueHalide::Internal::BoundsInference::CondValue195 CondValue(const Expr &c, const Expr &v)
196 : cond(c), value(v) {
197 }
198 };
199
200 struct Stage {
201 Function func;
202 size_t stage; // 0 is the pure definition, 1 is the first update
203 string name;
204 vector<int> consumers;
205 map<pair<string, int>, Box> bounds;
206 vector<CondValue> exprs;
207 set<ReductionVariable, ReductionVariable::Compare> rvars;
208 string stage_prefix;
209 size_t fused_group_index;
210
211 // Computed expressions on the left and right-hand sides.
212 // Note that a function definition might have different LHS or reduction domain
213 // (if it's an update def) or RHS per specialization. All specializations
214 // of an init definition should have the same LHS.
215 // This also pushes all the reduction domains it encounters into the 'rvars'
216 // set for later use.
compute_exprs_helperHalide::Internal::BoundsInference::Stage217 vector<vector<CondValue>> compute_exprs_helper(const Definition &def, bool is_update) {
218 vector<vector<CondValue>> result(2); // <args, values>
219
220 if (!def.defined()) {
221 return result;
222 }
223
224 // Default case (no specialization)
225 vector<Expr> predicates = def.split_predicate();
226 for (const ReductionVariable &rv : def.schedule().rvars()) {
227 rvars.insert(rv);
228 }
229
230 vector<vector<Expr>> vecs(2);
231 if (is_update) {
232 vecs[0] = def.args();
233 }
234 vecs[1] = def.values();
235
236 for (size_t i = 0; i < result.size(); ++i) {
237 for (const Expr &val : vecs[i]) {
238 if (!predicates.empty()) {
239 Expr cond_val = Call::make(val.type(),
240 Internal::Call::if_then_else,
241 {likely(predicates[0]), val, make_zero(val.type())},
242 Internal::Call::PureIntrinsic);
243 for (size_t i = 1; i < predicates.size(); ++i) {
244 cond_val = Call::make(cond_val.type(),
245 Internal::Call::if_then_else,
246 {likely(predicates[i]), cond_val, make_zero(cond_val.type())},
247 Internal::Call::PureIntrinsic);
248 }
249 result[i].push_back(CondValue(const_true(), cond_val));
250 } else {
251 result[i].push_back(CondValue(const_true(), val));
252 }
253 }
254 }
255
256 const vector<Specialization> &specializations = def.specializations();
257 for (size_t i = specializations.size(); i > 0; i--) {
258 Expr s_cond = specializations[i - 1].condition;
259 const Definition &s_def = specializations[i - 1].definition;
260
261 // Else case (i.e. specialization condition is false)
262 for (auto &vec : result) {
263 for (CondValue &cval : vec) {
264 cval.cond = simplify(!s_cond && cval.cond);
265 }
266 }
267
268 // Then case (i.e. specialization condition is true)
269 vector<vector<CondValue>> s_result = compute_exprs_helper(s_def, is_update);
270 for (auto &vec : s_result) {
271 for (CondValue &cval : vec) {
272 cval.cond = simplify(s_cond && cval.cond);
273 }
274 }
275 for (size_t i = 0; i < result.size(); i++) {
276 result[i].insert(result[i].end(), s_result[i].begin(), s_result[i].end());
277 }
278 }
279
280 // Optimization: If the args/values across specializations including
281 // the default case, are the same, we can combine those args/values
282 // into one arg/value with a const_true() condition for the purpose
283 // of bounds inference.
284 for (auto &vec : result) {
285 if (vec.size() > 1) {
286 bool all_equal = true;
287 Expr val = vec[0].value;
288 for (size_t i = 1; i < vec.size(); ++i) {
289 if (!equal(val, vec[i].value)) {
290 all_equal = false;
291 break;
292 }
293 }
294 if (all_equal) {
295 debug(4) << "compute_exprs: all values (size: " << vec.size() << ") "
296 << "(" << val << ") are equal, combine them together\n";
297 internal_assert(val.defined());
298 vec.clear();
299 vec.emplace_back(const_true(), val);
300 }
301 }
302 }
303 return result;
304 }
305
306 // Computed expressions on the left and right-hand sides. This also
307 // pushes all reduction domains it encounters into the 'rvars' set
308 // for later use.
compute_exprsHalide::Internal::BoundsInference::Stage309 void compute_exprs() {
310 // We need to clear 'exprs' and 'rvars' first, in case compute_exprs()
311 // is called multiple times.
312 exprs.clear();
313 rvars.clear();
314
315 bool is_update = (stage != 0);
316 vector<vector<CondValue>> result;
317 if (!is_update) {
318 result = compute_exprs_helper(func.definition(), is_update);
319 } else {
320 const Definition &def = func.update(stage - 1);
321 result = compute_exprs_helper(def, is_update);
322 }
323 internal_assert(result.size() == 2);
324 exprs = result[0];
325
326 if (func.extern_definition_proxy_expr().defined()) {
327 exprs.emplace_back(const_true(), func.extern_definition_proxy_expr());
328 }
329
330 exprs.insert(exprs.end(), result[1].begin(), result[1].end());
331
332 // For the purposes of computation bounds inference, we
333 // don't care what sites are loaded, just what sites need
334 // to have the correct value in them. So remap all selects
335 // to if_then_elses to get tighter bounds.
336 class SelectToIfThenElse : public IRMutator {
337 using IRMutator::visit;
338 Expr visit(const Select *op) override {
339 if (is_pure(op->condition)) {
340 return Call::make(op->type, Call::if_then_else,
341 {mutate(op->condition),
342 mutate(op->true_value),
343 mutate(op->false_value)},
344 Call::PureIntrinsic);
345 } else {
346 return IRMutator::visit(op);
347 }
348 }
349 } select_to_if_then_else;
350
351 for (auto &e : exprs) {
352 e.value = select_to_if_then_else.mutate(e.value);
353 }
354 }
355
356 // Check if the dimension at index 'dim_idx' is always pure (i.e. equal to 'dim')
357 // in the definition (including in its specializations)
is_dim_always_pureHalide::Internal::BoundsInference::Stage358 bool is_dim_always_pure(const Definition &def, const string &dim, int dim_idx) {
359 const Variable *var = def.args()[dim_idx].as<Variable>();
360 if ((!var) || (var->name != dim)) {
361 return false;
362 }
363
364 for (const Specialization &s : def.specializations()) {
365 bool pure = is_dim_always_pure(s.definition, dim, dim_idx);
366 if (!pure) {
367 return false;
368 }
369 }
370 return true;
371 }
372
373 // Wrap a statement in let stmts defining the box
define_boundsHalide::Internal::BoundsInference::Stage374 Stmt define_bounds(Stmt s,
375 const Function &producing_func,
376 const string &producing_stage_index,
377 int producing_stage_index_index,
378 const string &loop_level,
379 const vector<vector<Function>> &fused_groups,
380 const vector<set<FusedPair>> &fused_pairs_in_groups,
381 const set<string> &in_pipeline,
382 const set<string> &inner_productions,
383 const set<string> &has_extern_consumer,
384 const Target &target) {
385
386 // Merge all the relevant boxes.
387 Box b;
388
389 const vector<string> func_args = func.args();
390
391 size_t last_dot = loop_level.rfind('.');
392 string var = loop_level.substr(last_dot + 1);
393
394 for (const pair<const pair<string, int>, Box> &i : bounds) {
395 string func_name = i.first.first;
396 int func_stage_index = i.first.second;
397 string stage_name = func_name + ".s" + std::to_string(func_stage_index);
398 if (stage_name == producing_stage_index ||
399 inner_productions.count(func_name) ||
400 is_fused_with_others(fused_groups, fused_pairs_in_groups,
401 producing_func, producing_stage_index_index,
402 func_name, func_stage_index, var)) {
403 merge_boxes(b, i.second);
404 }
405 }
406
407 internal_assert(b.empty() || b.size() == func_args.size());
408
409 if (!b.empty()) {
410 // Optimization: If a dimension is pure in every update
411 // step of a func, then there exists a single bound for
412 // that dimension, instead of one bound per stage. Let's
413 // figure out what those dimensions are, and just have all
414 // stages but the last use the bounds for the last stage.
415 vector<bool> always_pure_dims(func_args.size(), true);
416 for (const Definition &def : func.updates()) {
417 for (size_t j = 0; j < always_pure_dims.size(); j++) {
418 bool pure = is_dim_always_pure(def, func_args[j], j);
419 if (!pure) {
420 always_pure_dims[j] = false;
421 }
422 }
423 }
424
425 if (stage < func.updates().size()) {
426 size_t stages = func.updates().size();
427 string last_stage = func.name() + ".s" + std::to_string(stages) + ".";
428 for (size_t i = 0; i < always_pure_dims.size(); i++) {
429 if (always_pure_dims[i]) {
430 const string &dim = func_args[i];
431 Expr min = Variable::make(Int(32), last_stage + dim + ".min");
432 Expr max = Variable::make(Int(32), last_stage + dim + ".max");
433 b[i] = Interval(min, max);
434 }
435 }
436 }
437 }
438
439 if (func.has_extern_definition() &&
440 !func.extern_definition_proxy_expr().defined()) {
441 // After we define our bounds required, we need to
442 // figure out what we're actually going to compute,
443 // and what inputs we need. To do this we:
444
445 // 1) Grab a handle on the bounds query results from one level up
446
447 // 2) Run the bounds query to let it round up the output size.
448
449 // 3) Shift the requested output box back inside of the
450 // bounds query result from one loop level up (in case
451 // it was rounded up)
452
453 // 4) then run the bounds query again to get the input
454 // sizes.
455
456 // Because we're wrapping a stmt, this happens in reverse order.
457
458 // 4)
459 s = do_bounds_query(s, in_pipeline, target);
460
461 if (!in_pipeline.empty()) {
462 // 3)
463 string outer_query_name = func.name() + ".outer_bounds_query";
464 Expr outer_query = Variable::make(type_of<struct halide_buffer_t *>(), outer_query_name);
465 string inner_query_name = func.name() + ".o0.bounds_query";
466 Expr inner_query = Variable::make(type_of<struct halide_buffer_t *>(), inner_query_name);
467 for (int i = 0; i < func.dimensions(); i++) {
468 Expr outer_min = Call::make(Int(32), Call::buffer_get_min,
469 {outer_query, i}, Call::Extern);
470 Expr outer_max = Call::make(Int(32), Call::buffer_get_max,
471 {outer_query, i}, Call::Extern);
472
473 Expr inner_min = Call::make(Int(32), Call::buffer_get_min,
474 {inner_query, i}, Call::Extern);
475 Expr inner_max = Call::make(Int(32), Call::buffer_get_max,
476 {inner_query, i}, Call::Extern);
477
478 // Push 'inner' inside of 'outer'
479 Expr shift = Min::make(0, outer_max - inner_max);
480 Expr new_min = inner_min + shift;
481 Expr new_max = inner_max + shift;
482
483 // Modify the region to be computed accordingly
484 s = LetStmt::make(func.name() + ".s0." + func_args[i] + ".max", new_max, s);
485 s = LetStmt::make(func.name() + ".s0." + func_args[i] + ".min", new_min, s);
486 }
487
488 // 2)
489 s = do_bounds_query(s, in_pipeline, target);
490
491 // 1)
492 s = LetStmt::make(func.name() + ".outer_bounds_query",
493 Variable::make(type_of<struct halide_buffer_t *>(), func.name() + ".o0.bounds_query"), s);
494 } else {
495 // If we're at the outermost loop, there is no
496 // bounds query result from one level up, but we
497 // still need to modify the region to be computed
498 // based on the bounds query result and then do
499 // another bounds query to ask for the required
500 // input size given that.
501
502 // 2)
503 string inner_query_name = func.name() + ".o0.bounds_query";
504 Expr inner_query = Variable::make(type_of<struct halide_buffer_t *>(), inner_query_name);
505 for (int i = 0; i < func.dimensions(); i++) {
506 Expr new_min = Call::make(Int(32), Call::buffer_get_min,
507 {inner_query, i}, Call::Extern);
508 Expr new_max = Call::make(Int(32), Call::buffer_get_max,
509 {inner_query, i}, Call::Extern);
510
511 s = LetStmt::make(func.name() + ".s0." + func_args[i] + ".max", new_max, s);
512 s = LetStmt::make(func.name() + ".s0." + func_args[i] + ".min", new_min, s);
513 }
514
515 s = do_bounds_query(s, in_pipeline, target);
516 }
517 }
518
519 if (in_pipeline.count(name) == 0) {
520 // Inject any explicit bounds
521 string prefix = name + ".s" + std::to_string(stage) + ".";
522
523 LoopLevel compute_at = func.schedule().compute_level();
524 LoopLevel store_at = func.schedule().store_level();
525
526 for (size_t i = 0; i < func.schedule().bounds().size(); i++) {
527 Bound bound = func.schedule().bounds()[i];
528 string min_var = prefix + bound.var + ".min";
529 string max_var = prefix + bound.var + ".max";
530 Expr min_required = Variable::make(Int(32), min_var);
531 Expr max_required = Variable::make(Int(32), max_var);
532
533 if (bound.extent.defined()) {
534 // If the Func is compute_at some inner loop, and
535 // only extent is bounded, then the min could
536 // actually move around, which makes the extent
537 // bound not actually useful for determining the
538 // max required from the point of view of
539 // producers.
540 if (bound.min.defined() ||
541 compute_at.is_root() ||
542 (compute_at.match(loop_level) &&
543 store_at.match(loop_level))) {
544 if (!bound.min.defined()) {
545 bound.min = min_required;
546 }
547 s = LetStmt::make(min_var, bound.min, s);
548 s = LetStmt::make(max_var, bound.min + bound.extent - 1, s);
549 }
550
551 // Save the unbounded values to use in bounds-checking assertions
552 s = LetStmt::make(min_var + "_unbounded", min_required, s);
553 s = LetStmt::make(max_var + "_unbounded", max_required, s);
554 }
555
556 if (bound.modulus.defined()) {
557 min_required -= bound.remainder;
558 min_required = (min_required / bound.modulus) * bound.modulus;
559 min_required += bound.remainder;
560 Expr max_plus_one = max_required + 1;
561 max_plus_one -= bound.remainder;
562 max_plus_one = ((max_plus_one + bound.modulus - 1) / bound.modulus) * bound.modulus;
563 max_plus_one += bound.remainder;
564 max_required = max_plus_one - 1;
565 s = LetStmt::make(min_var, min_required, s);
566 s = LetStmt::make(max_var, max_required, s);
567 }
568 }
569 }
570
571 for (size_t d = 0; d < b.size(); d++) {
572 string arg = name + ".s" + std::to_string(stage) + "." + func_args[d];
573
574 const bool clamp_to_outer_bounds =
575 !in_pipeline.empty() && has_extern_consumer.count(name);
576 if (clamp_to_outer_bounds) {
577 // Allocation bounds inference is going to have a
578 // bad time lifting the results of the bounds
579 // queries outwards. Help it out by insisting that
580 // the bounds are clamped to lie within the bounds
581 // one loop level up.
582 Expr outer_min = Variable::make(Int(32), arg + ".outer_min");
583 Expr outer_max = Variable::make(Int(32), arg + ".outer_max");
584 b[d].min = clamp(b[d].min, outer_min, outer_max);
585 b[d].max = clamp(b[d].max, outer_min, outer_max);
586 }
587
588 if (b[d].is_single_point()) {
589 s = LetStmt::make(arg + ".min", Variable::make(Int(32), arg + ".max"), s);
590 } else {
591 s = LetStmt::make(arg + ".min", b[d].min, s);
592 }
593 s = LetStmt::make(arg + ".max", b[d].max, s);
594
595 if (clamp_to_outer_bounds) {
596 s = LetStmt::make(arg + ".outer_min", Variable::make(Int(32), arg + ".min"), s);
597 s = LetStmt::make(arg + ".outer_max", Variable::make(Int(32), arg + ".max"), s);
598 }
599 }
600
601 if (stage > 0) {
602 for (const ReductionVariable &rvar : rvars) {
603 string arg = name + ".s" + std::to_string(stage) + "." + rvar.var;
604 s = LetStmt::make(arg + ".min", rvar.min, s);
605 s = LetStmt::make(arg + ".max", rvar.extent + rvar.min - 1, s);
606 }
607 }
608
609 return s;
610 }
611
do_bounds_queryHalide::Internal::BoundsInference::Stage612 Stmt do_bounds_query(Stmt s, const set<string> &in_pipeline, const Target &target) {
613
614 const string &extern_name = func.extern_function_name();
615 const vector<ExternFuncArgument> &args = func.extern_arguments();
616
617 vector<Expr> bounds_inference_args;
618
619 vector<pair<string, Expr>> lets;
620
621 // Iterate through all of the input args to the extern
622 // function building a suitable argument list for the
623 // extern function call. We need a query halide_buffer_t per
624 // producer and a query halide_buffer_t for the output
625
626 Expr null_handle = make_zero(Handle());
627
628 vector<pair<Expr, int>> buffers_to_annotate;
629 for (size_t j = 0; j < args.size(); j++) {
630 if (args[j].is_expr()) {
631 bounds_inference_args.push_back(args[j].expr);
632 } else if (args[j].is_func()) {
633 Function input(args[j].func);
634 for (int k = 0; k < input.outputs(); k++) {
635 string name = input.name() + ".o" + std::to_string(k) + ".bounds_query." + func.name();
636
637 BufferBuilder builder;
638 builder.type = input.output_types()[k];
639 builder.dimensions = input.dimensions();
640 Expr buf = builder.build();
641
642 lets.emplace_back(name, buf);
643 bounds_inference_args.push_back(Variable::make(type_of<struct halide_buffer_t *>(), name));
644 buffers_to_annotate.emplace_back(bounds_inference_args.back(), input.dimensions());
645 }
646 } else if (args[j].is_image_param() || args[j].is_buffer()) {
647 Parameter p = args[j].image_param;
648 Buffer<> b = args[j].buffer;
649 string name = args[j].is_image_param() ? p.name() : b.name();
650 int dims = args[j].is_image_param() ? p.dimensions() : b.dimensions();
651
652 Expr in_buf = Variable::make(type_of<struct halide_buffer_t *>(), name + ".buffer");
653
654 // Copy the input buffer into a query buffer to mutate.
655 string query_name = name + ".bounds_query." + func.name();
656
657 Expr alloca_size = Call::make(Int(32), Call::size_of_halide_buffer_t, {}, Call::Intrinsic);
658 Expr query_buf = Call::make(type_of<struct halide_buffer_t *>(), Call::alloca,
659 {alloca_size}, Call::Intrinsic);
660 Expr query_shape = Call::make(type_of<struct halide_dimension_t *>(), Call::alloca,
661 {(int)(sizeof(halide_dimension_t) * dims)}, Call::Intrinsic);
662 query_buf = Call::make(type_of<struct halide_buffer_t *>(), Call::buffer_init_from_buffer,
663 {query_buf, query_shape, in_buf}, Call::Extern);
664
665 lets.emplace_back(query_name, query_buf);
666 Expr buf = Variable::make(type_of<struct halide_buffer_t *>(), query_name, b, p, ReductionDomain());
667 bounds_inference_args.push_back(buf);
668 // Although we expect ImageParams to be properly initialized and sanitized by the caller,
669 // we create a copy with copy_memory (not msan-aware), so we need to annotate it as initialized.
670 buffers_to_annotate.emplace_back(bounds_inference_args.back(), dims);
671 } else {
672 internal_error << "Bad ExternFuncArgument type";
673 }
674 }
675
676 // Make the buffer_ts representing the output. They all
677 // use the same size, but have differing types.
678 for (int j = 0; j < func.outputs(); j++) {
679 BufferBuilder builder;
680 builder.type = func.output_types()[j];
681 builder.dimensions = func.dimensions();
682 for (const string &arg : func.args()) {
683 string prefix = func.name() + ".s" + std::to_string(stage) + "." + arg;
684 Expr min = Variable::make(Int(32), prefix + ".min");
685 Expr max = Variable::make(Int(32), prefix + ".max");
686 builder.mins.push_back(min);
687 builder.extents.push_back(max + 1 - min);
688 builder.strides.emplace_back(0);
689 }
690 Expr output_buffer_t = builder.build();
691
692 string buf_name = func.name() + ".o" + std::to_string(j) + ".bounds_query";
693 bounds_inference_args.push_back(Variable::make(type_of<struct halide_buffer_t *>(), buf_name));
694 // Since this is a temporary, internal-only buffer used for bounds inference,
695 // we need to mark it
696 buffers_to_annotate.emplace_back(bounds_inference_args.back(), func.dimensions());
697 lets.emplace_back(buf_name, output_buffer_t);
698 }
699
700 Stmt annotate;
701 if (target.has_feature(Target::MSAN)) {
702 // Mark the buffers as initialized before calling out.
703 for (const auto &p : buffers_to_annotate) {
704 Expr buffer = p.first;
705 int dimensions = p.second;
706 // Return type is really 'void', but no way to represent that in our IR.
707 // Precedent (from halide_print, etc) is to use Int(32) and ignore the result.
708 Expr sizeof_buffer_t = cast<uint64_t>(
709 Call::make(Int(32), Call::size_of_halide_buffer_t, {}, Call::Intrinsic));
710 Stmt mark_buffer =
711 Evaluate::make(Call::make(Int(32), "halide_msan_annotate_memory_is_initialized",
712 {buffer, sizeof_buffer_t}, Call::Extern));
713 Expr shape = Call::make(type_of<halide_dimension_t *>(), Call::buffer_get_shape, {buffer},
714 Call::Extern);
715 Expr shape_size = Expr((uint64_t)(sizeof(halide_dimension_t) * dimensions));
716 Stmt mark_shape =
717 Evaluate::make(Call::make(Int(32), "halide_msan_annotate_memory_is_initialized",
718 {shape, shape_size}, Call::Extern));
719
720 mark_buffer = Block::make(mark_buffer, mark_shape);
721 if (annotate.defined()) {
722 annotate = Block::make(annotate, mark_buffer);
723 } else {
724 annotate = mark_buffer;
725 }
726 }
727 }
728
729 // Make the extern call
730 Expr e = func.make_call_to_extern_definition(bounds_inference_args, target);
731
732 // Check if it succeeded
733 string result_name = unique_name('t');
734 Expr result = Variable::make(Int(32), result_name);
735 Expr error = Call::make(Int(32), "halide_error_bounds_inference_call_failed",
736 {extern_name, result}, Call::Extern);
737 Stmt check = AssertStmt::make(EQ::make(result, 0), error);
738
739 check = LetStmt::make(result_name, e, check);
740
741 if (annotate.defined()) {
742 check = Block::make(annotate, check);
743 }
744
745 // Now inner code is free to extract the fields from the halide_buffer_t
746 s = Block::make(check, s);
747
748 // Wrap in let stmts defining the args
749 for (size_t i = 0; i < lets.size(); i++) {
750 s = LetStmt::make(lets[i].first, lets[i].second, s);
751 }
752
753 return s;
754 }
755
756 // A scope giving the bounds for variables used by this stage.
757 // We need to take into account specializations which may refer to
758 // different reduction variables as well.
populate_scopeHalide::Internal::BoundsInference::Stage759 void populate_scope(Scope<Interval> &result) {
760 for (const string &farg : func.args()) {
761 string arg = name + ".s" + std::to_string(stage) + "." + farg;
762 result.push(farg,
763 Interval(Variable::make(Int(32), arg + ".min"),
764 Variable::make(Int(32), arg + ".max")));
765 }
766 if (stage > 0) {
767 for (const ReductionVariable &rv : rvars) {
768 string arg = name + ".s" + std::to_string(stage) + "." + rv.var;
769 result.push(rv.var, Interval(Variable::make(Int(32), arg + ".min"),
770 Variable::make(Int(32), arg + ".max")));
771 }
772 }
773
774 /*for (size_t i = 0; i < func.definition().schedule().bounds().size(); i++) {
775 const Bound &b = func.definition().schedule().bounds()[i];
776 result.push(b.var, Interval(b.min, (b.min + b.extent) - 1));
777 }*/
778 }
779 };
780 vector<Stage> stages;
781
BoundsInference(const vector<Function> & f,const vector<vector<Function>> & fg,const vector<set<FusedPair>> & fp,const vector<Function> & outputs,const FuncValueBounds & fb,const Target & target)782 BoundsInference(const vector<Function> &f,
783 const vector<vector<Function>> &fg,
784 const vector<set<FusedPair>> &fp,
785 const vector<Function> &outputs,
786 const FuncValueBounds &fb,
787 const Target &target)
788 : funcs(f), fused_groups(fg), fused_pairs_in_groups(fp), func_bounds(fb), target(target) {
789 internal_assert(!f.empty());
790
791 // Compute the intrinsic relationships between the stages of
792 // the functions.
793
794 // Figure out which functions will be inlined away
795 vector<bool> inlined(f.size());
796 for (size_t i = 0; i < inlined.size(); i++) {
797 if (i < f.size() - 1 &&
798 f[i].schedule().compute_level().is_inlined() &&
799 f[i].can_be_inlined()) {
800 inlined[i] = true;
801 } else {
802 inlined[i] = false;
803 }
804 }
805
806 // First lay out all the stages in their realization order.
807 // The functions are already in topologically sorted order, so
808 // this is straight-forward.
809 for (size_t i = 0; i < f.size(); i++) {
810
811 if (inlined[i]) continue;
812
813 Stage s;
814 s.func = f[i];
815 s.stage = 0;
816 s.name = s.func.name();
817 s.fused_group_index = find_fused_group_index(s.func, fused_groups);
818 s.compute_exprs();
819 s.stage_prefix = s.name + ".s0.";
820 stages.push_back(s);
821
822 for (size_t j = 0; j < f[i].updates().size(); j++) {
823 s.stage = (int)(j + 1);
824 s.stage_prefix = s.name + ".s" + std::to_string(s.stage) + ".";
825 s.compute_exprs();
826 stages.push_back(s);
827 }
828 }
829
830 // Do any pure inlining (TODO: This is currently slow)
831 for (size_t i = f.size(); i > 0; i--) {
832 Function func = f[i - 1];
833 if (inlined[i - 1]) {
834 for (size_t j = 0; j < stages.size(); j++) {
835 Stage &s = stages[j];
836 for (size_t k = 0; k < s.exprs.size(); k++) {
837 CondValue &cond_val = s.exprs[k];
838 internal_assert(cond_val.value.defined());
839 cond_val.value = inline_function(cond_val.value, func);
840 }
841 }
842 }
843 }
844
845 // Remove the inlined stages
846 vector<Stage> new_stages;
847 for (size_t i = 0; i < stages.size(); i++) {
848 if (!stages[i].func.schedule().compute_level().is_inlined() ||
849 !stages[i].func.can_be_inlined()) {
850 new_stages.push_back(stages[i]);
851 }
852 }
853 new_stages.swap(stages);
854
855 // Dump the stages post-inlining for debugging
856 /*
857 debug(0) << "Bounds inference stages after inlining: \n";
858 for (size_t i = 0; i < stages.size(); i++) {
859 debug(0) << " " << i << ") " << stages[i].name << "\n";
860 }
861 */
862
863 // Then compute relationships between them.
864 for (size_t i = 0; i < stages.size(); i++) {
865
866 Stage &consumer = stages[i];
867
868 // Set up symbols representing the bounds over which this
869 // stage will be computed.
870 Scope<Interval> scope;
871 consumer.populate_scope(scope);
872
873 // Compute all the boxes of the producers this consumer
874 // uses.
875 map<string, Box> boxes;
876 if (consumer.func.has_extern_definition() &&
877 !consumer.func.extern_definition_proxy_expr().defined()) {
878
879 const vector<ExternFuncArgument> &args = consumer.func.extern_arguments();
880 // Stage::define_bounds is going to compute a query
881 // halide_buffer_t per producer for bounds inference to
882 // use. We just need to extract those values.
883 for (size_t j = 0; j < args.size(); j++) {
884 if (args[j].is_func()) {
885 Function f(args[j].func);
886 has_extern_consumer.insert(f.name());
887 string stage_name = f.name() + ".s" + std::to_string(f.updates().size());
888 Box b(f.dimensions());
889 for (int d = 0; d < f.dimensions(); d++) {
890 string buf_name = f.name() + ".o0.bounds_query." + consumer.name;
891 Expr buf = Variable::make(type_of<struct halide_buffer_t *>(), buf_name);
892 Expr min = Call::make(Int(32), Call::buffer_get_min,
893 {buf, d}, Call::Extern);
894 Expr max = Call::make(Int(32), Call::buffer_get_max,
895 {buf, d}, Call::Extern);
896 b[d] = Interval(min, max);
897 }
898 merge_boxes(boxes[f.name()], b);
899 }
900 }
901 } else {
902 for (const auto &cval : consumer.exprs) {
903 map<string, Box> new_boxes;
904 new_boxes = boxes_required(cval.value, scope, func_bounds);
905 for (auto &i : new_boxes) {
906 // Add the condition on which this value is evaluated to the box before merging
907 Box &box = i.second;
908 box.used = cval.cond;
909 merge_boxes(boxes[i.first], box);
910 }
911 }
912 }
913
914 // Expand the bounds required of all the producers found
915 // (and we are checking until i, because stages are topologically sorted).
916 for (size_t j = 0; j < i; j++) {
917 Stage &producer = stages[j];
918 // A consumer depends on *all* stages of a producer, not just the last one.
919 const Box &b = boxes[producer.func.name()];
920
921 if (!b.empty()) {
922 // Check for unboundedness
923 for (size_t k = 0; k < b.size(); k++) {
924 if (!b[k].is_bounded()) {
925 std::ostringstream err;
926 if (consumer.stage == 0) {
927 err << "The pure definition ";
928 } else {
929 err << "Update definition number " << (consumer.stage - 1);
930 }
931 err << " of Function " << consumer.name
932 << " calls function " << producer.name
933 << " in an unbounded way in dimension " << k << "\n";
934 user_error << err.str();
935 }
936 }
937
938 // Dump out the region required of each stage for debugging.
939
940 /*
941 debug(0) << "Box required of " << producer.name
942 << " by " << consumer.name
943 << " stage " << consumer.stage << ":\n";
944 for (size_t k = 0; k < b.size(); k++) {
945 debug(0) << " " << b[k].min << " ... " << b[k].max << "\n";
946 }
947 debug(0) << "\n";
948 */
949
950 producer.bounds[{consumer.name, consumer.stage}] = b;
951 producer.consumers.push_back((int)i);
952 }
953 }
954 }
955
956 // The region required of the each output is expanded to include the size of the output buffer.
957 for (Function output : outputs) {
958 Box output_box;
959 string buffer_name = output.name();
960 if (output.outputs() > 1) {
961 // Use the output size of the first output buffer
962 buffer_name += ".0";
963 }
964 for (int d = 0; d < output.dimensions(); d++) {
965 Parameter buf = output.output_buffers()[0];
966 Expr min = Variable::make(Int(32), buffer_name + ".min." + std::to_string(d), buf);
967 Expr extent = Variable::make(Int(32), buffer_name + ".extent." + std::to_string(d), buf);
968
969 // Respect any output min and extent constraints
970 Expr min_constraint = buf.min_constraint(d);
971 Expr extent_constraint = buf.extent_constraint(d);
972
973 if (min_constraint.defined()) {
974 min = min_constraint;
975 }
976 if (extent_constraint.defined()) {
977 extent = extent_constraint;
978 }
979
980 output_box.push_back(Interval(min, (min + extent) - 1));
981 }
982 for (size_t i = 0; i < stages.size(); i++) {
983 Stage &s = stages[i];
984 if (!s.func.same_as(output)) continue;
985 s.bounds[{s.name, s.stage}] = output_box;
986 }
987 }
988 }
989
990 using IRMutator::visit;
991
visit(const For * op)992 Stmt visit(const For *op) override {
993 // Don't recurse inside loops marked 'Extern', they will be
994 // removed later.
995 if (op->for_type == ForType::Extern) {
996 return op;
997 }
998
999 set<string> old_inner_productions;
1000 inner_productions.swap(old_inner_productions);
1001
1002 Stmt body = op->body;
1003
1004 // Walk inside of any let/if statements that don't depend on
1005 // bounds inference results so that we don't needlessly
1006 // complicate our bounds expressions.
1007 vector<pair<string, Expr>> wrappers;
1008 while (1) {
1009 if (const LetStmt *let = body.as<LetStmt>()) {
1010 if (depends_on_bounds_inference(let->value)) {
1011 break;
1012 }
1013
1014 body = let->body;
1015 wrappers.emplace_back(let->name, let->value);
1016 } else if (const IfThenElse *if_then_else = body.as<IfThenElse>()) {
1017 if (depends_on_bounds_inference(if_then_else->condition) ||
1018 if_then_else->else_case.defined()) {
1019 break;
1020 }
1021
1022 body = if_then_else->then_case;
1023 wrappers.emplace_back(std::string(), if_then_else->condition);
1024 } else {
1025 break;
1026 }
1027 }
1028
1029 // If there are no pipelines at this loop level, we can skip
1030 // most of the work. Consider 'extern' for loops as pipelines
1031 // (we aren't recursing into these loops above).
1032 bool no_pipelines =
1033 body.as<For>() != nullptr &&
1034 body.as<For>()->for_type != ForType::Extern;
1035
1036 // Figure out which stage of which function we're producing
1037 int producing = -1;
1038 Function f;
1039 int stage_index = -1;
1040 string stage_name;
1041 for (size_t i = 0; i < stages.size(); i++) {
1042 if (starts_with(op->name, stages[i].stage_prefix)) {
1043 producing = i;
1044 f = stages[i].func;
1045 stage_index = (int)stages[i].stage;
1046 stage_name = stages[i].name + ".s" + std::to_string(stages[i].stage);
1047 break;
1048 }
1049 }
1050
1051 // Figure out how much of it we're producing
1052
1053 // Note: the case when functions are fused is a little bit tricky, so may need extra care:
1054 // when we're producing some of a Func A, at every loop belonging to A
1055 // you potentially need to define symbols for what box is being computed
1056 // of A (A.x.min, A.x.max ...), because that any other producer Func P nested
1057 // there is going to define its loop bounds in terms of these symbols, to ensure
1058 // it computes enough of itself to satisfy the consumer.
1059 // Now say we compute B with A, and say B consumes P, not A. Bounds inference
1060 // will see the shared loop, and think it belongs to A only. It will define A.x.min and
1061 // friends, but that's not very useful, because P's loops are in terms of B.x.min, B.x.max, etc.
1062 // So without a local definition of those symbols, P will use the one in the outer scope, and
1063 // compute way too much of itself. It'll still be correct, but it's massive over-compute.
1064 // The fix is to realize that in this loop belonging to A, we also potentially need to define
1065 // a box for B, because B belongs to the same fused group as A, so really this loop belongs to A and B.
1066 // We'll get the box using boxes_provided and only filtering for A and B after the fact
1067 // Note that even though the loops are fused, the boxes touched of A and B might be totally different,
1068 // because e.g. B could be double-resolution (as happens when fusing yuv computations), so this
1069 // is not just a matter of giving A's box B's name as an alias.
1070 map<string, Box> boxes_for_fused_group;
1071 map<string, Function> stage_name_to_func;
1072 if (!no_pipelines && producing >= 0 && !f.has_extern_definition()) {
1073 Scope<Interval> empty_scope;
1074 size_t last_dot = op->name.rfind('.');
1075 string var = op->name.substr(last_dot + 1);
1076
1077 set<pair<string, int>> fused_with_f;
1078 for (const auto &pair : fused_pairs_in_groups[stages[producing].fused_group_index]) {
1079 if (!((pair.func_1 == stages[producing].name) && ((int)pair.stage_1 == stage_index)) && is_fused_with_others(fused_groups, fused_pairs_in_groups,
1080 f, stage_index,
1081 pair.func_1, pair.stage_1, var)) {
1082 fused_with_f.insert(make_pair(pair.func_1, pair.stage_1));
1083 }
1084 if (!((pair.func_2 == stages[producing].name) && ((int)pair.stage_2 == stage_index)) && is_fused_with_others(fused_groups, fused_pairs_in_groups,
1085 f, stage_index,
1086 pair.func_2, pair.stage_2, var)) {
1087 fused_with_f.insert(make_pair(pair.func_2, pair.stage_2));
1088 }
1089 }
1090
1091 if (fused_with_f.empty()) {
1092 boxes_for_fused_group[stage_name] = box_provided(body, stages[producing].name, empty_scope, func_bounds);
1093 stage_name_to_func[stage_name] = f;
1094 internal_assert((int)boxes_for_fused_group[stage_name].size() == f.dimensions());
1095 } else {
1096 auto boxes = boxes_provided(body, empty_scope, func_bounds);
1097 boxes_for_fused_group[stage_name] = boxes[stages[producing].name];
1098 stage_name_to_func[stage_name] = f;
1099 internal_assert((int)boxes_for_fused_group[stage_name].size() == f.dimensions());
1100 for (const auto &fused : fused_with_f) {
1101 string fused_stage_name = fused.first + ".s" + std::to_string(fused.second);
1102 boxes_for_fused_group[fused_stage_name] = boxes[fused.first];
1103 for (const auto &fn : funcs) {
1104 if (fn.name() == fused.first) {
1105 stage_name_to_func[fused_stage_name] = fn;
1106 break;
1107 }
1108 }
1109 }
1110 }
1111 }
1112
1113 // Recurse.
1114 body = mutate(body);
1115
1116 if (!no_pipelines) {
1117 // We only care about the bounds of a func if:
1118 // A) We're not already in a pipeline over that func AND
1119 // B.1) There's a production of this func somewhere inside this loop OR
1120 // B.2) We're downstream (a consumer) of a func for which we care about the bounds.
1121 vector<bool> bounds_needed(stages.size(), false);
1122 for (size_t i = 0; i < stages.size(); i++) {
1123 if (inner_productions.count(stages[i].name)) {
1124 bounds_needed[i] = true;
1125 }
1126
1127 if (in_pipeline.count(stages[i].name)) {
1128 bounds_needed[i] = false;
1129 }
1130
1131 if (bounds_needed[i]) {
1132 for (size_t j = 0; j < stages[i].consumers.size(); j++) {
1133 bounds_needed[stages[i].consumers[j]] = true;
1134 }
1135 body = stages[i].define_bounds(
1136 body, f, stage_name, stage_index, op->name, fused_groups,
1137 fused_pairs_in_groups, in_pipeline, inner_productions,
1138 has_extern_consumer, target);
1139 }
1140 }
1141
1142 // Finally, define the production bounds for the thing
1143 // we're producing.
1144 if (producing >= 0 && !inner_productions.empty()) {
1145 for (const auto &b : boxes_for_fused_group) {
1146 const vector<string> &f_args = stage_name_to_func[b.first].args();
1147 const auto &box = b.second;
1148 internal_assert(f_args.size() == box.size());
1149 for (size_t i = 0; i < box.size(); i++) {
1150 internal_assert(box[i].is_bounded());
1151 string var = b.first + "." + f_args[i];
1152
1153 if (box[i].is_single_point()) {
1154 body = LetStmt::make(var + ".max", Variable::make(Int(32), var + ".min"), body);
1155 } else {
1156 body = LetStmt::make(var + ".max", box[i].max, body);
1157 }
1158
1159 body = LetStmt::make(var + ".min", box[i].min, body);
1160 }
1161 }
1162 }
1163
1164 // And the current bounds on its reduction variables, and
1165 // variables from extern for loops.
1166 if (producing >= 0) {
1167 const Stage &s = stages[producing];
1168 vector<string> vars;
1169 if (s.func.has_extern_definition()) {
1170 vars = s.func.args();
1171 }
1172 if (stages[producing].stage > 0) {
1173 for (const ReductionVariable &rv : s.rvars) {
1174 vars.push_back(rv.var);
1175 }
1176 }
1177 for (const string &i : vars) {
1178 string var = s.stage_prefix + i;
1179 Interval in = bounds_of_inner_var(var, body);
1180 if (in.is_bounded()) {
1181 // bounds_of_inner_var doesn't understand
1182 // GuardWithIf, but we know split rvars never
1183 // have inner bounds that exceed the outer
1184 // ones.
1185 if (!s.rvars.empty()) {
1186 in.min = max(in.min, Variable::make(Int(32), var + ".min"));
1187 in.max = min(in.max, Variable::make(Int(32), var + ".max"));
1188 }
1189
1190 body = LetStmt::make(var + ".min", in.min, body);
1191 body = LetStmt::make(var + ".max", in.max, body);
1192 } else {
1193 // If it's not found, we're already in the
1194 // scope of the injected let. The let was
1195 // probably lifted to an outer level.
1196 Expr val = Variable::make(Int(32), var);
1197 body = LetStmt::make(var + ".min", val, body);
1198 body = LetStmt::make(var + ".max", val, body);
1199 }
1200 }
1201 }
1202 }
1203
1204 inner_productions.insert(old_inner_productions.begin(),
1205 old_inner_productions.end());
1206
1207 // Rewrap the let/if statements
1208 for (size_t i = wrappers.size(); i > 0; i--) {
1209 const auto &p = wrappers[i - 1];
1210 if (p.first.empty()) {
1211 body = IfThenElse::make(p.second, body);
1212 } else {
1213 body = LetStmt::make(p.first, p.second, body);
1214 }
1215 }
1216
1217 return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body);
1218 }
1219
visit(const ProducerConsumer * p)1220 Stmt visit(const ProducerConsumer *p) override {
1221 in_pipeline.insert(p->name);
1222 Stmt stmt = IRMutator::visit(p);
1223 in_pipeline.erase(p->name);
1224 inner_productions.insert(p->name);
1225 return stmt;
1226 }
1227 };
1228
bounds_inference(Stmt s,const vector<Function> & outputs,const vector<string> & order,const vector<vector<string>> & fused_groups,const map<string,Function> & env,const FuncValueBounds & func_bounds,const Target & target)1229 Stmt bounds_inference(Stmt s,
1230 const vector<Function> &outputs,
1231 const vector<string> &order,
1232 const vector<vector<string>> &fused_groups,
1233 const map<string, Function> &env,
1234 const FuncValueBounds &func_bounds,
1235 const Target &target) {
1236
1237 vector<Function> funcs(order.size());
1238 for (size_t i = 0; i < order.size(); i++) {
1239 funcs[i] = env.find(order[i])->second;
1240 }
1241
1242 // Each element in 'fused_func_groups' indicates a group of functions
1243 // which loops should be fused together.
1244 vector<vector<Function>> fused_func_groups;
1245 for (const vector<string> &group : fused_groups) {
1246 vector<Function> fs;
1247 for (const string &fname : group) {
1248 fs.push_back(env.find(fname)->second);
1249 }
1250 fused_func_groups.push_back(fs);
1251 }
1252
1253 // For each fused group, collect the pairwise fused function stages.
1254 vector<set<FusedPair>> fused_pairs_in_groups;
1255 for (const vector<string> &group : fused_groups) {
1256 set<FusedPair> pairs;
1257 for (const string &fname : group) {
1258 Function f = env.find(fname)->second;
1259 if (!f.has_extern_definition()) {
1260 std::copy(f.definition().schedule().fused_pairs().begin(),
1261 f.definition().schedule().fused_pairs().end(),
1262 std::inserter(pairs, pairs.end()));
1263
1264 for (size_t i = 0; i < f.updates().size(); ++i) {
1265 std::copy(f.updates()[i].schedule().fused_pairs().begin(),
1266 f.updates()[i].schedule().fused_pairs().end(),
1267 std::inserter(pairs, pairs.end()));
1268 }
1269 }
1270 }
1271 fused_pairs_in_groups.push_back(pairs);
1272 }
1273
1274 // Add a note in the IR for where assertions on input images
1275 // should go. Those are handled by a later lowering pass.
1276 Expr marker = Call::make(Int(32), Call::add_image_checks_marker, {}, Call::Intrinsic);
1277 s = Block::make(Evaluate::make(marker), s);
1278
1279 // Add a synthetic outermost loop to act as 'root'.
1280 s = For::make("<outermost>", 0, 1, ForType::Serial, DeviceAPI::None, s);
1281
1282 s = BoundsInference(funcs, fused_func_groups, fused_pairs_in_groups,
1283 outputs, func_bounds, target)
1284 .mutate(s);
1285 return s.as<For>()->body;
1286 }
1287
1288 } // namespace Internal
1289 } // namespace Halide
1290