1 #include <algorithm>
2 #include <iostream>
3 #include <string.h>
4 #include <utility>
5 
6 #ifdef _MSC_VER
7 #include <intrin.h>
8 #endif
9 
10 #include "ApplySplit.h"
11 #include "Argument.h"
12 #include "Associativity.h"
13 #include "CodeGen_LLVM.h"
14 #include "Debug.h"
15 #include "ExprUsesVar.h"
16 #include "Func.h"
17 #include "Function.h"
18 #include "IR.h"
19 #include "IREquality.h"
20 #include "IRMutator.h"
21 #include "IROperator.h"
22 #include "IRPrinter.h"
23 #include "ImageParam.h"
24 #include "LLVM_Output.h"
25 #include "Lower.h"
26 #include "Param.h"
27 #include "PrintLoopNest.h"
28 #include "Simplify.h"
29 #include "Solve.h"
30 #include "Substitute.h"
31 #include "Util.h"
32 
33 namespace Halide {
34 
35 using std::map;
36 using std::ofstream;
37 using std::pair;
38 using std::string;
39 using std::vector;
40 
41 using namespace Internal;
42 
43 namespace {
44 
45 template<typename DimType>
dump_dim_list(const vector<DimType> & dims)46 std::string dump_dim_list(const vector<DimType> &dims) {
47     std::ostringstream oss;
48     oss << "Vars:";
49     for (size_t i = 0; i < dims.size(); i++) {
50         oss << " " << dims[i].var;
51     }
52     oss << "\n";
53     return oss.str();
54 }
55 
56 }  // namespace
57 
Func(const string & name)58 Func::Func(const string &name)
59     : func(unique_name(name)) {
60 }
61 
Func()62 Func::Func()
63     : func(make_entity_name(this, "Halide:.*:Func", 'f')) {
64 }
65 
Func(const Expr & e)66 Func::Func(const Expr &e)
67     : func(make_entity_name(this, "Halide:.*:Func", 'f')) {
68     (*this)(_) = e;
69 }
70 
Func(Function f)71 Func::Func(Function f)
72     : func(std::move(f)) {
73 }
74 
name() const75 const string &Func::name() const {
76     return func.name();
77 }
78 
79 /** Get the pure arguments. */
args() const80 std::vector<Var> Func::args() const {
81     const std::vector<std::string> arg_names = func.args();
82     std::vector<Var> args;
83     args.reserve(arg_names.size());
84     for (const auto &arg_name : arg_names) {
85         args.emplace_back(arg_name);
86     }
87     return args;
88 }
89 
90 /** The right-hand-side value of the pure definition of this
91  * function. An error if the Func has no definition, or is defined as
92  * a Tuple. */
value() const93 Expr Func::value() const {
94     user_assert(defined())
95         << "Can't call Func::value() on an undefined Func. To check if a Func is defined, call Func::defined()\n";
96     user_assert(func.outputs() == 1)
97         << "Can't call Func::value() on Func \"" << name() << "\", because it has multiple values.\n";
98     return func.values()[0];
99 }
100 
101 /** The values returned by a Func, in Tuple form. */
values() const102 Tuple Func::values() const {
103     user_assert(defined())
104         << "Can't call Func::values() on an undefined Func. To check if a Func is defined, call Func::defined().\n";
105     return Tuple(func.values());
106 }
107 
108 /** Get the left-hand-side of the update definition. An empty
109  * vector if there's no update definition. */
update_args(int idx) const110 const std::vector<Expr> &Func::update_args(int idx) const {
111     user_assert(has_update_definition())
112         << "Can't call Func::update_args() on Func \"" << name()
113         << "\" as it has no update definition. "
114         << "Use Func::has_update_definition() to check for the existence of an update definition.\n";
115     user_assert(idx < num_update_definitions())
116         << "Update definition index out of bounds.\n";
117     return func.update(idx).args();
118 }
119 
120 /** Get the right-hand-side of the update definition. An error if
121  * there is no update definition. */
update_value(int idx) const122 Expr Func::update_value(int idx) const {
123     user_assert(has_update_definition())
124         << "Can't call Func::update_args() on Func \"" << name() << "\" as it has no update definition. "
125         << "Use Func::has_update_definition() to check for the existence of an update definition.\n";
126     user_assert(idx < num_update_definitions())
127         << "Update definition index out of bounds.\n";
128     user_assert(func.update(idx).values().size() == 1)
129         << "Can't call Func::update_value() on Func \"" << name() << "\", because it has multiple values.\n";
130     return func.update(idx).values()[0];
131 }
132 
133 /** The update values returned by a Func, in Tuple form. */
update_values(int idx) const134 Tuple Func::update_values(int idx) const {
135     user_assert(has_update_definition())
136         << "Can't call Func::update_args() on Func \"" << name() << "\" as it has no update definition. "
137         << "Use Func::has_update_definition() to check for the existence of an update definition.\n";
138     user_assert(idx < num_update_definitions())
139         << "Update definition index out of bounds.\n";
140     return Tuple(func.update(idx).values());
141 }
142 
143 /** Get the RVars of the reduction domain for the update definition. Returns an
144  * empty vector if there's no update definition, or if the update definition has
145  * no domain. Note that the RVars returned are floating RVars, i.e. they don't
146  * actually have pointer to the reduction domain. */
rvars(int idx) const147 vector<RVar> Func::rvars(int idx) const {
148     user_assert(has_update_definition())
149         << "Can't call Func::update_args() on Func \"" << name() << "\" as it has no update definition. "
150         << "Use Func::has_update_definition() to check for the existence of an update definition.\n";
151     user_assert(idx < num_update_definitions())
152         << "Update definition index out of bounds.\n";
153     const std::vector<ReductionVariable> rvars = func.update(idx).schedule().rvars();
154     std::vector<RVar> rvs(rvars.size());
155     for (size_t i = 0; i < rvars.size(); i++) {
156         rvs[i] = RVar(rvars[i].var);
157     }
158     return rvs;
159 }
160 
defined() const161 bool Func::defined() const {
162     return func.has_pure_definition() || func.has_extern_definition();
163 }
164 
165 /** Is this function a reduction? */
has_update_definition() const166 bool Func::has_update_definition() const {
167     return func.has_update_definition();
168 }
169 
170 /** How many update definitions are there? */
num_update_definitions() const171 int Func::num_update_definitions() const {
172     return static_cast<int>(func.updates().size());
173 }
174 
175 /** Is this function external? */
is_extern() const176 bool Func::is_extern() const {
177     return func.has_extern_definition();
178 }
179 
180 /** Add an extern definition for this Func. */
define_extern(const std::string & function_name,const std::vector<ExternFuncArgument> & args,const std::vector<Type> & types,const std::vector<Var> & arguments,NameMangling mangling,DeviceAPI device_api)181 void Func::define_extern(const std::string &function_name,
182                          const std::vector<ExternFuncArgument> &args,
183                          const std::vector<Type> &types,
184                          const std::vector<Var> &arguments,
185                          NameMangling mangling, DeviceAPI device_api) {
186     func.define_extern(function_name, args, types, arguments, mangling,
187                        device_api);
188 }
189 
190 /** Get the types of the buffers returned by an extern definition. */
output_types() const191 const std::vector<Type> &Func::output_types() const {
192     return func.output_types();
193 }
194 
195 /** Get the number of outputs this function has. */
outputs() const196 int Func::outputs() const {
197     return func.outputs();
198 }
199 
200 /** Get the name of the extern function called for an extern
201  * definition. */
extern_function_name() const202 const std::string &Func::extern_function_name() const {
203     return func.extern_function_name();
204 }
205 
dimensions() const206 int Func::dimensions() const {
207     if (!defined()) return 0;
208     return func.dimensions();
209 }
210 
operator ()(vector<Var> args) const211 FuncRef Func::operator()(vector<Var> args) const {
212     int placeholder_pos, count;
213     std::tie(placeholder_pos, count) = add_implicit_vars(args);
214     return FuncRef(func, args, placeholder_pos, count);
215 }
216 
operator ()(vector<Expr> args) const217 FuncRef Func::operator()(vector<Expr> args) const {
218     int placeholder_pos, count;
219     std::tie(placeholder_pos, count) = add_implicit_vars(args);
220     return FuncRef(func, args, placeholder_pos, count);
221 }
222 
add_implicit_vars(vector<Var> & args) const223 std::pair<int, int> Func::add_implicit_vars(vector<Var> &args) const {
224     int placeholder_pos = -1;
225     int count = 0;
226     std::vector<Var>::iterator iter = args.begin();
227 
228     while (iter != args.end() && !iter->same_as(_)) {
229         iter++;
230     }
231     if (iter != args.end()) {
232         placeholder_pos = (int)(iter - args.begin());
233         int i = 0;
234         iter = args.erase(iter);
235         while ((int)args.size() < dimensions()) {
236             Internal::debug(2) << "Adding implicit var " << i << " to call to " << name() << "\n";
237             iter = args.insert(iter, Var::implicit(i++));
238             iter++;
239             count++;
240         }
241     }
242 
243     if (defined() && args.size() != (size_t)dimensions()) {
244         user_error << "Func \"" << name() << "\" was called with "
245                    << args.size() << " arguments, but was defined with " << dimensions() << "\n";
246     }
247 
248     return {placeholder_pos, count};
249 }
250 
add_implicit_vars(vector<Expr> & args) const251 std::pair<int, int> Func::add_implicit_vars(vector<Expr> &args) const {
252     int placeholder_pos = -1;
253     int count = 0;
254     std::vector<Expr>::iterator iter = args.begin();
255     while (iter != args.end()) {
256         const Variable *var = iter->as<Variable>();
257         if (var && var->name == Var(_).name())
258             break;
259         iter++;
260     }
261     if (iter != args.end()) {
262         placeholder_pos = (int)(iter - args.begin());
263         int i = 0;
264         iter = args.erase(iter);
265         while ((int)args.size() < dimensions()) {
266             Internal::debug(2) << "Adding implicit var " << i << " to call to " << name() << "\n";
267             iter = args.insert(iter, Var::implicit(i++));
268             iter++;
269             count++;
270         }
271     }
272 
273     if (defined() && args.size() != (size_t)dimensions()) {
274         user_error << "Func \"" << name() << "\" was called with "
275                    << args.size() << " arguments, but was defined with " << dimensions() << "\n";
276     }
277 
278     return {placeholder_pos, count};
279 }
280 
281 namespace {
var_name_match(const string & candidate,const string & var)282 bool var_name_match(const string &candidate, const string &var) {
283     internal_assert(var.find('.') == string::npos)
284         << "var_name_match expects unqualified names for the second argument. "
285         << "Name passed: " << var << "\n";
286     if (candidate == var) return true;
287     return Internal::ends_with(candidate, "." + var);
288 }
289 }  // namespace
290 
name() const291 std::string Stage::name() const {
292     std::string stage_name = (stage_index == 0) ? function.name() : function.name() + ".update(" + std::to_string(stage_index - 1) + ")";
293     return stage_name;
294 }
295 
set_dim_type(const VarOrRVar & var,ForType t)296 void Stage::set_dim_type(const VarOrRVar &var, ForType t) {
297     bool found = false;
298     vector<Dim> &dims = definition.schedule().dims();
299     for (size_t i = 0; i < dims.size(); i++) {
300         if (var_name_match(dims[i].var, var.name())) {
301             found = true;
302             dims[i].for_type = t;
303 
304             // If it's an rvar and the for type is parallel, we need to
305             // validate that this doesn't introduce a race condition,
306             // unless it is flagged explicitly or is a associative atomic operation.
307             if (!dims[i].is_pure() && var.is_rvar && is_parallel(t)) {
308                 if (!definition.schedule().allow_race_conditions() &&
309                     definition.schedule().atomic()) {
310                     if (!definition.schedule().override_atomic_associativity_test()) {
311                         // We only allow allow associative atomic operations
312                         const string &func_name = function.name();
313                         vector<Expr> &args = definition.args();
314                         vector<Expr> &values = definition.values();
315 
316                         // Check whether the operator is associative and determine the operator and
317                         // its identity for each value in the definition if it is a Tuple
318                         const auto &prover_result = prove_associativity(func_name, args, values);
319 
320                         user_assert(prover_result.associative())
321                             << "Failed to call atomic() on " << name()
322                             << " since it can't prove associativity of the operator.\n";
323                         internal_assert(prover_result.size() == values.size());
324                     }
325                 }
326                 user_assert(definition.schedule().allow_race_conditions() ||
327                             definition.schedule().atomic())
328                     << "In schedule for " << name()
329                     << ", marking var " << var.name()
330                     << " as parallel or vectorized may introduce a race"
331                     << " condition resulting in incorrect output."
332                     << " It is possible to parallelize this by using the"
333                     << " atomic() method if the operation is associative,"
334                     << " or set override_associativity_test to true in the atomic method "
335                     << " if you are certain that the operation is associative."
336                     << " It is also possible to override this error using"
337                     << " the allow_race_conditions() method. Use allow_race_conditions()"
338                     << " with great caution, and only when you are willing"
339                     << " to accept non-deterministic output, or you can prove"
340                     << " that any race conditions in this code do not change"
341                     << " the output, or you can prove that there are actually"
342                     << " no race conditions, and that Halide is being too cautious.\n";
343             }
344         } else if (t == ForType::Vectorized) {
345             user_assert(dims[i].for_type != ForType::Vectorized)
346                 << "In schedule for " << name()
347                 << ", can't vectorize across " << var.name()
348                 << " because Func is already vectorized across " << dims[i].var << "\n";
349         }
350     }
351 
352     if (!found) {
353         user_error << "In schedule for " << name()
354                    << ", could not find dimension "
355                    << var.name()
356                    << " to mark as " << t
357                    << " in vars for function\n"
358                    << dump_argument_list();
359     }
360 }
361 
set_dim_device_api(const VarOrRVar & var,DeviceAPI device_api)362 void Stage::set_dim_device_api(const VarOrRVar &var, DeviceAPI device_api) {
363     bool found = false;
364     vector<Dim> &dims = definition.schedule().dims();
365     for (size_t i = 0; i < dims.size(); i++) {
366         if (var_name_match(dims[i].var, var.name())) {
367             found = true;
368             dims[i].device_api = device_api;
369         }
370     }
371 
372     if (!found) {
373         user_error << "In schedule for " << name()
374                    << ", could not find dimension "
375                    << var.name()
376                    << " to set to device API " << static_cast<int>(device_api)
377                    << " in vars for function\n"
378                    << dump_argument_list();
379     }
380 }
381 
dump_argument_list() const382 std::string Stage::dump_argument_list() const {
383     return dump_dim_list(definition.schedule().dims());
384 }
385 
386 namespace {
387 
388 class SubstituteSelfReference : public IRMutator {
389     using IRMutator::visit;
390 
391     const string func;
392     const Function substitute;
393     const vector<Var> new_args;
394 
visit(const Call * c)395     Expr visit(const Call *c) override {
396         Expr expr = IRMutator::visit(c);
397         c = expr.as<Call>();
398         internal_assert(c);
399 
400         if ((c->call_type == Call::Halide) && (func == c->name)) {
401             debug(4) << "...Replace call to Func \"" << c->name << "\" with "
402                      << "\"" << substitute.name() << "\"\n";
403             vector<Expr> args;
404             args.insert(args.end(), c->args.begin(), c->args.end());
405             args.insert(args.end(), new_args.begin(), new_args.end());
406             expr = Call::make(substitute, args, c->value_index);
407         }
408         return expr;
409     }
410 
411 public:
SubstituteSelfReference(const string & func,const Function & substitute,const vector<Var> & new_args)412     SubstituteSelfReference(const string &func, const Function &substitute,
413                             const vector<Var> &new_args)
414         : func(func), substitute(substitute), new_args(new_args) {
415         internal_assert(substitute.get_contents().defined());
416     }
417 };
418 
419 /** Substitute all self-reference calls to 'func' with 'substitute' which
420  * args (LHS) is the old args (LHS) plus 'new_args' in that order.
421  * Expect this method to be called on the value (RHS) of an update definition. */
substitute_self_reference(Expr val,const string & func,const Function & substitute,const vector<Var> & new_args)422 Expr substitute_self_reference(Expr val, const string &func, const Function &substitute,
423                                const vector<Var> &new_args) {
424     SubstituteSelfReference subs(func, substitute, new_args);
425     val = subs.mutate(val);
426     return val;
427 }
428 
429 // Substitute the occurrence of 'name' in 'exprs' with 'value'.
substitute_var_in_exprs(const string & name,const Expr & value,vector<Expr> & exprs)430 void substitute_var_in_exprs(const string &name, const Expr &value, vector<Expr> &exprs) {
431     for (auto &expr : exprs) {
432         expr = substitute(name, value, expr);
433     }
434 }
435 
apply_split_result(const vector<pair<string,Expr>> & bounds_let_stmts,const vector<ApplySplitResult> & splits_result,vector<Expr> & predicates,vector<Expr> & args,vector<Expr> & values)436 void apply_split_result(const vector<pair<string, Expr>> &bounds_let_stmts,
437                         const vector<ApplySplitResult> &splits_result,
438                         vector<Expr> &predicates, vector<Expr> &args,
439                         vector<Expr> &values) {
440 
441     for (const auto &res : splits_result) {
442         if (res.is_substitution() || res.is_let()) {
443             // Apply substitutions to the list of predicates, args, and values.
444             // Make sure we substitute in all the let stmts as well since we are
445             // not going to add them to the exprs.
446             substitute_var_in_exprs(res.name, res.value, predicates);
447             substitute_var_in_exprs(res.name, res.value, args);
448             substitute_var_in_exprs(res.name, res.value, values);
449         } else {
450             internal_assert(res.is_predicate());
451             predicates.push_back(res.value);
452         }
453     }
454 
455     // Make sure we substitute in all the let stmts from 'bounds_let_stmts'
456     // since we are not going to add them to the exprs.
457     for (const auto &let : bounds_let_stmts) {
458         substitute_var_in_exprs(let.first, let.second, predicates);
459         substitute_var_in_exprs(let.first, let.second, args);
460         substitute_var_in_exprs(let.first, let.second, values);
461     }
462 }
463 
464 /** Apply split directives on the reduction variables. Remove the old RVar from
465  * the list and add the split result (inner and outer RVars) to the list. Add
466  * new predicates corresponding to the TailStrategy to the RDom predicate list. */
apply_split(const Split & s,vector<ReductionVariable> & rvars,vector<Expr> & predicates,vector<Expr> & args,vector<Expr> & values,map<string,Expr> & dim_extent_alignment)467 bool apply_split(const Split &s, vector<ReductionVariable> &rvars,
468                  vector<Expr> &predicates, vector<Expr> &args,
469                  vector<Expr> &values, map<string, Expr> &dim_extent_alignment) {
470     internal_assert(s.is_split());
471     const auto it = std::find_if(rvars.begin(), rvars.end(),
472                                  [&s](const ReductionVariable &rv) { return (s.old_var == rv.var); });
473 
474     Expr old_max, old_min, old_extent;
475 
476     if (it != rvars.end()) {
477         debug(4) << "  Splitting " << it->var << " into " << s.outer << " and " << s.inner << "\n";
478 
479         old_max = simplify(it->min + it->extent - 1);
480         old_min = it->min;
481         old_extent = it->extent;
482 
483         it->var = s.inner;
484         it->min = 0;
485         it->extent = s.factor;
486 
487         rvars.insert(it + 1, {s.outer, 0, simplify((old_extent - 1 + s.factor) / s.factor)});
488 
489         vector<ApplySplitResult> splits_result = apply_split(s, true, "", dim_extent_alignment);
490         vector<pair<string, Expr>> bounds_let_stmts = compute_loop_bounds_after_split(s, "");
491         apply_split_result(bounds_let_stmts, splits_result, predicates, args, values);
492 
493         return true;
494     }
495     return false;
496 }
497 
498 /** Apply fuse directives on the reduction variables. Remove the
499  * fused RVars from the list and add the fused RVar to the list. */
apply_fuse(const Split & s,vector<ReductionVariable> & rvars,vector<Expr> & predicates,vector<Expr> & args,vector<Expr> & values,map<string,Expr> & dim_extent_alignment)500 bool apply_fuse(const Split &s, vector<ReductionVariable> &rvars,
501                 vector<Expr> &predicates, vector<Expr> &args,
502                 vector<Expr> &values, map<string, Expr> &dim_extent_alignment) {
503     internal_assert(s.is_fuse());
504     const auto &iter_outer = std::find_if(rvars.begin(), rvars.end(),
505                                           [&s](const ReductionVariable &rv) { return (s.outer == rv.var); });
506     const auto &iter_inner = std::find_if(rvars.begin(), rvars.end(),
507                                           [&s](const ReductionVariable &rv) { return (s.inner == rv.var); });
508 
509     Expr inner_min, inner_extent, outer_min, outer_extent;
510     if ((iter_outer != rvars.end()) && (iter_inner != rvars.end())) {
511         debug(4) << "  Fusing " << s.outer << " and " << s.inner << " into " << s.old_var << "\n";
512 
513         inner_min = iter_inner->min;
514         inner_extent = iter_inner->extent;
515         outer_min = iter_outer->min;
516         outer_extent = iter_outer->extent;
517 
518         Expr extent = iter_outer->extent * iter_inner->extent;
519         iter_outer->var = s.old_var;
520         iter_outer->min = 0;
521         iter_outer->extent = extent;
522         rvars.erase(iter_inner);
523 
524         vector<ApplySplitResult> splits_result = apply_split(s, true, "", dim_extent_alignment);
525         vector<pair<string, Expr>> bounds_let_stmts = compute_loop_bounds_after_split(s, "");
526         apply_split_result(bounds_let_stmts, splits_result, predicates, args, values);
527 
528         return true;
529     }
530     return false;
531 }
532 
533 /** Apply purify directives on the reduction variables and predicates. Purify
534  * replace a RVar with a Var, thus, the RVar needs to be removed from the list.
535  * Any reference to the RVar in the predicates will be replaced with reference
536  * to a Var. */
apply_purify(const Split & s,vector<ReductionVariable> & rvars,vector<Expr> & predicates,vector<Expr> & args,vector<Expr> & values,map<string,Expr> & dim_extent_alignment)537 bool apply_purify(const Split &s, vector<ReductionVariable> &rvars,
538                   vector<Expr> &predicates, vector<Expr> &args,
539                   vector<Expr> &values, map<string, Expr> &dim_extent_alignment) {
540     internal_assert(s.is_purify());
541     const auto &iter = std::find_if(rvars.begin(), rvars.end(),
542                                     [&s](const ReductionVariable &rv) { return (s.old_var == rv.var); });
543     if (iter != rvars.end()) {
544         debug(4) << "  Purify RVar " << iter->var << " into Var " << s.outer
545                  << ", deleting it from the rvars list\n";
546         rvars.erase(iter);
547 
548         vector<ApplySplitResult> splits_result = apply_split(s, true, "", dim_extent_alignment);
549         vector<pair<string, Expr>> bounds_let_stmts = compute_loop_bounds_after_split(s, "");
550         apply_split_result(bounds_let_stmts, splits_result, predicates, args, values);
551 
552         return true;
553     }
554     return false;
555 }
556 
557 /** Apply rename directives on the reduction variables. */
apply_rename(const Split & s,vector<ReductionVariable> & rvars,vector<Expr> & predicates,vector<Expr> & args,vector<Expr> & values,map<string,Expr> & dim_extent_alignment)558 bool apply_rename(const Split &s, vector<ReductionVariable> &rvars,
559                   vector<Expr> &predicates, vector<Expr> &args,
560                   vector<Expr> &values, map<string, Expr> &dim_extent_alignment) {
561     internal_assert(s.is_rename());
562     const auto &iter = std::find_if(rvars.begin(), rvars.end(),
563                                     [&s](const ReductionVariable &rv) { return (s.old_var == rv.var); });
564     if (iter != rvars.end()) {
565         debug(4) << "  Renaming " << iter->var << " into " << s.outer << "\n";
566         iter->var = s.outer;
567 
568         vector<ApplySplitResult> splits_result = apply_split(s, true, "", dim_extent_alignment);
569         vector<pair<string, Expr>> bounds_let_stmts = compute_loop_bounds_after_split(s, "");
570         apply_split_result(bounds_let_stmts, splits_result, predicates, args, values);
571 
572         return true;
573     }
574     return false;
575 }
576 
577 /** Apply scheduling directives (e.g. split, fuse, etc.) on the reduction
578  * variables. */
apply_split_directive(const Split & s,vector<ReductionVariable> & rvars,vector<Expr> & predicates,vector<Expr> & args,vector<Expr> & values)579 bool apply_split_directive(const Split &s, vector<ReductionVariable> &rvars,
580                            vector<Expr> &predicates, vector<Expr> &args,
581                            vector<Expr> &values) {
582     map<string, Expr> dim_extent_alignment;
583     for (const ReductionVariable &rv : rvars) {
584         dim_extent_alignment[rv.var] = rv.extent;
585     }
586 
587     vector<pair<string, Expr>> rvar_bounds;
588     for (const ReductionVariable &rv : rvars) {
589         rvar_bounds.emplace_back(rv.var + ".loop_min", rv.min);
590         rvar_bounds.emplace_back(rv.var + ".loop_max", simplify(rv.min + rv.extent - 1));
591         rvar_bounds.emplace_back(rv.var + ".loop_extent", rv.extent);
592     }
593 
594     bool found = false;
595     if (s.is_split()) {
596         found = apply_split(s, rvars, predicates, args, values, dim_extent_alignment);
597     } else if (s.is_fuse()) {
598         found = apply_fuse(s, rvars, predicates, args, values, dim_extent_alignment);
599     } else if (s.is_purify()) {
600         found = apply_purify(s, rvars, predicates, args, values, dim_extent_alignment);
601     } else {
602         found = apply_rename(s, rvars, predicates, args, values, dim_extent_alignment);
603     }
604 
605     if (found) {
606         for (const auto &let : rvar_bounds) {
607             substitute_var_in_exprs(let.first, let.second, predicates);
608             substitute_var_in_exprs(let.first, let.second, args);
609             substitute_var_in_exprs(let.first, let.second, values);
610         }
611     }
612     return found;
613 }
614 
615 }  // anonymous namespace
616 
rfactor(const RVar & r,const Var & v)617 Func Stage::rfactor(const RVar &r, const Var &v) {
618     return rfactor({{r, v}});
619 }
620 
rfactor(vector<pair<RVar,Var>> preserved)621 Func Stage::rfactor(vector<pair<RVar, Var>> preserved) {
622     user_assert(!definition.is_init()) << "rfactor() must be called on an update definition\n";
623 
624     const string &func_name = function.name();
625     vector<Expr> &args = definition.args();
626     vector<Expr> &values = definition.values();
627 
628     // Check whether the operator is associative and determine the operator and
629     // its identity for each value in the definition if it is a Tuple
630     const auto &prover_result = prove_associativity(func_name, args, values);
631 
632     user_assert(prover_result.associative())
633         << "Failed to call rfactor() on " << name()
634         << " since it can't prove associativity of the operator\n";
635     internal_assert(prover_result.size() == values.size());
636 
637     vector<Split> &splits = definition.schedule().splits();
638     vector<Dim> &dims = definition.schedule().dims();
639     vector<ReductionVariable> &rvars = definition.schedule().rvars();
640     vector<Expr> predicates = definition.split_predicate();
641 
642     Scope<string> scope;  // Contains list of RVars lifted to the intermediate Func
643     vector<string> rvars_removed;
644 
645     vector<bool> is_rfactored(dims.size(), false);
646     for (const pair<RVar, Var> &i : preserved) {
647         const RVar &rv = i.first;
648         const Var &v = i.second;
649         {
650             // Check that the RVar are in the dims list
651             const auto &iter = std::find_if(dims.begin(), dims.end(),
652                                             [&rv](const Dim &dim) { return var_name_match(dim.var, rv.name()); });
653             user_assert((iter != dims.end()) && (*iter).is_rvar())
654                 << "In schedule for " << name()
655                 << ", can't perform rfactor() on " << rv.name()
656                 << " since it is not in the reduction domain\n"
657                 << dump_argument_list();
658             is_rfactored[iter - dims.begin()] = true;
659         }
660         {
661             // Check that the new pure Vars we used to rename the RVar aren't already in the dims list
662             const auto &iter = std::find_if(dims.begin(), dims.end(),
663                                             [&v](const Dim &dim) { return var_name_match(dim.var, v.name()); });
664             user_assert(iter == dims.end())
665                 << "In schedule for " << name()
666                 << ", can't rename the rvars " << rv.name() << " into " << v.name()
667                 << ", since it is already used in this Func's schedule elsewhere.\n"
668                 << dump_argument_list();
669         }
670     }
671 
672     // If the operator is associative but non-commutative, rfactor() on inner
673     // dimensions (excluding the outer dimensions) is not valid.
674     if (!prover_result.commutative()) {
675         int last_rvar = -1;
676         for (int i = dims.size() - 1; i >= 0; --i) {
677             if ((last_rvar != -1) && is_rfactored[i]) {
678                 user_assert(is_rfactored[last_rvar])
679                     << "In schedule for " << name()
680                     << ", can't rfactor an inner dimension " << dims[i].var
681                     << " without rfactoring the outer dimensions, since the "
682                     << "operator is non-commutative.\n"
683                     << dump_argument_list();
684             }
685             if (dims[i].is_rvar()) {
686                 last_rvar = i;
687             }
688         }
689     }
690 
691     // We need to apply the split directives on the reduction vars, so that we can
692     // correctly lift the RVars not in 'rvars_kept' and distribute the RVars to the
693     // intermediate and merge Funcs.
694     {
695         vector<Split> temp;
696         for (const Split &s : splits) {
697             // If it's already applied, we should remove it from the split list.
698             if (!apply_split_directive(s, rvars, predicates, args, values)) {
699                 temp.push_back(s);
700             }
701         }
702         splits = temp;
703     }
704 
705     // Reduction domain of the intermediate update definition
706     vector<ReductionVariable> intm_rvars;
707     for (const auto &rv : rvars) {
708         const auto &iter = std::find_if(preserved.begin(), preserved.end(),
709                                         [&rv](const pair<RVar, Var> &pair) { return var_name_match(rv.var, pair.first.name()); });
710         if (iter == preserved.end()) {
711             intm_rvars.push_back(rv);
712             scope.push(rv.var, rv.var);
713         }
714     }
715     RDom intm_rdom(intm_rvars);
716 
717     // Sort the Rvars kept and their Vars replacement based on the RVars of
718     // the reduction domain AFTER applying the split directives, so that we
719     // can have a consistent args order for the update definition of the
720     // intermediate and new merge Funcs.
721     std::sort(preserved.begin(), preserved.end(),
722               [&](const pair<RVar, Var> &lhs, const pair<RVar, Var> &rhs) {
723                   const auto &iter_lhs = std::find_if(rvars.begin(), rvars.end(),
724                                                       [&lhs](const ReductionVariable &rv) { return var_name_match(rv.var, lhs.first.name()); });
725                   const auto &iter_rhs = std::find_if(rvars.begin(), rvars.end(),
726                                                       [&rhs](const ReductionVariable &rv) { return var_name_match(rv.var, rhs.first.name()); });
727                   return iter_lhs < iter_rhs;
728               });
729     // The list of RVars to keep in the new update definition
730     vector<RVar> rvars_kept(preserved.size());
731     // List of pure Vars to replace the RVars in the intermediate's update definition
732     vector<Var> vars_rename(preserved.size());
733     for (size_t i = 0; i < preserved.size(); ++i) {
734         const auto &val = preserved[i];
735         rvars_kept[i] = val.first;
736         vars_rename[i] = val.second;
737     }
738 
739     // List of RVars for the new reduction domain. Any RVars not in 'rvars_kept'
740     // are removed from the RDom
741     {
742         vector<ReductionVariable> temp;
743         for (const auto &rv : rvars) {
744             const auto &iter = std::find_if(rvars_kept.begin(), rvars_kept.end(),
745                                             [&rv](const RVar &rvar) { return var_name_match(rv.var, rvar.name()); });
746             if (iter != rvars_kept.end()) {
747                 temp.push_back(rv);
748             } else {
749                 rvars_removed.push_back(rv.var);
750             }
751         }
752         rvars.swap(temp);
753     }
754     RDom f_rdom(rvars);
755 
756     // Init definition of the intermediate Func
757 
758     // Compute args of the init definition of the intermediate Func.
759     // Replace the RVars, which are in 'rvars_kept', with the specified new pure
760     // Vars. Also, add the pure Vars of the original init definition as part of
761     // the args.
762     // For example, if we have the following Func f:
763     //   f(x, y) = 10
764     //   f(r.x, r.y) += h(r.x, r.y)
765     // Calling f.update(0).rfactor({{r.y, u}}) will generate the following
766     // intermediate Func:
767     //   f_intm(x, y, u) = 0
768     //   f_intm(r.x, u, u) += h(r.x, u)
769 
770     vector<Var> init_args;
771     init_args.insert(init_args.end(), dim_vars.begin(), dim_vars.end());
772     init_args.insert(init_args.end(), vars_rename.begin(), vars_rename.end());
773 
774     vector<Expr> init_vals(values.size());
775     for (size_t i = 0; i < init_vals.size(); ++i) {
776         init_vals[i] = prover_result.pattern.identities[i];
777     }
778 
779     Func intm(func_name + "_intm");
780     intm(init_args) = Tuple(init_vals);
781 
782     // Args of the update definition of the intermediate Func
783     vector<Expr> update_args(args.size() + vars_rename.size());
784 
785     // We need to substitute the reference to the old RDom's RVars with
786     // the new RDom's RVars. Also, substitute the reference to RVars which
787     // are in 'rvars_kept' with their corresponding new pure Vars
788     map<string, Expr> substitution_map;
789     for (size_t i = 0; i < intm_rvars.size(); ++i) {
790         substitution_map[intm_rvars[i].var] = intm_rdom[i];
791     }
792     for (size_t i = 0; i < vars_rename.size(); i++) {
793         update_args[i + args.size()] = vars_rename[i];
794         RVar rvar_kept = rvars_kept[i];
795         // Find the full name of rvar_kept in rvars
796         const auto &iter = std::find_if(rvars.begin(), rvars.end(),
797                                         [&rvar_kept](const ReductionVariable &rv) { return var_name_match(rv.var, rvar_kept.name()); });
798         substitution_map[iter->var] = vars_rename[i];
799     }
800     for (size_t i = 0; i < args.size(); i++) {
801         Expr arg = substitute(substitution_map, args[i]);
802         update_args[i] = arg;
803     }
804 
805     // Compute the predicates for the intermediate Func and the new update definition
806     for (const Expr &pred : predicates) {
807         Expr subs_pred = substitute(substitution_map, pred);
808         intm_rdom.where(subs_pred);
809         if (!expr_uses_vars(pred, scope)) {
810             // Only keep the predicate that does not depend on the lifted RVars
811             // (either explicitly or implicitly). For example, if 'rx' is split
812             // into 'rxo' and 'rxi' and 'rxo' is part of the lifted RVars, we'll
813             // ignore every predicate that depends on 'rx'
814             f_rdom.where(pred);
815         }
816     }
817     definition.predicate() = f_rdom.domain().predicate();
818 
819     // The update values the intermediate Func should compute
820     vector<Expr> update_vals(values.size());
821     for (size_t i = 0; i < update_vals.size(); i++) {
822         Expr val = substitute(substitution_map, values[i]);
823         // Need to update the self-reference in the update definition to point
824         // to the new intermediate Func
825         val = substitute_self_reference(val, func_name, intm.function(), vars_rename);
826         update_vals[i] = val;
827     }
828     intm(update_args) = Tuple(update_vals);
829 
830     // Determine the dims and schedule of the update definition of the
831     // intermediate Func. We copy over the schedule from the original
832     // update definition (e.g. split, parallelize, vectorize, etc.)
833     intm.function().update(0).schedule().dims() = dims;
834     intm.function().update(0).schedule().splits() = splits;
835 
836     // Copy over the storage order of the original pure dims
837     vector<StorageDim> &intm_storage_dims = intm.function().schedule().storage_dims();
838     internal_assert(intm_storage_dims.size() ==
839                     function.schedule().storage_dims().size() + vars_rename.size());
840     for (size_t i = 0; i < function.schedule().storage_dims().size(); ++i) {
841         intm_storage_dims[i] = function.schedule().storage_dims()[i];
842     }
843 
844     for (size_t i = 0; i < rvars_kept.size(); ++i) {
845         // Apply the purify directive that replaces the RVar in rvars_kept
846         // with a pure Var
847         intm.update(0).purify(rvars_kept[i], vars_rename[i]);
848     }
849 
850     // Determine the dims of the new update definition
851 
852     // Add pure Vars from the original init definition to the dims list
853     // if they are not already in the list
854     for (const Var &v : dim_vars) {
855         const auto &iter = std::find_if(dims.begin(), dims.end(),
856                                         [&v](const Dim &dim) { return var_name_match(dim.var, v.name()); });
857         if (iter == dims.end()) {
858             Dim d = {v.name(), ForType::Serial, DeviceAPI::None, DimType::PureVar};
859             dims.insert(dims.end() - 1, d);
860         }
861     }
862     // Then, we need to remove lifted RVars from the dims list
863     for (const string &rv : rvars_removed) {
864         remove(rv);
865     }
866 
867     // Define the new update definition which refers to the intermediate Func.
868     // Using the same example as above, the new update definition is:
869     //   f(x, y) += f_intm(x, y, r.y)
870 
871     // Args for store in the new update definition
872     vector<Expr> f_store_args(dim_vars.size());
873     for (size_t i = 0; i < f_store_args.size(); ++i) {
874         f_store_args[i] = dim_vars[i];
875     }
876 
877     // Call's args to the intermediate Func in the new update definition
878     vector<Expr> f_load_args;
879     f_load_args.insert(f_load_args.end(), dim_vars.begin(), dim_vars.end());
880     for (int i = 0; i < f_rdom.dimensions(); ++i) {
881         f_load_args.push_back(f_rdom[i]);
882     }
883     internal_assert(f_load_args.size() == init_args.size());
884 
885     // Update value of the new update definition. It loads values from
886     // the intermediate Func.
887     vector<Expr> f_values(values.size());
888 
889     // There might be cross-dependencies between tuple elements, so we need
890     // to collect all substitutions first.
891     map<string, Expr> replacements;
892     for (size_t i = 0; i < f_values.size(); ++i) {
893         if (!prover_result.ys[i].var.empty()) {
894             Expr r = (values.size() == 1) ? Expr(intm(f_load_args)) : Expr(intm(f_load_args)[i]);
895             replacements.emplace(prover_result.ys[i].var, r);
896         }
897 
898         if (!prover_result.xs[i].var.empty()) {
899             Expr prev_val = Call::make(intm.output_types()[i], func_name,
900                                        f_store_args, Call::CallType::Halide,
901                                        FunctionPtr(), i);
902             replacements.emplace(prover_result.xs[i].var, prev_val);
903         } else {
904             user_warning << "Update definition of " << name() << " at index " << i
905                          << " doesn't depend on the previous value. This isn't a"
906                          << " reduction operation\n";
907         }
908     }
909     for (size_t i = 0; i < f_values.size(); ++i) {
910         f_values[i] = substitute(replacements, prover_result.pattern.ops[i]);
911     }
912 
913     // Update the definition
914     args.swap(f_store_args);
915     values.swap(f_values);
916 
917     return intm;
918 }
919 
split(const string & old,const string & outer,const string & inner,const Expr & factor,bool exact,TailStrategy tail)920 void Stage::split(const string &old, const string &outer, const string &inner, const Expr &factor, bool exact, TailStrategy tail) {
921     debug(4) << "In schedule for " << name() << ", split " << old << " into "
922              << outer << " and " << inner << " with factor of " << factor << "\n";
923     vector<Dim> &dims = definition.schedule().dims();
924 
925     // Check that the new names aren't already in the dims list.
926     for (size_t i = 0; i < dims.size(); i++) {
927         string new_names[2] = {inner, outer};
928         for (int j = 0; j < 2; j++) {
929             if (var_name_match(dims[i].var, new_names[j]) && new_names[j] != old) {
930                 user_error << "In schedule for " << name()
931                            << ", can't create var " << new_names[j]
932                            << " using a split or tile, because " << new_names[j]
933                            << " is already used in this Func's schedule elsewhere.\n"
934                            << dump_argument_list();
935             }
936         }
937     }
938 
939     // Replace the old dimension with the new dimensions in the dims list
940     bool found = false;
941     string inner_name, outer_name, old_name;
942 
943     for (size_t i = 0; (!found) && i < dims.size(); i++) {
944         if (var_name_match(dims[i].var, old)) {
945             found = true;
946             old_name = dims[i].var;
947             inner_name = old_name + "." + inner;
948             outer_name = old_name + "." + outer;
949             dims.insert(dims.begin() + i, dims[i]);
950             dims[i].var = inner_name;
951             dims[i + 1].var = outer_name;
952             if (dims[i].for_type == ForType::Extern) {
953                 // If we split an extern loop, mark the outer loop serial.
954                 dims[i + 1].for_type = ForType::Serial;
955             }
956         }
957     }
958 
959     if (!found) {
960         user_error << "In schedule for " << name()
961                    << ", could not find split dimension: "
962                    << old
963                    << "\n"
964                    << dump_argument_list();
965     }
966 
967     bool round_up_ok = !exact;
968     if (round_up_ok && !definition.is_init()) {
969         // If it's the outermost split in this dimension, RoundUp
970         // is OK. Otherwise we need GuardWithIf to avoid
971         // recomputing values in the case where the inner split
972         // factor does not divide the outer split factor.
973         std::set<string> inner_vars;
974         for (const Split &s : definition.schedule().splits()) {
975             if (s.is_split()) {
976                 inner_vars.insert(s.inner);
977                 if (inner_vars.count(s.old_var)) {
978                     inner_vars.insert(s.outer);
979                 }
980             } else if (s.is_rename() || s.is_purify()) {
981                 if (inner_vars.count(s.old_var)) {
982                     inner_vars.insert(s.outer);
983                 }
984             } else if (s.is_fuse()) {
985                 if (inner_vars.count(s.inner) || inner_vars.count(s.outer)) {
986                     inner_vars.insert(s.old_var);
987                 }
988             }
989         }
990         round_up_ok = !inner_vars.count(old_name);
991         user_assert(round_up_ok || tail != TailStrategy::RoundUp)
992             << "Can't use TailStrategy::RoundUp for splitting " << old_name
993             << " in update definition of " << name() << ". "
994             << "It may redundantly recompute some values, which "
995             << "could change the meaning of the algorithm. "
996             << "Use TailStrategy::GuardWithIf instead.";
997     }
998 
999     if (tail == TailStrategy::Auto) {
1000         // Select a tail strategy
1001         if (exact) {
1002             tail = TailStrategy::GuardWithIf;
1003         } else if (!definition.is_init()) {
1004             tail = round_up_ok ? TailStrategy::RoundUp : TailStrategy::GuardWithIf;
1005         } else {
1006             // We should employ ShiftInwards when we can to prevent
1007             // overcompute and adding constraints to the bounds of
1008             // inputs and outputs. However, if we're already covered
1009             // by an earlier larger ShiftInwards split, there's no
1010             // point - it just complicates the IR and confuses bounds
1011             // inference. An example of this is:
1012             //
1013             // f.vectorize(x, 8).unroll(x, 4);
1014             //
1015             // The vectorize-induced split is ShiftInwards. There's no
1016             // point also applying ShiftInwards to the unroll-induced
1017             // split.
1018             //
1019             // Note that we'll still partition the outermost loop to
1020             // avoid the overhead of the min we placed in the inner
1021             // loop with the vectorize, because that's how loop
1022             // partitioning works. The steady-state will be just as
1023             // efficient as:
1024             //
1025             // f.split(x, x, xi, 32).vectorize(xi, 8).unroll(xi);
1026             //
1027             // It's only the tail/epilogue that changes.
1028 
1029             std::map<string, Expr> descends_from_shiftinwards_outer;
1030             for (const Split &s : definition.schedule().splits()) {
1031                 auto it = descends_from_shiftinwards_outer.find(s.old_var);
1032                 if (s.is_split() && s.tail == TailStrategy::ShiftInwards) {
1033                     descends_from_shiftinwards_outer[s.outer] = s.factor;
1034                 } else if (s.is_split() && it != descends_from_shiftinwards_outer.end()) {
1035                     descends_from_shiftinwards_outer[s.inner] = it->second;
1036                     descends_from_shiftinwards_outer[s.outer] = it->second;
1037                 } else if ((s.is_rename() || s.is_purify()) &&
1038                            it != descends_from_shiftinwards_outer.end()) {
1039                     descends_from_shiftinwards_outer[s.outer] = it->second;
1040                 }
1041             }
1042             auto it = descends_from_shiftinwards_outer.find(old_name);
1043             if (it != descends_from_shiftinwards_outer.end() &&
1044                 can_prove(it->second >= factor)) {
1045                 tail = TailStrategy::RoundUp;
1046             } else {
1047                 tail = TailStrategy::ShiftInwards;
1048             }
1049         }
1050     }
1051 
1052     if (!definition.is_init()) {
1053         user_assert(tail != TailStrategy::ShiftInwards)
1054             << "When splitting Var " << old_name
1055             << " ShiftInwards is not a legal tail strategy for update definitions, as"
1056             << " it may change the meaning of the algorithm\n";
1057     }
1058 
1059     if (exact) {
1060         user_assert(tail == TailStrategy::GuardWithIf)
1061             << "When splitting Var " << old_name
1062             << " the tail strategy must be GuardWithIf or Auto. "
1063             << "Anything else may change the meaning of the algorithm\n";
1064     }
1065 
1066     // Add the split to the splits list
1067     Split split = {old_name, outer_name, inner_name, factor, exact, tail, Split::SplitVar};
1068     definition.schedule().splits().push_back(split);
1069 }
1070 
split(const VarOrRVar & old,const VarOrRVar & outer,const VarOrRVar & inner,const Expr & factor,TailStrategy tail)1071 Stage &Stage::split(const VarOrRVar &old, const VarOrRVar &outer, const VarOrRVar &inner, const Expr &factor, TailStrategy tail) {
1072     if (old.is_rvar) {
1073         user_assert(outer.is_rvar) << "Can't split RVar " << old.name() << " into Var " << outer.name() << "\n";
1074         user_assert(inner.is_rvar) << "Can't split RVar " << old.name() << " into Var " << inner.name() << "\n";
1075     } else {
1076         user_assert(!outer.is_rvar) << "Can't split Var " << old.name() << " into RVar " << outer.name() << "\n";
1077         user_assert(!inner.is_rvar) << "Can't split Var " << old.name() << " into RVar " << inner.name() << "\n";
1078     }
1079     split(old.name(), outer.name(), inner.name(), factor, old.is_rvar, tail);
1080     return *this;
1081 }
1082 
fuse(const VarOrRVar & inner,const VarOrRVar & outer,const VarOrRVar & fused)1083 Stage &Stage::fuse(const VarOrRVar &inner, const VarOrRVar &outer, const VarOrRVar &fused) {
1084     if (!fused.is_rvar) {
1085         user_assert(!outer.is_rvar) << "Can't fuse Var " << fused.name()
1086                                     << " from RVar " << outer.name() << "\n";
1087         user_assert(!inner.is_rvar) << "Can't fuse Var " << inner.name()
1088                                     << " from RVar " << inner.name() << "\n";
1089     }
1090 
1091     debug(4) << "In schedule for " << name() << ", fuse " << outer.name()
1092              << " and " << inner.name() << " into " << fused.name() << "\n";
1093 
1094     // Replace the old dimensions with the new dimension in the dims list
1095     bool found_outer = false, found_inner = false;
1096     string inner_name, outer_name, fused_name;
1097     vector<Dim> &dims = definition.schedule().dims();
1098 
1099     DimType outer_type = DimType::PureRVar;
1100     for (size_t i = 0; (!found_outer) && i < dims.size(); i++) {
1101         if (var_name_match(dims[i].var, outer.name())) {
1102             found_outer = true;
1103             outer_name = dims[i].var;
1104             outer_type = dims[i].dim_type;
1105             dims.erase(dims.begin() + i);
1106         }
1107     }
1108     if (!found_outer) {
1109         user_error << "In schedule for " << name()
1110                    << ", could not find outer fuse dimension: "
1111                    << outer.name()
1112                    << "\n"
1113                    << dump_argument_list();
1114     }
1115 
1116     for (size_t i = 0; (!found_inner) && i < dims.size(); i++) {
1117         if (var_name_match(dims[i].var, inner.name())) {
1118             found_inner = true;
1119             inner_name = dims[i].var;
1120             fused_name = inner_name + "." + fused.name();
1121             dims[i].var = fused_name;
1122 
1123             if (dims[i].dim_type == DimType::ImpureRVar ||
1124                 outer_type == DimType::ImpureRVar) {
1125                 dims[i].dim_type = DimType::ImpureRVar;
1126             } else if (dims[i].dim_type == DimType::PureRVar ||
1127                        outer_type == DimType::PureRVar) {
1128                 dims[i].dim_type = DimType::PureRVar;
1129             } else {
1130                 dims[i].dim_type = DimType::PureVar;
1131             }
1132         }
1133     }
1134 
1135     if (!found_inner) {
1136         user_error << "In schedule for " << name()
1137                    << ", could not find inner fuse dimension: "
1138                    << inner.name()
1139                    << "\n"
1140                    << dump_argument_list();
1141     }
1142 
1143     // Add the fuse to the splits list
1144     Split split = {fused_name, outer_name, inner_name, Expr(), true, TailStrategy::RoundUp, Split::FuseVars};
1145     definition.schedule().splits().push_back(split);
1146     return *this;
1147 }
1148 
1149 namespace Internal {
1150 class CheckForFreeVars : public IRGraphVisitor {
1151 public:
1152     string offending_var;
1153 
1154 protected:
1155     using IRGraphVisitor::visit;
visit(const Variable * var)1156     void visit(const Variable *var) override {
1157         if (!var->param.defined() && !var->image.defined()) {
1158             offending_var = var->name;
1159         }
1160     }
1161 };
1162 }  // namespace Internal
1163 
specialize(const Expr & condition)1164 Stage Stage::specialize(const Expr &condition) {
1165     user_assert(condition.type().is_bool()) << "Argument passed to specialize must be of type bool\n";
1166 
1167     // The condition may not depend on Vars or RVars
1168     Internal::CheckForFreeVars check;
1169     condition.accept(&check);
1170     if (!check.offending_var.empty()) {
1171         user_error << "Specialization condition " << condition << " for " << name()
1172                    << " depends on Var or RVar " << check.offending_var << ". "
1173                    << "Specialization conditions may not depend on any Vars or RVars.\n";
1174     }
1175 
1176     // The user may be retrieving a reference to an existing
1177     // specialization.
1178     const vector<Specialization> &specializations = definition.specializations();
1179     for (const auto &specialization : specializations) {
1180         if (equal(condition, specialization.condition)) {
1181             return Stage(function, specialization.definition, stage_index);
1182         }
1183     }
1184 
1185     // Can't add any more specializations after specialize_fail().
1186     user_assert(specializations.empty() || specializations.back().failure_message.empty())
1187         << "Cannot add new specializations after specialize_fail().";
1188     const Specialization &s = definition.add_specialization(condition);
1189 
1190     return Stage(function, s.definition, stage_index);
1191 }
1192 
specialize_fail(const std::string & message)1193 void Stage::specialize_fail(const std::string &message) {
1194     user_assert(!message.empty()) << "Argument passed to specialize_fail() must not be empty.\n";
1195     const vector<Specialization> &specializations = definition.specializations();
1196     user_assert(specializations.empty() || specializations.back().failure_message.empty())
1197         << "Only one specialize_fail() may be defined per Stage.";
1198     (void)definition.add_specialization(const_true());
1199     Specialization &s = definition.specializations().back();
1200     s.failure_message = message;
1201 }
1202 
purify(const VarOrRVar & old_var,const VarOrRVar & new_var)1203 Stage &Stage::purify(const VarOrRVar &old_var, const VarOrRVar &new_var) {
1204     user_assert(old_var.is_rvar && !new_var.is_rvar)
1205         << "In schedule for " << name()
1206         << ", can't rename " << (old_var.is_rvar ? "RVar " : "Var ") << old_var.name()
1207         << " to " << (new_var.is_rvar ? "RVar " : "Var ") << new_var.name()
1208         << "; purify must take a RVar as old_Var and a Var as new_var\n";
1209 
1210     debug(4) << "In schedule for " << name() << ", purify RVar "
1211              << old_var.name() << " to Var " << new_var.name() << "\n";
1212 
1213     StageSchedule &schedule = definition.schedule();
1214 
1215     // Replace the old dimension with the new dimensions in the dims list
1216     bool found = false;
1217     string old_name, new_name = new_var.name();
1218     vector<Dim> &dims = schedule.dims();
1219 
1220     for (size_t i = 0; (!found) && i < dims.size(); i++) {
1221         if (var_name_match(dims[i].var, old_var.name())) {
1222             found = true;
1223             old_name = dims[i].var;
1224             dims[i].var = new_name;
1225             dims[i].dim_type = DimType::PureVar;
1226         }
1227     }
1228 
1229     if (!found) {
1230         user_error
1231             << "In schedule for " << name()
1232             << ", could not find rename dimension: "
1233             << old_var.name()
1234             << "\n"
1235             << dump_argument_list();
1236     }
1237 
1238     Split split = {old_name, new_name, "", 1, false, TailStrategy::RoundUp, Split::PurifyRVar};
1239     definition.schedule().splits().push_back(split);
1240     return *this;
1241 }
1242 
remove(const string & var)1243 void Stage::remove(const string &var) {
1244     debug(4) << "In schedule for " << name() << ", remove " << var << "\n";
1245 
1246     StageSchedule &schedule = definition.schedule();
1247 
1248     // Replace the old dimension with the new dimensions in the dims list
1249     bool found = false;
1250     string old_name = var;
1251     vector<Dim> &dims = schedule.dims();
1252     for (size_t i = 0; (!found) && i < dims.size(); i++) {
1253         if (dims[i].var == var) {
1254             found = true;
1255             old_name = dims[i].var;
1256             dims.erase(dims.begin() + i);
1257         }
1258     }
1259 
1260     if (!found) {
1261         user_error
1262             << "In schedule for " << name()
1263             << ", could not find remove dimension: "
1264             << var
1265             << "\n"
1266             << dump_argument_list();
1267     }
1268 
1269     std::set<string> removed_vars;
1270     removed_vars.insert(var);
1271 
1272     auto should_remove = [&removed_vars](const string &var) {
1273         const auto &iter = std::find_if(
1274             removed_vars.begin(), removed_vars.end(), [&var](const string &rv) { return rv == var; });
1275         return iter != removed_vars.end();
1276     };
1277 
1278     vector<Split> &splits = schedule.splits();
1279     vector<Split> temp;
1280     for (size_t i = splits.size(); i > 0; i--) {
1281         bool is_removed = false;
1282         if (splits[i - 1].is_fuse()) {
1283             debug(4) << "    checking fuse " << splits[i - 1].inner << " and "
1284                      << splits[i - 1].inner << " into " << splits[i - 1].old_var << "\n";
1285             if (splits[i - 1].inner == old_name ||
1286                 splits[i - 1].outer == old_name) {
1287                 user_error
1288                     << "In schedule for " << name()
1289                     << ", can't remove variable " << old_name
1290                     << " because it has already been fused into "
1291                     << splits[i - 1].old_var << "\n"
1292                     << dump_argument_list();
1293             }
1294             if (should_remove(splits[i - 1].old_var)) {
1295                 is_removed = true;
1296                 removed_vars.insert(splits[i - 1].outer);
1297                 removed_vars.insert(splits[i - 1].inner);
1298             }
1299         } else if (splits[i - 1].is_split()) {
1300             debug(4) << "    splitting " << splits[i - 1].old_var << " into "
1301                      << splits[i - 1].outer << " and " << splits[i - 1].inner << "\n";
1302             if (should_remove(splits[i - 1].inner)) {
1303                 is_removed = true;
1304                 removed_vars.insert(splits[i - 1].old_var);
1305             } else if (should_remove(splits[i - 1].outer)) {
1306                 is_removed = true;
1307                 removed_vars.insert(splits[i - 1].old_var);
1308             }
1309             if (splits[i - 1].old_var == old_name) {
1310                 user_error
1311                     << "In schedule for " << name()
1312                     << ", can't remove a variable " << old_name
1313                     << " because it has already been renamed or split.\n"
1314                     << dump_argument_list();
1315             }
1316         } else {
1317             debug(4) << "    replace/rename " << splits[i - 1].old_var
1318                      << " into " << splits[i - 1].outer << "\n";
1319             if (should_remove(splits[i - 1].outer)) {
1320                 is_removed = true;
1321                 removed_vars.insert(splits[i - 1].old_var);
1322             }
1323             if (splits[i - 1].old_var == old_name) {
1324                 user_error
1325                     << "In schedule for " << name()
1326                     << ", can't remove a variable " << old_name
1327                     << " because it has already been renamed or split.\n"
1328                     << dump_argument_list();
1329             }
1330         }
1331         if (!is_removed) {
1332             temp.insert(temp.begin(), splits[i - 1]);
1333         }
1334     }
1335     splits.swap(temp);
1336 }
1337 
rename(const VarOrRVar & old_var,const VarOrRVar & new_var)1338 Stage &Stage::rename(const VarOrRVar &old_var, const VarOrRVar &new_var) {
1339     if (old_var.is_rvar) {
1340         user_assert(new_var.is_rvar)
1341             << "In schedule for " << name()
1342             << ", can't rename RVar " << old_var.name()
1343             << " to Var " << new_var.name() << "\n";
1344     } else {
1345         user_assert(!new_var.is_rvar)
1346             << "In schedule for " << name()
1347             << ", can't rename Var " << old_var.name()
1348             << " to RVar " << new_var.name() << "\n";
1349     }
1350 
1351     debug(4) << "In schedule for " << name() << ", rename " << old_var.name()
1352              << " to " << new_var.name() << "\n";
1353 
1354     StageSchedule &schedule = definition.schedule();
1355 
1356     // Replace the old dimension with the new dimensions in the dims list
1357     bool found = false;
1358     string old_name;
1359     vector<Dim> &dims = schedule.dims();
1360     for (size_t i = 0; (!found) && i < dims.size(); i++) {
1361         if (var_name_match(dims[i].var, old_var.name())) {
1362             found = true;
1363             old_name = dims[i].var;
1364             dims[i].var += "." + new_var.name();
1365         }
1366     }
1367 
1368     string new_name = old_name + "." + new_var.name();
1369 
1370     if (!found) {
1371         user_error
1372             << "In schedule for " << name()
1373             << ", could not find rename dimension: "
1374             << old_var.name()
1375             << "\n"
1376             << dump_argument_list();
1377     }
1378 
1379     // If possible, rewrite the split or rename that defines it.
1380     found = false;
1381     vector<Split> &splits = schedule.splits();
1382     for (size_t i = splits.size(); i > 0; i--) {
1383         if (splits[i - 1].is_fuse()) {
1384             if (splits[i - 1].inner == old_name ||
1385                 splits[i - 1].outer == old_name) {
1386                 user_error
1387                     << "In schedule for " << name()
1388                     << ", can't rename variable " << old_name
1389                     << " because it has already been fused into "
1390                     << splits[i - 1].old_var << "\n"
1391                     << dump_argument_list();
1392             }
1393             if (splits[i - 1].old_var == old_name) {
1394                 splits[i - 1].old_var = new_name;
1395                 found = true;
1396                 break;
1397             }
1398         } else {
1399             if (splits[i - 1].inner == old_name) {
1400                 splits[i - 1].inner = new_name;
1401                 found = true;
1402                 break;
1403             }
1404             if (splits[i - 1].outer == old_name) {
1405                 splits[i - 1].outer = new_name;
1406                 found = true;
1407                 break;
1408             }
1409             if (splits[i - 1].old_var == old_name) {
1410                 user_error
1411                     << "In schedule for " << name()
1412                     << ", can't rename a variable " << old_name
1413                     << " because it has already been renamed or split.\n"
1414                     << dump_argument_list();
1415             }
1416         }
1417     }
1418 
1419     if (!found) {
1420         Split split = {old_name, new_name, "", 1, old_var.is_rvar, TailStrategy::RoundUp, Split::RenameVar};
1421         definition.schedule().splits().push_back(split);
1422     }
1423 
1424     return *this;
1425 }
1426 
allow_race_conditions()1427 Stage &Stage::allow_race_conditions() {
1428     definition.schedule().allow_race_conditions() = true;
1429     return *this;
1430 }
1431 
atomic(bool override_associativity_test)1432 Stage &Stage::atomic(bool override_associativity_test) {
1433     definition.schedule().atomic() = true;
1434     definition.schedule().override_atomic_associativity_test() = override_associativity_test;
1435     return *this;
1436 }
1437 
serial(const VarOrRVar & var)1438 Stage &Stage::serial(const VarOrRVar &var) {
1439     set_dim_type(var, ForType::Serial);
1440     return *this;
1441 }
1442 
parallel(const VarOrRVar & var)1443 Stage &Stage::parallel(const VarOrRVar &var) {
1444     set_dim_type(var, ForType::Parallel);
1445     return *this;
1446 }
1447 
vectorize(const VarOrRVar & var)1448 Stage &Stage::vectorize(const VarOrRVar &var) {
1449     set_dim_type(var, ForType::Vectorized);
1450     return *this;
1451 }
1452 
unroll(const VarOrRVar & var)1453 Stage &Stage::unroll(const VarOrRVar &var) {
1454     set_dim_type(var, ForType::Unrolled);
1455     return *this;
1456 }
1457 
parallel(const VarOrRVar & var,const Expr & factor,TailStrategy tail)1458 Stage &Stage::parallel(const VarOrRVar &var, const Expr &factor, TailStrategy tail) {
1459     if (var.is_rvar) {
1460         RVar tmp;
1461         split(var.rvar, var.rvar, tmp, factor, tail);
1462     } else {
1463         Var tmp;
1464         split(var.var, var.var, tmp, factor, tail);
1465     }
1466     parallel(var);
1467     return *this;
1468 }
1469 
vectorize(const VarOrRVar & var,const Expr & factor,TailStrategy tail)1470 Stage &Stage::vectorize(const VarOrRVar &var, const Expr &factor, TailStrategy tail) {
1471     if (var.is_rvar) {
1472         RVar tmp;
1473         split(var.rvar, var.rvar, tmp, factor, tail);
1474         vectorize(tmp);
1475     } else {
1476         Var tmp;
1477         split(var.var, var.var, tmp, factor, tail);
1478         vectorize(tmp);
1479     }
1480     return *this;
1481 }
1482 
unroll(const VarOrRVar & var,const Expr & factor,TailStrategy tail)1483 Stage &Stage::unroll(const VarOrRVar &var, const Expr &factor, TailStrategy tail) {
1484     if (var.is_rvar) {
1485         RVar tmp;
1486         split(var.rvar, var.rvar, tmp, factor, tail);
1487         unroll(tmp);
1488     } else {
1489         Var tmp;
1490         split(var.var, var.var, tmp, factor, tail);
1491         unroll(tmp);
1492     }
1493 
1494     return *this;
1495 }
1496 
tile(const VarOrRVar & x,const VarOrRVar & y,const VarOrRVar & xo,const VarOrRVar & yo,const VarOrRVar & xi,const VarOrRVar & yi,const Expr & xfactor,const Expr & yfactor,TailStrategy tail)1497 Stage &Stage::tile(const VarOrRVar &x, const VarOrRVar &y,
1498                    const VarOrRVar &xo, const VarOrRVar &yo,
1499                    const VarOrRVar &xi, const VarOrRVar &yi,
1500                    const Expr &xfactor, const Expr &yfactor,
1501                    TailStrategy tail) {
1502     split(x, xo, xi, xfactor, tail);
1503     split(y, yo, yi, yfactor, tail);
1504     reorder(xi, yi, xo, yo);
1505     return *this;
1506 }
1507 
tile(const VarOrRVar & x,const VarOrRVar & y,const VarOrRVar & xi,const VarOrRVar & yi,const Expr & xfactor,const Expr & yfactor,TailStrategy tail)1508 Stage &Stage::tile(const VarOrRVar &x, const VarOrRVar &y,
1509                    const VarOrRVar &xi, const VarOrRVar &yi,
1510                    const Expr &xfactor, const Expr &yfactor,
1511                    TailStrategy tail) {
1512     split(x, x, xi, xfactor, tail);
1513     split(y, y, yi, yfactor, tail);
1514     reorder(xi, yi, x, y);
1515     return *this;
1516 }
1517 
tile(const std::vector<VarOrRVar> & previous,const std::vector<VarOrRVar> & outers,const std::vector<VarOrRVar> & inners,const std::vector<Expr> & factors,const std::vector<TailStrategy> & tails)1518 Stage &Stage::tile(const std::vector<VarOrRVar> &previous,
1519                    const std::vector<VarOrRVar> &outers,
1520                    const std::vector<VarOrRVar> &inners,
1521                    const std::vector<Expr> &factors,
1522                    const std::vector<TailStrategy> &tails) {
1523     if (previous.size() != outers.size() || previous.size() != inners.size() || previous.size() != factors.size() || previous.size() != tails.size())
1524         user_error << "Vectors passed to Stage::tile must all be the same length.\n";
1525     for (unsigned int i = 0; i < previous.size(); i++) {
1526         split(previous[i], outers[i], inners[i], factors[i], tails[i]);
1527     }
1528     std::vector<VarOrRVar> new_order;
1529     new_order.insert(new_order.end(), inners.begin(), inners.end());
1530     new_order.insert(new_order.end(), outers.begin(), outers.end());
1531     reorder(new_order);
1532     return *this;
1533 }
1534 
tile(const std::vector<VarOrRVar> & previous,const std::vector<VarOrRVar> & outers,const std::vector<VarOrRVar> & inners,const std::vector<Expr> & factors,TailStrategy tail)1535 Stage &Stage::tile(const std::vector<VarOrRVar> &previous,
1536                    const std::vector<VarOrRVar> &outers,
1537                    const std::vector<VarOrRVar> &inners,
1538                    const std::vector<Expr> &factors,
1539                    TailStrategy tail) {
1540     std::vector<TailStrategy> tails;
1541     for (unsigned int i = 0; i < previous.size(); i++) {
1542         tails.push_back(tail);
1543     }
1544     return tile(previous, outers, inners, factors, tails);
1545 }
1546 
tile(const std::vector<VarOrRVar> & previous,const std::vector<VarOrRVar> & inners,const std::vector<Expr> & factors,TailStrategy tail)1547 Stage &Stage::tile(const std::vector<VarOrRVar> &previous,
1548                    const std::vector<VarOrRVar> &inners,
1549                    const std::vector<Expr> &factors,
1550                    TailStrategy tail) {
1551     return tile(previous, previous, inners, factors, tail);
1552 }
1553 
reorder(const std::vector<VarOrRVar> & vars)1554 Stage &Stage::reorder(const std::vector<VarOrRVar> &vars) {
1555     const string &func_name = function.name();
1556     vector<Expr> &args = definition.args();
1557     vector<Expr> &values = definition.values();
1558     vector<Dim> &dims_old = definition.schedule().dims();
1559     vector<Dim> dims = dims_old;
1560 
1561     // Tag all the vars with their locations in the dims list.
1562     vector<size_t> idx(vars.size());
1563     for (size_t i = 0; i < vars.size(); i++) {
1564         bool found = false;
1565         for (size_t j = 0; j < dims.size(); j++) {
1566             if (var_name_match(dims[j].var, vars[i].name())) {
1567                 idx[i] = j;
1568                 found = true;
1569             }
1570         }
1571         user_assert(found)
1572             << "In schedule for " << name()
1573             << ", could not find var " << vars[i].name()
1574             << " to reorder in the argument list.\n"
1575             << dump_argument_list();
1576         // Check for duplicates
1577         for (size_t j = 0; j < i; j++) {
1578             user_assert(idx[i] != idx[j])
1579                 << "In schedule for " << name()
1580                 << ", call to reorder references " << vars[i].name()
1581                 << " twice.\n";
1582         }
1583     }
1584 
1585     // It is illegal to reorder RVars if the stage is not associative
1586     // or not commutative. Look for RVar reorderings and try to do the
1587     // necessary proof if any are found.
1588     bool associativity_proven = false;
1589     for (size_t i = 0; !associativity_proven && i < idx.size(); i++) {
1590         if (!dims[idx[i]].is_pure()) {
1591             for (size_t j = i + 1; !associativity_proven && j < idx.size(); j++) {
1592                 if (!dims[idx[j]].is_pure() && (idx[i] > idx[j])) {
1593                     // Generate an error if the operator is not both associative and commutative.
1594                     const auto &prover_result = prove_associativity(func_name, args, values);
1595                     associativity_proven = prover_result.associative() &&
1596                                            prover_result.commutative();
1597                     if (!associativity_proven) {
1598                         user_error
1599                             << "In schedule for " << name()
1600                             << ", can't reorder RVars " << vars[i].name()
1601                             << " and " << vars[j].name()
1602                             << " because it may change the meaning of the "
1603                             << "algorithm.\n";
1604                     }
1605                 }
1606             }
1607         }
1608     }
1609 
1610     // Sort idx to get the new locations
1611     vector<size_t> sorted = idx;
1612     std::sort(sorted.begin(), sorted.end());
1613 
1614     for (size_t i = 0; i < vars.size(); i++) {
1615         dims[sorted[i]] = dims_old[idx[i]];
1616     }
1617 
1618     dims_old.swap(dims);
1619 
1620     return *this;
1621 }
1622 
gpu_threads(const VarOrRVar & tx,DeviceAPI device_api)1623 Stage &Stage::gpu_threads(const VarOrRVar &tx, DeviceAPI device_api) {
1624     set_dim_device_api(tx, device_api);
1625     set_dim_type(tx, ForType::GPUThread);
1626     return *this;
1627 }
1628 
gpu_threads(const VarOrRVar & tx,const VarOrRVar & ty,DeviceAPI device_api)1629 Stage &Stage::gpu_threads(const VarOrRVar &tx, const VarOrRVar &ty, DeviceAPI device_api) {
1630     set_dim_device_api(tx, device_api);
1631     set_dim_device_api(ty, device_api);
1632     set_dim_type(tx, ForType::GPUThread);
1633     set_dim_type(ty, ForType::GPUThread);
1634     return *this;
1635 }
1636 
gpu_threads(const VarOrRVar & tx,const VarOrRVar & ty,const VarOrRVar & tz,DeviceAPI device_api)1637 Stage &Stage::gpu_threads(const VarOrRVar &tx, const VarOrRVar &ty, const VarOrRVar &tz, DeviceAPI device_api) {
1638     set_dim_device_api(tx, device_api);
1639     set_dim_device_api(ty, device_api);
1640     set_dim_device_api(tz, device_api);
1641     set_dim_type(tx, ForType::GPUThread);
1642     set_dim_type(ty, ForType::GPUThread);
1643     set_dim_type(tz, ForType::GPUThread);
1644     return *this;
1645 }
1646 
gpu_lanes(const VarOrRVar & tx,DeviceAPI device_api)1647 Stage &Stage::gpu_lanes(const VarOrRVar &tx, DeviceAPI device_api) {
1648     set_dim_device_api(tx, device_api);
1649     set_dim_type(tx, ForType::GPULane);
1650     return *this;
1651 }
1652 
gpu_blocks(const VarOrRVar & bx,DeviceAPI device_api)1653 Stage &Stage::gpu_blocks(const VarOrRVar &bx, DeviceAPI device_api) {
1654     set_dim_device_api(bx, device_api);
1655     set_dim_type(bx, ForType::GPUBlock);
1656     return *this;
1657 }
1658 
gpu_blocks(const VarOrRVar & bx,const VarOrRVar & by,DeviceAPI device_api)1659 Stage &Stage::gpu_blocks(const VarOrRVar &bx, const VarOrRVar &by, DeviceAPI device_api) {
1660     set_dim_device_api(bx, device_api);
1661     set_dim_device_api(by, device_api);
1662     set_dim_type(bx, ForType::GPUBlock);
1663     set_dim_type(by, ForType::GPUBlock);
1664     return *this;
1665 }
1666 
gpu_blocks(const VarOrRVar & bx,const VarOrRVar & by,const VarOrRVar & bz,DeviceAPI device_api)1667 Stage &Stage::gpu_blocks(const VarOrRVar &bx, const VarOrRVar &by, const VarOrRVar &bz, DeviceAPI device_api) {
1668     set_dim_device_api(bx, device_api);
1669     set_dim_device_api(by, device_api);
1670     set_dim_device_api(bz, device_api);
1671     set_dim_type(bx, ForType::GPUBlock);
1672     set_dim_type(by, ForType::GPUBlock);
1673     set_dim_type(bz, ForType::GPUBlock);
1674     return *this;
1675 }
1676 
gpu_single_thread(DeviceAPI device_api)1677 Stage &Stage::gpu_single_thread(DeviceAPI device_api) {
1678     Var block, thread;
1679     split(Var::outermost(), Var::outermost(), thread, 1);
1680     split(Var::outermost(), Var::outermost(), block, 1);
1681     gpu_blocks(block, device_api);
1682     gpu_threads(thread, device_api);
1683     return *this;
1684 }
1685 
gpu(const VarOrRVar & bx,const VarOrRVar & tx,DeviceAPI device_api)1686 Stage &Stage::gpu(const VarOrRVar &bx, const VarOrRVar &tx, DeviceAPI device_api) {
1687     return gpu_blocks(bx).gpu_threads(tx);
1688 }
1689 
gpu(const VarOrRVar & bx,const VarOrRVar & by,const VarOrRVar & tx,const VarOrRVar & ty,DeviceAPI device_api)1690 Stage &Stage::gpu(const VarOrRVar &bx, const VarOrRVar &by,
1691                   const VarOrRVar &tx, const VarOrRVar &ty, DeviceAPI device_api) {
1692     return gpu_blocks(bx, by).gpu_threads(tx, ty);
1693 }
1694 
gpu(const VarOrRVar & bx,const VarOrRVar & by,const VarOrRVar & bz,const VarOrRVar & tx,const VarOrRVar & ty,const VarOrRVar & tz,DeviceAPI device_api)1695 Stage &Stage::gpu(const VarOrRVar &bx, const VarOrRVar &by, const VarOrRVar &bz,
1696                   const VarOrRVar &tx, const VarOrRVar &ty, const VarOrRVar &tz,
1697                   DeviceAPI device_api) {
1698     return gpu_blocks(bx, by, bz).gpu_threads(tx, ty, tz);
1699 }
1700 
gpu_tile(const VarOrRVar & x,const VarOrRVar & bx,const VarOrRVar & tx,const Expr & x_size,TailStrategy tail,DeviceAPI device_api)1701 Stage &Stage::gpu_tile(const VarOrRVar &x, const VarOrRVar &bx, const VarOrRVar &tx, const Expr &x_size,
1702                        TailStrategy tail, DeviceAPI device_api) {
1703     split(x, bx, tx, x_size, tail);
1704     set_dim_device_api(bx, device_api);
1705     set_dim_device_api(tx, device_api);
1706     set_dim_type(bx, ForType::GPUBlock);
1707     set_dim_type(tx, ForType::GPUThread);
1708     return *this;
1709 }
1710 
gpu_tile(const VarOrRVar & x,const VarOrRVar & tx,const Expr & x_size,TailStrategy tail,DeviceAPI device_api)1711 Stage &Stage::gpu_tile(const VarOrRVar &x, const VarOrRVar &tx, const Expr &x_size,
1712                        TailStrategy tail, DeviceAPI device_api) {
1713     split(x, x, tx, x_size, tail);
1714     set_dim_device_api(x, device_api);
1715     set_dim_device_api(tx, device_api);
1716     set_dim_type(x, ForType::GPUBlock);
1717     set_dim_type(tx, ForType::GPUThread);
1718     return *this;
1719 }
1720 
gpu_tile(const VarOrRVar & x,const VarOrRVar & y,const VarOrRVar & bx,const VarOrRVar & by,const VarOrRVar & tx,const VarOrRVar & ty,const Expr & x_size,const Expr & y_size,TailStrategy tail,DeviceAPI device_api)1721 Stage &Stage::gpu_tile(const VarOrRVar &x, const VarOrRVar &y,
1722                        const VarOrRVar &bx, const VarOrRVar &by,
1723                        const VarOrRVar &tx, const VarOrRVar &ty,
1724                        const Expr &x_size, const Expr &y_size,
1725                        TailStrategy tail,
1726                        DeviceAPI device_api) {
1727     tile(x, y, bx, by, tx, ty, x_size, y_size, tail);
1728     set_dim_device_api(bx, device_api);
1729     set_dim_device_api(by, device_api);
1730     set_dim_device_api(tx, device_api);
1731     set_dim_device_api(ty, device_api);
1732     set_dim_type(bx, ForType::GPUBlock);
1733     set_dim_type(by, ForType::GPUBlock);
1734     set_dim_type(tx, ForType::GPUThread);
1735     set_dim_type(ty, ForType::GPUThread);
1736     return *this;
1737 }
1738 
gpu_tile(const VarOrRVar & x,const VarOrRVar & y,const VarOrRVar & tx,const VarOrRVar & ty,const Expr & x_size,const Expr & y_size,TailStrategy tail,DeviceAPI device_api)1739 Stage &Stage::gpu_tile(const VarOrRVar &x, const VarOrRVar &y,
1740                        const VarOrRVar &tx, const VarOrRVar &ty,
1741                        const Expr &x_size, const Expr &y_size,
1742                        TailStrategy tail,
1743                        DeviceAPI device_api) {
1744     return gpu_tile(x, y, x, y, tx, ty, x_size, y_size, tail, device_api);
1745 }
1746 
gpu_tile(const VarOrRVar & x,const VarOrRVar & y,const VarOrRVar & z,const VarOrRVar & bx,const VarOrRVar & by,const VarOrRVar & bz,const VarOrRVar & tx,const VarOrRVar & ty,const VarOrRVar & tz,const Expr & x_size,const Expr & y_size,const Expr & z_size,TailStrategy tail,DeviceAPI device_api)1747 Stage &Stage::gpu_tile(const VarOrRVar &x, const VarOrRVar &y, const VarOrRVar &z,
1748                        const VarOrRVar &bx, const VarOrRVar &by, const VarOrRVar &bz,
1749                        const VarOrRVar &tx, const VarOrRVar &ty, const VarOrRVar &tz,
1750                        const Expr &x_size, const Expr &y_size, const Expr &z_size,
1751                        TailStrategy tail,
1752                        DeviceAPI device_api) {
1753     split(x, bx, tx, x_size, tail);
1754     split(y, by, ty, y_size, tail);
1755     split(z, bz, tz, z_size, tail);
1756     // current order is:
1757     // tx bx ty by tz bz
1758     reorder(ty, bx);
1759     // tx ty bx by tz bz
1760     reorder(tz, bx);
1761     // tx ty tz by bx bz
1762     reorder(bx, by);
1763     // tx ty tz bx by bz
1764     set_dim_device_api(bx, device_api);
1765     set_dim_device_api(by, device_api);
1766     set_dim_device_api(bz, device_api);
1767     set_dim_device_api(tx, device_api);
1768     set_dim_device_api(ty, device_api);
1769     set_dim_device_api(tz, device_api);
1770 
1771     set_dim_type(bx, ForType::GPUBlock);
1772     set_dim_type(by, ForType::GPUBlock);
1773     set_dim_type(bz, ForType::GPUBlock);
1774     set_dim_type(tx, ForType::GPUThread);
1775     set_dim_type(ty, ForType::GPUThread);
1776     set_dim_type(tz, ForType::GPUThread);
1777     return *this;
1778 }
1779 
gpu_tile(const VarOrRVar & x,const VarOrRVar & y,const VarOrRVar & z,const VarOrRVar & tx,const VarOrRVar & ty,const VarOrRVar & tz,const Expr & x_size,const Expr & y_size,const Expr & z_size,TailStrategy tail,DeviceAPI device_api)1780 Stage &Stage::gpu_tile(const VarOrRVar &x, const VarOrRVar &y, const VarOrRVar &z,
1781                        const VarOrRVar &tx, const VarOrRVar &ty, const VarOrRVar &tz,
1782                        const Expr &x_size, const Expr &y_size, const Expr &z_size,
1783                        TailStrategy tail,
1784                        DeviceAPI device_api) {
1785     return gpu_tile(x, y, z, x, y, z, tx, ty, tz, x_size, y_size, z_size, tail, device_api);
1786 }
1787 
hexagon(const VarOrRVar & x)1788 Stage &Stage::hexagon(const VarOrRVar &x) {
1789     set_dim_device_api(x, DeviceAPI::Hexagon);
1790     return *this;
1791 }
1792 
prefetch(const Func & f,const VarOrRVar & var,Expr offset,PrefetchBoundStrategy strategy)1793 Stage &Stage::prefetch(const Func &f, const VarOrRVar &var, Expr offset, PrefetchBoundStrategy strategy) {
1794     PrefetchDirective prefetch = {f.name(), var.name(), std::move(offset), strategy, Parameter()};
1795     definition.schedule().prefetches().push_back(prefetch);
1796     return *this;
1797 }
1798 
prefetch(const Internal::Parameter & param,const VarOrRVar & var,Expr offset,PrefetchBoundStrategy strategy)1799 Stage &Stage::prefetch(const Internal::Parameter &param, const VarOrRVar &var, Expr offset, PrefetchBoundStrategy strategy) {
1800     PrefetchDirective prefetch = {param.name(), var.name(), std::move(offset), strategy, param};
1801     definition.schedule().prefetches().push_back(prefetch);
1802     return *this;
1803 }
1804 
compute_with(LoopLevel loop_level,const map<string,LoopAlignStrategy> & align)1805 Stage &Stage::compute_with(LoopLevel loop_level, const map<string, LoopAlignStrategy> &align) {
1806     loop_level.lock();
1807     user_assert(!loop_level.is_inlined() && !loop_level.is_root())
1808         << "Undefined loop level to compute with\n";
1809     user_assert(loop_level.func() != function.name())
1810         << "Cannot schedule " << name() << " to be computed with "
1811         << loop_level.to_string() << "\n";
1812     user_assert(!function.has_extern_definition())
1813         << "compute_with() on extern Func " << name() << " is not allowed\n";
1814 
1815     // We have to mark the fuse level on the "original" definition (the one
1816     // without the specialization) to ensure there is no competing compute_with.
1817     Definition &original_def = (stage_index == 0) ? function.definition() : function.update(stage_index - 1);
1818     user_assert(original_def.specializations().empty())
1819         << "Func " << name() << " is scheduled to be computed with "
1820         << loop_level.func() << ", so it must not have any specializations.\n";
1821 
1822     FuseLoopLevel &fuse_level = original_def.schedule().fuse_level();
1823     if (!fuse_level.level.lock().is_inlined()) {
1824         user_warning << name() << " already has a compute_with at " << fuse_level.level.to_string()
1825                      << ". Replacing it with a new compute_with at " << loop_level.to_string() << "\n";
1826     }
1827     fuse_level.level = loop_level;
1828     fuse_level.align = align;
1829     return *this;
1830 }
1831 
compute_with(LoopLevel loop_level,const vector<pair<VarOrRVar,LoopAlignStrategy>> & align)1832 Stage &Stage::compute_with(LoopLevel loop_level, const vector<pair<VarOrRVar, LoopAlignStrategy>> &align) {
1833     map<string, LoopAlignStrategy> align_str;
1834     for (const auto &iter : align) {
1835         align_str.emplace(iter.first.name(), iter.second);
1836     }
1837     return compute_with(std::move(loop_level), align_str);
1838 }
1839 
compute_with(LoopLevel loop_level,LoopAlignStrategy align)1840 Stage &Stage::compute_with(LoopLevel loop_level, LoopAlignStrategy align) {
1841     map<string, LoopAlignStrategy> align_str = {{loop_level.lock().var().name(), align}};
1842     return compute_with(loop_level, align_str);
1843 }
1844 
compute_with(const Stage & s,const VarOrRVar & var,const vector<pair<VarOrRVar,LoopAlignStrategy>> & align)1845 Stage &Stage::compute_with(const Stage &s, const VarOrRVar &var, const vector<pair<VarOrRVar, LoopAlignStrategy>> &align) {
1846     return compute_with(LoopLevel(s.function, var, s.stage_index), align);
1847 }
1848 
compute_with(const Stage & s,const VarOrRVar & var,LoopAlignStrategy align)1849 Stage &Stage::compute_with(const Stage &s, const VarOrRVar &var, LoopAlignStrategy align) {
1850     return compute_with(LoopLevel(s.function, var, s.stage_index), align);
1851 }
1852 
1853 /** Attempt to get the source file and line where this stage was
1854  * defined by parsing the process's own debug symbols. Returns an
1855  * empty string if no debug symbols were found or the debug
1856  * symbols were not understood. Works on OS X and Linux only. */
source_location() const1857 std::string Stage::source_location() const {
1858     return definition.source_location();
1859 }
1860 
invalidate_cache()1861 void Func::invalidate_cache() {
1862     if (pipeline_.defined()) {
1863         pipeline_.invalidate_cache();
1864     }
1865 }
1866 
1867 namespace {
1868 
validate_wrapper(const string & name,const map<string,FunctionPtr> & wrappers,const vector<Func> & fs,const FunctionPtr & wrapper)1869 void validate_wrapper(const string &name, const map<string, FunctionPtr> &wrappers,
1870                       const vector<Func> &fs, const FunctionPtr &wrapper) {
1871     if (!wrappers.empty() && !fs.empty()) {
1872         internal_assert(wrapper.defined() && !name.empty());
1873         // Make sure all the other Funcs in 'fs' share the same wrapper and no
1874         // other Func not in 'fs' share the same wrapper
1875         for (const auto &it : wrappers) {
1876             if (it.first == fs[0].name()) {
1877                 continue;
1878             }
1879             const auto &fs_iter = std::find_if(
1880                 fs.begin(), fs.end(), [&it](const Func &f) { return f.name() == it.first; });
1881             bool in_fs = fs_iter != fs.end();
1882 
1883             if (in_fs) {
1884                 user_assert(it.second.same_as(wrapper))
1885                     << it.first << " should have shared the same wrapper as " << fs[0].name() << "\n";
1886             } else {
1887                 user_assert(!it.second.same_as(wrapper))
1888                     << "Redefinition of shared wrapper [" << name << " -> "
1889                     << Function(wrapper).name() << "] in " << fs[0].name() << " is illegal since "
1890                     << it.first << " shares the same wrapper but is not part of the redefinition\n";
1891             }
1892         }
1893     }
1894 }
1895 
create_in_wrapper(Function wrapped_fn,const string & wrapper_name)1896 Func create_in_wrapper(Function wrapped_fn, const string &wrapper_name) {
1897     Func wrapper(wrapped_fn.new_function_in_same_group(wrapper_name));
1898     vector<Var> args = Func(wrapped_fn).args();
1899     wrapper(args) = Func(wrapped_fn)(args);
1900     return wrapper;
1901 }
1902 
create_clone_wrapper(Function wrapped_fn,const string & wrapper_name)1903 Func create_clone_wrapper(Function wrapped_fn, const string &wrapper_name) {
1904     Func wrapper(wrapped_fn.new_function_in_same_group(wrapper_name));
1905     std::map<FunctionPtr, FunctionPtr> remapping;
1906     wrapped_fn.deep_copy(wrapper.name(), wrapper.function().get_contents(), remapping);
1907     // Fix up any self-references in the clone.
1908     FunctionPtr self_reference = wrapper.function().get_contents();
1909     self_reference.weaken();
1910     remapping.emplace(wrapped_fn.get_contents(), self_reference);
1911     wrapper.function().substitute_calls(remapping);
1912     return wrapper;
1913 }
1914 
get_wrapper(Function wrapped_fn,string wrapper_name,const vector<Func> & fs,bool clone)1915 Func get_wrapper(Function wrapped_fn, string wrapper_name, const vector<Func> &fs, bool clone) {
1916     // Either all Funcs in 'fs' have the same wrapper or they don't already
1917     // have any wrappers. Otherwise, throw an error. If 'fs' is empty, then
1918     // it is a global wrapper.
1919     const map<string, FunctionPtr> &wrappers = wrapped_fn.wrappers();
1920     wrapper_name += ("$" + std::to_string(wrappers.size()));
1921     const auto &iter = fs.empty() ? wrappers.find("") : wrappers.find(fs[0].name());
1922     if (iter == wrappers.end()) {
1923         // Make sure the other Funcs also don't have any wrappers
1924         for (size_t i = 1; i < fs.size(); ++i) {
1925             user_assert(wrappers.count(fs[i].name()) == 0)
1926                 << "Cannot define the wrapper since " << fs[i].name()
1927                 << " already has a wrapper while " << fs[0].name() << " doesn't \n";
1928         }
1929         Func wrapper = clone ? create_clone_wrapper(wrapped_fn, wrapper_name) : create_in_wrapper(wrapped_fn, wrapper_name);
1930         Function wrapper_fn = wrapper.function();
1931         if (fs.empty()) {
1932             // Add global wrapper
1933             wrapped_fn.add_wrapper("", wrapper_fn);
1934         } else {
1935             for (const Func &f : fs) {
1936                 user_assert(wrapped_fn.name() != f.name())
1937                     << "Cannot create wrapper of itself (\"" << wrapped_fn.name() << "\")\n";
1938                 wrapped_fn.add_wrapper(f.name(), wrapper_fn);
1939             }
1940         }
1941         return wrapper;
1942     }
1943     internal_assert(iter->second.defined());
1944     validate_wrapper(wrapped_fn.name(), wrappers, fs, iter->second);
1945 
1946     Function wrapper(iter->second);
1947     internal_assert(wrapper.frozen());
1948     return Func(wrapper);
1949 }
1950 
1951 }  // anonymous namespace
1952 
in(const Func & f)1953 Func Func::in(const Func &f) {
1954     invalidate_cache();
1955     vector<Func> fs = {f};
1956     return get_wrapper(func, name() + "_in_" + f.name(), fs, false);
1957 }
1958 
in(const vector<Func> & fs)1959 Func Func::in(const vector<Func> &fs) {
1960     if (fs.empty()) {
1961         user_error << "Could not create a in wrapper for an empty list of Funcs\n";
1962     }
1963     invalidate_cache();
1964     return get_wrapper(func, name() + "_wrapper", fs, false);
1965 }
1966 
in()1967 Func Func::in() {
1968     invalidate_cache();
1969     return get_wrapper(func, name() + "_global_wrapper", {}, false);
1970 }
1971 
clone_in(const Func & f)1972 Func Func::clone_in(const Func &f) {
1973     invalidate_cache();
1974     vector<Func> fs = {f};
1975     return get_wrapper(func, name() + "_clone_in_" + f.name(), fs, true);
1976 }
1977 
clone_in(const vector<Func> & fs)1978 Func Func::clone_in(const vector<Func> &fs) {
1979     if (fs.empty()) {
1980         user_error << "Could not create a clone wrapper for an empty list of Funcs\n";
1981     }
1982     invalidate_cache();
1983     return get_wrapper(func, name() + "_clone", fs, true);
1984 }
1985 
copy_to_device(DeviceAPI d)1986 Func Func::copy_to_device(DeviceAPI d) {
1987     user_assert(defined())
1988         << "copy_to_device on Func " << name() << " with no definition\n";
1989     user_assert(outputs() == 1)
1990         << "copy_to_device on a Tuple-valued Func " << name() << " not yet supported\n";
1991     user_assert(!has_update_definition())
1992         << "copy_to_device on Func " << name() << " with update definition\n";
1993     user_assert(!is_extern())
1994         << "copy_to_device on Func " << name() << " with extern definition\n";
1995 
1996     const Call *call = func.is_wrapper();
1997     user_assert(call)
1998         << "Func " << name() << " is scheduled as copy_to_host/device, "
1999         << "but has value: " << value() << "\n"
2000         << "Expected a single call to another Func with matching "
2001         << "dimensionality and argument order.\n";
2002 
2003     // Move the RHS value to the proxy slot
2004     func.extern_definition_proxy_expr() = value();
2005 
2006     // ... and delete the pure definition
2007     func.definition() = Definition();
2008 
2009     ExternFuncArgument buffer;
2010     if (call->call_type == Call::Halide) {
2011         buffer = call->func;
2012     } else if (call->image.defined()) {
2013         buffer = call->image;
2014     } else {
2015         internal_assert(call->param.defined());
2016         buffer = call->param;
2017     }
2018 
2019     ExternFuncArgument device_interface = make_device_interface_call(d);
2020     func.define_extern("halide_buffer_copy", {buffer, device_interface},
2021                        {call->type}, args(),  // Reuse the existing dimension names
2022                        NameMangling::C, d);
2023     return *this;
2024 }
2025 
copy_to_host()2026 Func Func::copy_to_host() {
2027     user_assert(defined())
2028         << "copy_to_host on Func " << name() << " with no definition\n";
2029     user_assert(outputs() == 1)
2030         << "copy_to_host on a Tuple-valued Func " << name() << " not yet supported\n";
2031     user_assert(!has_update_definition())
2032         << "copy_to_host on Func " << name() << " with update definition\n";
2033     user_assert(!is_extern())
2034         << "copy_to_host on Func " << name() << " with extern definition\n";
2035     return copy_to_device(DeviceAPI::Host);
2036 }
2037 
split(const VarOrRVar & old,const VarOrRVar & outer,const VarOrRVar & inner,const Expr & factor,TailStrategy tail)2038 Func &Func::split(const VarOrRVar &old, const VarOrRVar &outer, const VarOrRVar &inner, const Expr &factor, TailStrategy tail) {
2039     invalidate_cache();
2040     Stage(func, func.definition(), 0).split(old, outer, inner, factor, tail);
2041     return *this;
2042 }
2043 
fuse(const VarOrRVar & inner,const VarOrRVar & outer,const VarOrRVar & fused)2044 Func &Func::fuse(const VarOrRVar &inner, const VarOrRVar &outer, const VarOrRVar &fused) {
2045     invalidate_cache();
2046     Stage(func, func.definition(), 0).fuse(inner, outer, fused);
2047     return *this;
2048 }
2049 
rename(const VarOrRVar & old_name,const VarOrRVar & new_name)2050 Func &Func::rename(const VarOrRVar &old_name, const VarOrRVar &new_name) {
2051     invalidate_cache();
2052     Stage(func, func.definition(), 0).rename(old_name, new_name);
2053     return *this;
2054 }
2055 
allow_race_conditions()2056 Func &Func::allow_race_conditions() {
2057     invalidate_cache();
2058     Stage(func, func.definition(), 0).allow_race_conditions();
2059     return *this;
2060 }
2061 
atomic(bool override_associativity_test)2062 Func &Func::atomic(bool override_associativity_test) {
2063     invalidate_cache();
2064     Stage(func, func.definition(), 0).atomic(override_associativity_test);
2065     return *this;
2066 }
2067 
memoize()2068 Func &Func::memoize() {
2069     invalidate_cache();
2070     func.schedule().memoized() = true;
2071     return *this;
2072 }
2073 
store_in(MemoryType t)2074 Func &Func::store_in(MemoryType t) {
2075     invalidate_cache();
2076     func.schedule().memory_type() = t;
2077     return *this;
2078 }
2079 
async()2080 Func &Func::async() {
2081     invalidate_cache();
2082     func.schedule().async() = true;
2083     return *this;
2084 }
2085 
specialize(const Expr & c)2086 Stage Func::specialize(const Expr &c) {
2087     invalidate_cache();
2088     return Stage(func, func.definition(), 0).specialize(c);
2089 }
2090 
specialize_fail(const std::string & message)2091 void Func::specialize_fail(const std::string &message) {
2092     invalidate_cache();
2093     (void)Stage(func, func.definition(), 0).specialize_fail(message);
2094 }
2095 
serial(const VarOrRVar & var)2096 Func &Func::serial(const VarOrRVar &var) {
2097     invalidate_cache();
2098     Stage(func, func.definition(), 0).serial(var);
2099     return *this;
2100 }
2101 
parallel(const VarOrRVar & var)2102 Func &Func::parallel(const VarOrRVar &var) {
2103     invalidate_cache();
2104     Stage(func, func.definition(), 0).parallel(var);
2105     return *this;
2106 }
2107 
vectorize(const VarOrRVar & var)2108 Func &Func::vectorize(const VarOrRVar &var) {
2109     invalidate_cache();
2110     Stage(func, func.definition(), 0).vectorize(var);
2111     return *this;
2112 }
2113 
unroll(const VarOrRVar & var)2114 Func &Func::unroll(const VarOrRVar &var) {
2115     invalidate_cache();
2116     Stage(func, func.definition(), 0).unroll(var);
2117     return *this;
2118 }
2119 
parallel(const VarOrRVar & var,const Expr & factor,TailStrategy tail)2120 Func &Func::parallel(const VarOrRVar &var, const Expr &factor, TailStrategy tail) {
2121     invalidate_cache();
2122     Stage(func, func.definition(), 0).parallel(var, factor, tail);
2123     return *this;
2124 }
2125 
vectorize(const VarOrRVar & var,const Expr & factor,TailStrategy tail)2126 Func &Func::vectorize(const VarOrRVar &var, const Expr &factor, TailStrategy tail) {
2127     invalidate_cache();
2128     Stage(func, func.definition(), 0).vectorize(var, factor, tail);
2129     return *this;
2130 }
2131 
unroll(const VarOrRVar & var,const Expr & factor,TailStrategy tail)2132 Func &Func::unroll(const VarOrRVar &var, const Expr &factor, TailStrategy tail) {
2133     invalidate_cache();
2134     Stage(func, func.definition(), 0).unroll(var, factor, tail);
2135     return *this;
2136 }
2137 
bound(const Var & var,Expr min,Expr extent)2138 Func &Func::bound(const Var &var, Expr min, Expr extent) {
2139     user_assert(!min.defined() || Int(32).can_represent(min.type())) << "Can't represent min bound in int32\n";
2140     user_assert(extent.defined()) << "Extent bound of a Func can't be undefined\n";
2141     user_assert(Int(32).can_represent(extent.type())) << "Can't represent extent bound in int32\n";
2142 
2143     if (min.defined()) {
2144         min = cast<int32_t>(min);
2145     }
2146     extent = cast<int32_t>(extent);
2147 
2148     invalidate_cache();
2149     bool found = func.is_pure_arg(var.name());
2150     user_assert(found)
2151         << "Can't bound variable " << var.name()
2152         << " of function " << name()
2153         << " because " << var.name()
2154         << " is not one of the pure variables of " << name() << ".\n";
2155 
2156     Bound b = {var.name(), min, extent, Expr(), Expr()};
2157     func.schedule().bounds().push_back(b);
2158 
2159     // Propagate constant bounds into estimates as well.
2160     if (!is_const(min)) min = Expr();
2161     if (!is_const(extent)) extent = Expr();
2162     set_estimate(var, min, extent);
2163 
2164     return *this;
2165 }
2166 
set_estimate(const Var & var,const Expr & min,const Expr & extent)2167 Func &Func::set_estimate(const Var &var, const Expr &min, const Expr &extent) {
2168     invalidate_cache();
2169     bool found = func.is_pure_arg(var.name());
2170     user_assert(found)
2171         << "Can't provide an estimate on variable " << var.name()
2172         << " of function " << name()
2173         << " because " << var.name()
2174         << " is not one of the pure variables of " << name() << ".\n";
2175 
2176     Bound b = {var.name(), min, extent, Expr(), Expr()};
2177     func.schedule().estimates().push_back(b);
2178 
2179     // Propagate the estimate into the Parameter as well, so that
2180     // the values in the metadata will be correct.
2181     const auto &arg_names = func.args();
2182     int dim = -1;
2183     for (size_t i = 0; i < arg_names.size(); ++i) {
2184         if (arg_names[i] == var.name()) {
2185             dim = i;
2186             break;
2187         }
2188     }
2189     internal_assert(dim >= 0);
2190     for (auto param : func.output_buffers()) {
2191         if (min.defined()) {
2192             param.set_min_constraint_estimate(dim, min);
2193         }
2194         if (extent.defined()) {
2195             param.set_extent_constraint_estimate(dim, extent);
2196         }
2197     }
2198     return *this;
2199 }
2200 
set_estimates(const Region & estimates)2201 Func &Func::set_estimates(const Region &estimates) {
2202     const std::vector<Var> a = args();
2203     user_assert(estimates.size() == a.size())
2204         << "Func " << name() << " has " << a.size() << " dimensions, "
2205         << "but the estimates passed to set_estimates contains " << estimates.size() << " pairs.\n";
2206     for (size_t i = 0; i < a.size(); i++) {
2207         const Range &r = estimates[i];
2208         set_estimate(a[i], r.min, r.extent);
2209     }
2210     return *this;
2211 }
2212 
bound_extent(const Var & var,Expr extent)2213 Func &Func::bound_extent(const Var &var, Expr extent) {
2214     return bound(var, Expr(), std::move(extent));
2215 }
2216 
align_bounds(const Var & var,Expr modulus,Expr remainder)2217 Func &Func::align_bounds(const Var &var, Expr modulus, Expr remainder) {
2218     user_assert(modulus.defined()) << "modulus is undefined\n";
2219     user_assert(remainder.defined()) << "remainder is undefined\n";
2220     user_assert(Int(32).can_represent(modulus.type())) << "Can't represent modulus as int32\n";
2221     user_assert(Int(32).can_represent(remainder.type())) << "Can't represent remainder as int32\n";
2222 
2223     modulus = cast<int32_t>(modulus);
2224     remainder = cast<int32_t>(remainder);
2225 
2226     // Reduce the remainder
2227     remainder = remainder % modulus;
2228 
2229     invalidate_cache();
2230 
2231     bool found = func.is_pure_arg(var.name());
2232     user_assert(found)
2233         << "Can't align bounds of variable " << var.name()
2234         << " of function " << name()
2235         << " because " << var.name()
2236         << " is not one of the pure variables of " << name() << ".\n";
2237 
2238     Bound b = {var.name(), Expr(), Expr(), modulus, remainder};
2239     func.schedule().bounds().push_back(b);
2240     return *this;
2241 }
2242 
tile(const VarOrRVar & x,const VarOrRVar & y,const VarOrRVar & xo,const VarOrRVar & yo,const VarOrRVar & xi,const VarOrRVar & yi,const Expr & xfactor,const Expr & yfactor,TailStrategy tail)2243 Func &Func::tile(const VarOrRVar &x, const VarOrRVar &y,
2244                  const VarOrRVar &xo, const VarOrRVar &yo,
2245                  const VarOrRVar &xi, const VarOrRVar &yi,
2246                  const Expr &xfactor, const Expr &yfactor,
2247                  TailStrategy tail) {
2248     invalidate_cache();
2249     Stage(func, func.definition(), 0).tile(x, y, xo, yo, xi, yi, xfactor, yfactor, tail);
2250     return *this;
2251 }
2252 
tile(const VarOrRVar & x,const VarOrRVar & y,const VarOrRVar & xi,const VarOrRVar & yi,const Expr & xfactor,const Expr & yfactor,TailStrategy tail)2253 Func &Func::tile(const VarOrRVar &x, const VarOrRVar &y,
2254                  const VarOrRVar &xi, const VarOrRVar &yi,
2255                  const Expr &xfactor, const Expr &yfactor,
2256                  TailStrategy tail) {
2257     invalidate_cache();
2258     Stage(func, func.definition(), 0).tile(x, y, xi, yi, xfactor, yfactor, tail);
2259     return *this;
2260 }
2261 
tile(const std::vector<VarOrRVar> & previous,const std::vector<VarOrRVar> & outers,const std::vector<VarOrRVar> & inners,const std::vector<Expr> & factors,TailStrategy tail)2262 Func &Func::tile(const std::vector<VarOrRVar> &previous,
2263                  const std::vector<VarOrRVar> &outers,
2264                  const std::vector<VarOrRVar> &inners,
2265                  const std::vector<Expr> &factors,
2266                  TailStrategy tail) {
2267     Stage(func, func.definition(), 0).tile(previous, outers, inners, factors, tail);
2268     return *this;
2269 }
2270 
tile(const std::vector<VarOrRVar> & previous,const std::vector<VarOrRVar> & inners,const std::vector<Expr> & factors,TailStrategy tail)2271 Func &Func::tile(const std::vector<VarOrRVar> &previous,
2272                  const std::vector<VarOrRVar> &inners,
2273                  const std::vector<Expr> &factors,
2274                  TailStrategy tail) {
2275     Stage(func, func.definition(), 0).tile(previous, inners, factors, tail);
2276     return *this;
2277 }
2278 
tile(const std::vector<VarOrRVar> & previous,const std::vector<VarOrRVar> & outers,const std::vector<VarOrRVar> & inners,const std::vector<Expr> & factors,const std::vector<TailStrategy> & tails)2279 Func &Func::tile(const std::vector<VarOrRVar> &previous,
2280                  const std::vector<VarOrRVar> &outers,
2281                  const std::vector<VarOrRVar> &inners,
2282                  const std::vector<Expr> &factors,
2283                  const std::vector<TailStrategy> &tails) {
2284     Stage(func, func.definition(), 0).tile(previous, outers, inners, factors, tails);
2285     return *this;
2286 }
2287 
reorder(const std::vector<VarOrRVar> & vars)2288 Func &Func::reorder(const std::vector<VarOrRVar> &vars) {
2289     invalidate_cache();
2290     Stage(func, func.definition(), 0).reorder(vars);
2291     return *this;
2292 }
2293 
gpu_threads(const VarOrRVar & tx,DeviceAPI device_api)2294 Func &Func::gpu_threads(const VarOrRVar &tx, DeviceAPI device_api) {
2295     invalidate_cache();
2296     Stage(func, func.definition(), 0).gpu_threads(tx, device_api);
2297     return *this;
2298 }
2299 
gpu_threads(const VarOrRVar & tx,const VarOrRVar & ty,DeviceAPI device_api)2300 Func &Func::gpu_threads(const VarOrRVar &tx, const VarOrRVar &ty, DeviceAPI device_api) {
2301     invalidate_cache();
2302     Stage(func, func.definition(), 0).gpu_threads(tx, ty, device_api);
2303     return *this;
2304 }
2305 
gpu_threads(const VarOrRVar & tx,const VarOrRVar & ty,const VarOrRVar & tz,DeviceAPI device_api)2306 Func &Func::gpu_threads(const VarOrRVar &tx, const VarOrRVar &ty, const VarOrRVar &tz, DeviceAPI device_api) {
2307     invalidate_cache();
2308     Stage(func, func.definition(), 0).gpu_threads(tx, ty, tz, device_api);
2309     return *this;
2310 }
2311 
gpu_lanes(const VarOrRVar & tx,DeviceAPI device_api)2312 Func &Func::gpu_lanes(const VarOrRVar &tx, DeviceAPI device_api) {
2313     invalidate_cache();
2314     Stage(func, func.definition(), 0).gpu_lanes(tx, device_api);
2315     return *this;
2316 }
2317 
gpu_blocks(const VarOrRVar & bx,DeviceAPI device_api)2318 Func &Func::gpu_blocks(const VarOrRVar &bx, DeviceAPI device_api) {
2319     invalidate_cache();
2320     Stage(func, func.definition(), 0).gpu_blocks(bx, device_api);
2321     return *this;
2322 }
2323 
gpu_blocks(const VarOrRVar & bx,const VarOrRVar & by,DeviceAPI device_api)2324 Func &Func::gpu_blocks(const VarOrRVar &bx, const VarOrRVar &by, DeviceAPI device_api) {
2325     invalidate_cache();
2326     Stage(func, func.definition(), 0).gpu_blocks(bx, by, device_api);
2327     return *this;
2328 }
2329 
gpu_blocks(const VarOrRVar & bx,const VarOrRVar & by,const VarOrRVar & bz,DeviceAPI device_api)2330 Func &Func::gpu_blocks(const VarOrRVar &bx, const VarOrRVar &by, const VarOrRVar &bz, DeviceAPI device_api) {
2331     invalidate_cache();
2332     Stage(func, func.definition(), 0).gpu_blocks(bx, by, bz, device_api);
2333     return *this;
2334 }
2335 
gpu_single_thread(DeviceAPI device_api)2336 Func &Func::gpu_single_thread(DeviceAPI device_api) {
2337     invalidate_cache();
2338     Stage(func, func.definition(), 0).gpu_single_thread(device_api);
2339     return *this;
2340 }
2341 
gpu(const VarOrRVar & bx,const VarOrRVar & tx,DeviceAPI device_api)2342 Func &Func::gpu(const VarOrRVar &bx, const VarOrRVar &tx, DeviceAPI device_api) {
2343     invalidate_cache();
2344     Stage(func, func.definition(), 0).gpu(bx, tx, device_api);
2345     return *this;
2346 }
2347 
gpu(const VarOrRVar & bx,const VarOrRVar & by,const VarOrRVar & tx,const VarOrRVar & ty,DeviceAPI device_api)2348 Func &Func::gpu(const VarOrRVar &bx, const VarOrRVar &by, const VarOrRVar &tx, const VarOrRVar &ty, DeviceAPI device_api) {
2349     invalidate_cache();
2350     Stage(func, func.definition(), 0).gpu(bx, by, tx, ty, device_api);
2351     return *this;
2352 }
2353 
gpu(const VarOrRVar & bx,const VarOrRVar & by,const VarOrRVar & bz,const VarOrRVar & tx,const VarOrRVar & ty,const VarOrRVar & tz,DeviceAPI device_api)2354 Func &Func::gpu(const VarOrRVar &bx, const VarOrRVar &by, const VarOrRVar &bz, const VarOrRVar &tx, const VarOrRVar &ty, const VarOrRVar &tz, DeviceAPI device_api) {
2355     invalidate_cache();
2356     Stage(func, func.definition(), 0).gpu(bx, by, bz, tx, ty, tz, device_api);
2357     return *this;
2358 }
2359 
gpu_tile(const VarOrRVar & x,const VarOrRVar & bx,const VarOrRVar & tx,const Expr & x_size,TailStrategy tail,DeviceAPI device_api)2360 Func &Func::gpu_tile(const VarOrRVar &x, const VarOrRVar &bx, const VarOrRVar &tx, const Expr &x_size, TailStrategy tail, DeviceAPI device_api) {
2361     invalidate_cache();
2362     Stage(func, func.definition(), 0).gpu_tile(x, bx, tx, x_size, tail, device_api);
2363     return *this;
2364 }
2365 
gpu_tile(const VarOrRVar & x,const VarOrRVar & tx,const Expr & x_size,TailStrategy tail,DeviceAPI device_api)2366 Func &Func::gpu_tile(const VarOrRVar &x, const VarOrRVar &tx, const Expr &x_size, TailStrategy tail, DeviceAPI device_api) {
2367     invalidate_cache();
2368     Stage(func, func.definition(), 0).gpu_tile(x, tx, x_size, tail, device_api);
2369     return *this;
2370 }
2371 
gpu_tile(const VarOrRVar & x,const VarOrRVar & y,const VarOrRVar & bx,const VarOrRVar & by,const VarOrRVar & tx,const VarOrRVar & ty,const Expr & x_size,const Expr & y_size,TailStrategy tail,DeviceAPI device_api)2372 Func &Func::gpu_tile(const VarOrRVar &x, const VarOrRVar &y,
2373                      const VarOrRVar &bx, const VarOrRVar &by,
2374                      const VarOrRVar &tx, const VarOrRVar &ty,
2375                      const Expr &x_size, const Expr &y_size,
2376                      TailStrategy tail,
2377                      DeviceAPI device_api) {
2378     invalidate_cache();
2379     Stage(func, func.definition(), 0)
2380         .gpu_tile(x, y, bx, by, tx, ty, x_size, y_size, tail, device_api);
2381     return *this;
2382 }
2383 
gpu_tile(const VarOrRVar & x,const VarOrRVar & y,const VarOrRVar & tx,const VarOrRVar & ty,const Expr & x_size,const Expr & y_size,TailStrategy tail,DeviceAPI device_api)2384 Func &Func::gpu_tile(const VarOrRVar &x, const VarOrRVar &y,
2385                      const VarOrRVar &tx, const VarOrRVar &ty,
2386                      const Expr &x_size, const Expr &y_size,
2387                      TailStrategy tail,
2388                      DeviceAPI device_api) {
2389     invalidate_cache();
2390     Stage(func, func.definition(), 0)
2391         .gpu_tile(x, y, tx, ty, x_size, y_size, tail, device_api);
2392     return *this;
2393 }
2394 
gpu_tile(const VarOrRVar & x,const VarOrRVar & y,const VarOrRVar & z,const VarOrRVar & bx,const VarOrRVar & by,const VarOrRVar & bz,const VarOrRVar & tx,const VarOrRVar & ty,const VarOrRVar & tz,const Expr & x_size,const Expr & y_size,const Expr & z_size,TailStrategy tail,DeviceAPI device_api)2395 Func &Func::gpu_tile(const VarOrRVar &x, const VarOrRVar &y, const VarOrRVar &z,
2396                      const VarOrRVar &bx, const VarOrRVar &by, const VarOrRVar &bz,
2397                      const VarOrRVar &tx, const VarOrRVar &ty, const VarOrRVar &tz,
2398                      const Expr &x_size, const Expr &y_size, const Expr &z_size,
2399                      TailStrategy tail,
2400                      DeviceAPI device_api) {
2401     invalidate_cache();
2402     Stage(func, func.definition(), 0)
2403         .gpu_tile(x, y, z, bx, by, bz, tx, ty, tz, x_size, y_size, z_size, tail, device_api);
2404     return *this;
2405 }
2406 
gpu_tile(const VarOrRVar & x,const VarOrRVar & y,const VarOrRVar & z,const VarOrRVar & tx,const VarOrRVar & ty,const VarOrRVar & tz,const Expr & x_size,const Expr & y_size,const Expr & z_size,TailStrategy tail,DeviceAPI device_api)2407 Func &Func::gpu_tile(const VarOrRVar &x, const VarOrRVar &y, const VarOrRVar &z,
2408                      const VarOrRVar &tx, const VarOrRVar &ty, const VarOrRVar &tz,
2409                      const Expr &x_size, const Expr &y_size, const Expr &z_size,
2410                      TailStrategy tail,
2411                      DeviceAPI device_api) {
2412     invalidate_cache();
2413     Stage(func, func.definition(), 0)
2414         .gpu_tile(x, y, z, tx, ty, tz, x_size, y_size, z_size, tail, device_api);
2415     return *this;
2416 }
2417 
shader(const Var & x,const Var & y,const Var & c,DeviceAPI device_api)2418 Func &Func::shader(const Var &x, const Var &y, const Var &c, DeviceAPI device_api) {
2419     invalidate_cache();
2420 
2421     reorder(c, x, y);
2422     // GLSL outputs must be stored interleaved
2423     reorder_storage(c, x, y);
2424 
2425     // TODO: Set appropriate constraints if this is the output buffer?
2426 
2427     Stage(func, func.definition(), 0).gpu_blocks(x, y, device_api);
2428 
2429     bool constant_bounds = false;
2430     FuncSchedule &sched = func.schedule();
2431     for (size_t i = 0; i < sched.bounds().size(); i++) {
2432         if (c.name() == sched.bounds()[i].var) {
2433             constant_bounds = is_const(sched.bounds()[i].min) &&
2434                               is_const(sched.bounds()[i].extent);
2435             break;
2436         }
2437     }
2438     user_assert(constant_bounds)
2439         << "The color channel for image loops must have constant bounds, e.g., .bound(c, 0, 3).\n";
2440     return *this;
2441 }
2442 
glsl(const Var & x,const Var & y,const Var & c)2443 Func &Func::glsl(const Var &x, const Var &y, const Var &c) {
2444     return shader(x, y, c, DeviceAPI::GLSL).vectorize(c);
2445 }
2446 
hexagon(const VarOrRVar & x)2447 Func &Func::hexagon(const VarOrRVar &x) {
2448     invalidate_cache();
2449     Stage(func, func.definition(), 0).hexagon(x);
2450     return *this;
2451 }
2452 
prefetch(const Func & f,const VarOrRVar & var,Expr offset,PrefetchBoundStrategy strategy)2453 Func &Func::prefetch(const Func &f, const VarOrRVar &var, Expr offset, PrefetchBoundStrategy strategy) {
2454     invalidate_cache();
2455     Stage(func, func.definition(), 0).prefetch(f, var, std::move(offset), strategy);
2456     return *this;
2457 }
2458 
prefetch(const Internal::Parameter & param,const VarOrRVar & var,Expr offset,PrefetchBoundStrategy strategy)2459 Func &Func::prefetch(const Internal::Parameter &param, const VarOrRVar &var, Expr offset, PrefetchBoundStrategy strategy) {
2460     invalidate_cache();
2461     Stage(func, func.definition(), 0).prefetch(param, var, std::move(offset), strategy);
2462     return *this;
2463 }
2464 
reorder_storage(const Var & x,const Var & y)2465 Func &Func::reorder_storage(const Var &x, const Var &y) {
2466     invalidate_cache();
2467 
2468     user_assert(x.name() != y.name())
2469         << "In schedule for " << name()
2470         << ", call to reorder_storage references "
2471         << x.name() << " twice\n";
2472 
2473     vector<StorageDim> &dims = func.schedule().storage_dims();
2474     bool found_y = false;
2475     size_t y_loc = 0;
2476     for (size_t i = 0; i < dims.size(); i++) {
2477         if (var_name_match(dims[i].var, y.name())) {
2478             found_y = true;
2479             y_loc = i;
2480         } else if (var_name_match(dims[i].var, x.name())) {
2481             if (found_y) std::swap(dims[i], dims[y_loc]);
2482             return *this;
2483         }
2484     }
2485     user_error << "In schedule for " << name()
2486                << ", could not find variables " << x.name()
2487                << " and " << y.name() << " to reorder.\n"
2488                << dump_dim_list(dims);
2489     return *this;
2490 }
2491 
reorder_storage(const std::vector<Var> & dims,size_t start)2492 Func &Func::reorder_storage(const std::vector<Var> &dims, size_t start) {
2493     // Reorder the first dimension with respect to all others, then
2494     // recursively reorder all remaining dimensions.
2495     for (size_t i = start + 1; i < dims.size(); i++) {
2496         reorder_storage(dims[start], dims[i]);
2497     }
2498     if ((dims.size() - start) > 2) {
2499         reorder_storage(dims, start + 1);
2500     }
2501     return *this;
2502 }
2503 
reorder_storage(const std::vector<Var> & dims)2504 Func &Func::reorder_storage(const std::vector<Var> &dims) {
2505     user_assert(dims.size() > 1) << "reorder_storage must have at least two dimensions in reorder list.\n";
2506 
2507     return reorder_storage(dims, 0);
2508 }
2509 
align_storage(const Var & dim,const Expr & alignment)2510 Func &Func::align_storage(const Var &dim, const Expr &alignment) {
2511     invalidate_cache();
2512 
2513     vector<StorageDim> &dims = func.schedule().storage_dims();
2514     for (size_t i = 0; i < dims.size(); i++) {
2515         if (var_name_match(dims[i].var, dim.name())) {
2516             dims[i].alignment = alignment;
2517             return *this;
2518         }
2519     }
2520     user_error << "In schedule for " << name()
2521                << ", could not find var " << dim.name()
2522                << " to align the storage of.\n"
2523                << dump_dim_list(func.schedule().storage_dims());
2524     return *this;
2525 }
2526 
fold_storage(const Var & dim,const Expr & factor,bool fold_forward)2527 Func &Func::fold_storage(const Var &dim, const Expr &factor, bool fold_forward) {
2528     invalidate_cache();
2529 
2530     vector<StorageDim> &dims = func.schedule().storage_dims();
2531     for (size_t i = 0; i < dims.size(); i++) {
2532         if (var_name_match(dims[i].var, dim.name())) {
2533             dims[i].fold_factor = factor;
2534             dims[i].fold_forward = fold_forward;
2535             return *this;
2536         }
2537     }
2538     user_error << "In schedule for " << name()
2539                << ", could not find var " << dim.name()
2540                << " to fold the storage of.\n"
2541                << dump_dim_list(func.schedule().storage_dims());
2542     return *this;
2543 }
2544 
compute_at(LoopLevel loop_level)2545 Func &Func::compute_at(LoopLevel loop_level) {
2546     invalidate_cache();
2547     func.schedule().compute_level() = std::move(loop_level);
2548     // We want to set store_level = compute_level iff store_level is inlined,
2549     // but we can't do that here, since the value in store_level could
2550     // be mutated at any time prior to lowering. Instead, we check at
2551     // the start of lowering (via Function::lock_loop_levels() method) and
2552     // do the compute_level -> store_level propagation then.
2553     return *this;
2554 }
2555 
compute_at(const Func & f,const RVar & var)2556 Func &Func::compute_at(const Func &f, const RVar &var) {
2557     return compute_at(LoopLevel(f, var));
2558 }
2559 
compute_at(const Func & f,const Var & var)2560 Func &Func::compute_at(const Func &f, const Var &var) {
2561     return compute_at(LoopLevel(f, var));
2562 }
2563 
compute_with(const Stage & s,const VarOrRVar & var,const vector<pair<VarOrRVar,LoopAlignStrategy>> & align)2564 Func &Func::compute_with(const Stage &s, const VarOrRVar &var, const vector<pair<VarOrRVar, LoopAlignStrategy>> &align) {
2565     invalidate_cache();
2566     Stage(func, func.definition(), 0).compute_with(s, var, align);
2567     return *this;
2568 }
2569 
compute_with(const Stage & s,const VarOrRVar & var,LoopAlignStrategy align)2570 Func &Func::compute_with(const Stage &s, const VarOrRVar &var, LoopAlignStrategy align) {
2571     invalidate_cache();
2572     Stage(func, func.definition(), 0).compute_with(s, var, align);
2573     return *this;
2574 }
2575 
compute_with(LoopLevel loop_level,const std::vector<std::pair<VarOrRVar,LoopAlignStrategy>> & align)2576 Func &Func::compute_with(LoopLevel loop_level, const std::vector<std::pair<VarOrRVar, LoopAlignStrategy>> &align) {
2577     invalidate_cache();
2578     Stage(func, func.definition(), 0).compute_with(std::move(loop_level), align);
2579     return *this;
2580 }
2581 
compute_with(LoopLevel loop_level,LoopAlignStrategy align)2582 Func &Func::compute_with(LoopLevel loop_level, LoopAlignStrategy align) {
2583     invalidate_cache();
2584     Stage(func, func.definition(), 0).compute_with(std::move(loop_level), align);
2585     return *this;
2586 }
2587 
compute_root()2588 Func &Func::compute_root() {
2589     return compute_at(LoopLevel::root());
2590 }
2591 
store_at(LoopLevel loop_level)2592 Func &Func::store_at(LoopLevel loop_level) {
2593     invalidate_cache();
2594     func.schedule().store_level() = std::move(loop_level);
2595     return *this;
2596 }
2597 
store_at(const Func & f,const RVar & var)2598 Func &Func::store_at(const Func &f, const RVar &var) {
2599     return store_at(LoopLevel(f, var));
2600 }
2601 
store_at(const Func & f,const Var & var)2602 Func &Func::store_at(const Func &f, const Var &var) {
2603     return store_at(LoopLevel(f, var));
2604 }
2605 
store_root()2606 Func &Func::store_root() {
2607     return store_at(LoopLevel::root());
2608 }
2609 
compute_inline()2610 Func &Func::compute_inline() {
2611     return compute_at(LoopLevel::inlined());
2612 }
2613 
trace_loads()2614 Func &Func::trace_loads() {
2615     invalidate_cache();
2616     func.trace_loads();
2617     return *this;
2618 }
2619 
trace_stores()2620 Func &Func::trace_stores() {
2621     invalidate_cache();
2622     func.trace_stores();
2623     return *this;
2624 }
2625 
trace_realizations()2626 Func &Func::trace_realizations() {
2627     invalidate_cache();
2628     func.trace_realizations();
2629     return *this;
2630 }
2631 
add_trace_tag(const std::string & trace_tag)2632 Func &Func::add_trace_tag(const std::string &trace_tag) {
2633     invalidate_cache();
2634     func.add_trace_tag(trace_tag);
2635     return *this;
2636 }
2637 
debug_to_file(const string & filename)2638 void Func::debug_to_file(const string &filename) {
2639     invalidate_cache();
2640     func.debug_file() = filename;
2641 }
2642 
update(int idx)2643 Stage Func::update(int idx) {
2644     user_assert(idx < num_update_definitions()) << "Call to update with index larger than last defined update stage for Func \"" << name() << "\".\n";
2645     invalidate_cache();
2646     return Stage(func, func.update(idx), idx + 1);
2647 }
2648 
operator Stage() const2649 Func::operator Stage() const {
2650     user_assert(!func.has_extern_definition())
2651         << "Extern func \"" << name() << "\" cannot be converted into Stage\n";
2652     return Stage(func, func.definition(), 0);
2653 }
2654 
2655 namespace {
2656 class CountImplicitVars : public Internal::IRGraphVisitor {
2657 public:
2658     int count;
2659 
CountImplicitVars(const vector<Expr> & e)2660     CountImplicitVars(const vector<Expr> &e)
2661         : count(0) {
2662         for (size_t i = 0; i < e.size(); i++) {
2663             e[i].accept(this);
2664         }
2665     }
2666 
2667     using IRGraphVisitor::visit;
2668 
visit(const Variable * v)2669     void visit(const Variable *v) override {
2670         int index = Var::implicit_index(v->name);
2671         if (index != -1) {
2672             if (index >= count) count = index + 1;
2673         }
2674     }
2675 };
2676 }  // namespace
2677 
FuncRef(const Internal::Function & f,const vector<Expr> & a,int placeholder_pos,int count)2678 FuncRef::FuncRef(const Internal::Function &f, const vector<Expr> &a, int placeholder_pos,
2679                  int count)
2680     : func(f), implicit_count(count), args(a) {
2681     implicit_placeholder_pos = placeholder_pos;
2682     Internal::check_call_arg_types(f.name(), &args, args.size());
2683 }
2684 
FuncRef(Internal::Function f,const vector<Var> & a,int placeholder_pos,int count)2685 FuncRef::FuncRef(Internal::Function f, const vector<Var> &a, int placeholder_pos,
2686                  int count)
2687     : func(std::move(f)), implicit_count(count) {
2688     implicit_placeholder_pos = placeholder_pos;
2689     args.resize(a.size());
2690     for (size_t i = 0; i < a.size(); i++) {
2691         args[i] = a[i];
2692     }
2693 }
2694 
args_with_implicit_vars(const vector<Expr> & e) const2695 vector<Expr> FuncRef::args_with_implicit_vars(const vector<Expr> &e) const {
2696     vector<Expr> a = args;
2697 
2698     for (size_t i = 0; i < a.size(); i++) {
2699         user_assert(a[i].defined())
2700             << "Argument " << (i + 1) << " in call to \"" << func.name() << "\" is undefined.\n";
2701     }
2702     for (size_t i = 0; i < e.size(); i++) {
2703         user_assert(e[i].defined())
2704             << "Value " << (i + 1) << " in definition of \"" << func.name() << "\" is undefined.\n";
2705     }
2706 
2707     CountImplicitVars count(e);
2708     for (size_t i = 0; i < a.size(); i++) {
2709         a[i].accept(&count);
2710     }
2711 
2712     if (count.count > 0) {
2713         if (func.has_pure_definition()) {
2714             // If the func already has pure definition, the number of implicit
2715             // vars in the RHS can only be at most the number of implicit vars
2716             // in the LHS.
2717             user_assert(implicit_count >= count.count)
2718                 << "The update definition of " << func.name() << " uses " << count.count
2719                 << " implicit variables, but the initial definition uses only "
2720                 << implicit_count << " implicit variables.\n";
2721         } else if (implicit_placeholder_pos != -1) {
2722             internal_assert(implicit_count == 0)
2723                 << "Pure definition can't possibly already have implicit variables defined\n";
2724 
2725             Internal::debug(2) << "Adding " << count.count << " implicit vars to LHS of " << func.name() << "\n";
2726 
2727             vector<Expr>::iterator iter = a.begin() + implicit_placeholder_pos;
2728             for (int i = 0; i < count.count; i++) {
2729                 iter = a.insert(iter, Var::implicit(i));
2730                 iter++;
2731             }
2732         }
2733     }
2734 
2735     // Check the implicit vars in the RHS also exist in the LHS
2736     for (int i = 0; i < count.count; i++) {
2737         Var v = Var::implicit(i);
2738         bool found = false;
2739         for (size_t j = 0; j < a.size(); j++) {
2740             if (const Variable *arg = a[j].as<Variable>()) {
2741                 if (arg->name == v.name()) {
2742                     found = true;
2743                 }
2744             }
2745         }
2746         user_assert(found)
2747             << "Right-hand-side of update definition of " << func.name()
2748             << " uses implicit variables, but the left-hand-side does not"
2749             << " contain the placeholder symbol '_'.\n";
2750     }
2751 
2752     return a;
2753 }
2754 
operator =(const Expr & e)2755 Stage FuncRef::operator=(const Expr &e) {
2756     return (*this) = Tuple(e);
2757 }
2758 
operator =(const Tuple & e)2759 Stage FuncRef::operator=(const Tuple &e) {
2760     if (!func.has_pure_definition()) {
2761         for (size_t i = 0; i < args.size(); ++i) {
2762             const Variable *var = args[i].as<Variable>();
2763             user_assert((var != nullptr) && (!var->reduction_domain.defined()))
2764                 << "Argument " << (i + 1) << " in initial definition of \""
2765                 << func.name() << "\" is not a Var.\n";
2766         }
2767 
2768         // Find implicit args in the expr and add them to the args list before calling define
2769         vector<Expr> expanded_args = args_with_implicit_vars(e.as_vector());
2770         vector<string> expanded_args_str(expanded_args.size());
2771         for (size_t i = 0; i < expanded_args.size(); ++i) {
2772             const Variable *v = expanded_args[i].as<Variable>();
2773             internal_assert(v);
2774             expanded_args_str[i] = v->name;
2775         }
2776         func.define(expanded_args_str, e.as_vector());
2777         return Stage(func, func.definition(), 0);
2778     } else {
2779         func.define_update(args, e.as_vector());
2780 
2781         size_t update_stage = func.updates().size() - 1;
2782         return Stage(func, func.update(update_stage), update_stage);
2783     }
2784 }
2785 
operator =(const FuncRef & e)2786 Stage FuncRef::operator=(const FuncRef &e) {
2787     if (e.size() == 1) {
2788         return (*this) = Expr(e);
2789     } else {
2790         return (*this) = Tuple(e);
2791     }
2792 }
2793 
2794 // Inject a suitable base-case definition given an update
2795 // definition. This is a helper for FuncRef::operator+= and co.
define_base_case(const Internal::Function & func,const vector<Expr> & a,const Tuple & e)2796 Func define_base_case(const Internal::Function &func, const vector<Expr> &a, const Tuple &e) {
2797     Func f(func);
2798 
2799     if (func.has_pure_definition()) return f;
2800     vector<Var> pure_args(a.size());
2801 
2802     // Reuse names of existing pure args
2803     for (size_t i = 0; i < a.size(); i++) {
2804         if (const Variable *v = a[i].as<Variable>()) {
2805             if (!v->param.defined()) {
2806                 pure_args[i] = Var(v->name);
2807             }
2808         } else {
2809             pure_args[i] = Var();
2810         }
2811     }
2812 
2813     f(pure_args) = e;
2814     return f;
2815 }
2816 
define_base_case(const Internal::Function & func,const vector<Expr> & a,const Expr & e)2817 Func define_base_case(const Internal::Function &func, const vector<Expr> &a, const Expr &e) {
2818     return define_base_case(func, a, Tuple(e));
2819 }
2820 
2821 template<typename BinaryOp>
func_ref_update(const Tuple & e,int init_val)2822 Stage FuncRef::func_ref_update(const Tuple &e, int init_val) {
2823     internal_assert(e.size() > 1);
2824 
2825     vector<Expr> init_values(e.size());
2826     for (int i = 0; i < (int)init_values.size(); ++i) {
2827         init_values[i] = cast(e[i].type(), init_val);
2828     }
2829     vector<Expr> expanded_args = args_with_implicit_vars(e.as_vector());
2830     FuncRef self_ref = define_base_case(func, expanded_args, Tuple(init_values))(expanded_args);
2831 
2832     vector<Expr> values(e.size());
2833     for (int i = 0; i < (int)values.size(); ++i) {
2834         values[i] = BinaryOp()(self_ref[i], e[i]);
2835     }
2836     return self_ref = Tuple(values);
2837 }
2838 
2839 template<typename BinaryOp>
func_ref_update(Expr e,int init_val)2840 Stage FuncRef::func_ref_update(Expr e, int init_val) {
2841     vector<Expr> expanded_args = args_with_implicit_vars({e});
2842     FuncRef self_ref = define_base_case(func, expanded_args, cast(e.type(), init_val))(expanded_args);
2843     return self_ref = BinaryOp()(Expr(self_ref), e);
2844 }
2845 
operator +=(Expr e)2846 Stage FuncRef::operator+=(Expr e) {
2847     return func_ref_update<std::plus<Expr>>(std::move(e), 0);
2848 }
2849 
operator +=(const Tuple & e)2850 Stage FuncRef::operator+=(const Tuple &e) {
2851     if (e.size() == 1) {
2852         return (*this) += e[0];
2853     } else {
2854         return func_ref_update<std::plus<Expr>>(e, 0);
2855     }
2856 }
2857 
operator +=(const FuncRef & e)2858 Stage FuncRef::operator+=(const FuncRef &e) {
2859     if (e.size() == 1) {
2860         return (*this) += Expr(e);
2861     } else {
2862         return (*this) += Tuple(e);
2863     }
2864 }
2865 
operator *=(Expr e)2866 Stage FuncRef::operator*=(Expr e) {
2867     return func_ref_update<std::multiplies<Expr>>(std::move(e), 1);
2868 }
2869 
operator *=(const Tuple & e)2870 Stage FuncRef::operator*=(const Tuple &e) {
2871     if (e.size() == 1) {
2872         return (*this) *= e[0];
2873     } else {
2874         return func_ref_update<std::multiplies<Expr>>(e, 1);
2875     }
2876 }
2877 
operator *=(const FuncRef & e)2878 Stage FuncRef::operator*=(const FuncRef &e) {
2879     if (e.size() == 1) {
2880         return (*this) *= Expr(e);
2881     } else {
2882         return (*this) *= Tuple(e);
2883     }
2884 }
2885 
operator -=(Expr e)2886 Stage FuncRef::operator-=(Expr e) {
2887     return func_ref_update<std::minus<Expr>>(std::move(e), 0);
2888 }
2889 
operator -=(const Tuple & e)2890 Stage FuncRef::operator-=(const Tuple &e) {
2891     if (e.size() == 1) {
2892         return (*this) -= e[0];
2893     } else {
2894         return func_ref_update<std::minus<Expr>>(e, 0);
2895     }
2896 }
2897 
operator -=(const FuncRef & e)2898 Stage FuncRef::operator-=(const FuncRef &e) {
2899     if (e.size() == 1) {
2900         return (*this) -= Expr(e);
2901     } else {
2902         return (*this) -= Tuple(e);
2903     }
2904 }
2905 
operator /=(Expr e)2906 Stage FuncRef::operator/=(Expr e) {
2907     return func_ref_update<std::divides<Expr>>(std::move(e), 1);
2908 }
2909 
operator /=(const Tuple & e)2910 Stage FuncRef::operator/=(const Tuple &e) {
2911     if (e.size() == 1) {
2912         return (*this) /= e[0];
2913     } else {
2914         return func_ref_update<std::divides<Expr>>(e, 1);
2915     }
2916 }
2917 
operator /=(const FuncRef & e)2918 Stage FuncRef::operator/=(const FuncRef &e) {
2919     if (e.size() == 1) {
2920         return (*this) /= Expr(e);
2921     } else {
2922         return (*this) /= Tuple(e);
2923     }
2924 }
2925 
operator Expr() const2926 FuncRef::operator Expr() const {
2927     user_assert(func.has_pure_definition() || func.has_extern_definition())
2928         << "Can't call Func \"" << func.name() << "\" because it has not yet been defined.\n";
2929 
2930     user_assert(func.outputs() == 1)
2931         << "Can't convert a reference Func \"" << func.name()
2932         << "\" to an Expr, because " << func.name() << " returns a Tuple.\n";
2933 
2934     return Call::make(func, args);
2935 }
2936 
operator [](int i) const2937 FuncTupleElementRef FuncRef::operator[](int i) const {
2938     user_assert(func.has_pure_definition() || func.has_extern_definition())
2939         << "Can't call Func \"" << func.name() << "\" because it has not yet been defined.\n";
2940 
2941     user_assert(func.outputs() != 1)
2942         << "Can't index into a reference to Func \"" << func.name()
2943         << "\", because it does not return a Tuple.\n";
2944 
2945     user_assert(i >= 0 && i < func.outputs())
2946         << "Tuple index out of range in reference to Func \"" << func.name() << "\".\n";
2947 
2948     return FuncTupleElementRef(*this, args, i);
2949 }
2950 
size() const2951 size_t FuncRef::size() const {
2952     return func.outputs();
2953 }
2954 
FuncTupleElementRef(const FuncRef & ref,const std::vector<Expr> & args,int idx)2955 FuncTupleElementRef::FuncTupleElementRef(
2956     const FuncRef &ref, const std::vector<Expr> &args, int idx)
2957     : func_ref(ref), args(args), idx(idx) {
2958     internal_assert(func_ref.size() > 1)
2959         << "Func " << ref.function().name() << " does not return a Tuple\n";
2960     internal_assert(idx >= 0 && idx < (int)func_ref.size());
2961 }
2962 
values_with_undefs(const Expr & e) const2963 Tuple FuncTupleElementRef::values_with_undefs(const Expr &e) const {
2964     vector<Expr> values(func_ref.size());
2965     for (int i = 0; i < (int)values.size(); ++i) {
2966         if (i == idx) {
2967             values[i] = e;
2968         } else {
2969             Type t = func_ref.function().values()[i].type();
2970             values[i] = undef(t);
2971         }
2972     }
2973     return Tuple(values);
2974 }
2975 
operator =(const Expr & e)2976 Stage FuncTupleElementRef::operator=(const Expr &e) {
2977     return func_ref = values_with_undefs(e);
2978 }
2979 
operator +=(const Expr & e)2980 Stage FuncTupleElementRef::operator+=(const Expr &e) {
2981     return func_ref += values_with_undefs(e);
2982 }
2983 
operator *=(const Expr & e)2984 Stage FuncTupleElementRef::operator*=(const Expr &e) {
2985     return func_ref *= values_with_undefs(e);
2986 }
2987 
operator -=(const Expr & e)2988 Stage FuncTupleElementRef::operator-=(const Expr &e) {
2989     return func_ref -= values_with_undefs(e);
2990 }
2991 
operator /=(const Expr & e)2992 Stage FuncTupleElementRef::operator/=(const Expr &e) {
2993     return func_ref /= values_with_undefs(e);
2994 }
2995 
operator =(const FuncRef & e)2996 Stage FuncTupleElementRef::operator=(const FuncRef &e) {
2997     return func_ref = values_with_undefs(e);
2998 }
2999 
operator Expr() const3000 FuncTupleElementRef::operator Expr() const {
3001     return Internal::Call::make(func_ref.function(), args, idx);
3002 }
3003 
realize(std::vector<int32_t> sizes,const Target & target,const ParamMap & param_map)3004 Realization Func::realize(std::vector<int32_t> sizes, const Target &target,
3005                           const ParamMap &param_map) {
3006     user_assert(defined()) << "Can't realize undefined Func.\n";
3007     return pipeline().realize(std::move(sizes), target, param_map);
3008 }
3009 
realize(int x_size,int y_size,int z_size,int w_size,const Target & target,const ParamMap & param_map)3010 Realization Func::realize(int x_size, int y_size, int z_size, int w_size, const Target &target,
3011                           const ParamMap &param_map) {
3012     return realize({x_size, y_size, z_size, w_size}, target, param_map);
3013 }
3014 
realize(int x_size,int y_size,int z_size,const Target & target,const ParamMap & param_map)3015 Realization Func::realize(int x_size, int y_size, int z_size, const Target &target,
3016                           const ParamMap &param_map) {
3017     return realize({x_size, y_size, z_size}, target, param_map);
3018 }
3019 
realize(int x_size,int y_size,const Target & target,const ParamMap & param_map)3020 Realization Func::realize(int x_size, int y_size, const Target &target,
3021                           const ParamMap &param_map) {
3022     return realize({x_size, y_size}, target, param_map);
3023 }
3024 
realize(int x_size,const Target & target,const ParamMap & param_map)3025 Realization Func::realize(int x_size, const Target &target,
3026                           const ParamMap &param_map) {
3027     return realize(std::vector<int>{x_size}, target, param_map);
3028 }
3029 
realize(const Target & target,const ParamMap & param_map)3030 Realization Func::realize(const Target &target,
3031                           const ParamMap &param_map) {
3032     return realize(std::vector<int>{}, target, param_map);
3033 }
3034 
infer_input_bounds(int x_size,int y_size,int z_size,int w_size,const Target & target,const ParamMap & param_map)3035 void Func::infer_input_bounds(int x_size, int y_size, int z_size, int w_size,
3036                               const Target &target,
3037                               const ParamMap &param_map) {
3038     vector<int32_t> sizes;
3039     if (x_size) sizes.push_back(x_size);
3040     if (y_size) sizes.push_back(y_size);
3041     if (z_size) sizes.push_back(z_size);
3042     if (w_size) sizes.push_back(w_size);
3043     infer_input_bounds(sizes, target, param_map);
3044 }
3045 
infer_input_bounds(const std::vector<int32_t> & sizes,const Target & target,const ParamMap & param_map)3046 void Func::infer_input_bounds(const std::vector<int32_t> &sizes,
3047                               const Target &target,
3048                               const ParamMap &param_map) {
3049     user_assert(defined()) << "Can't infer input bounds on an undefined Func.\n";
3050     vector<Buffer<>> outputs(func.outputs());
3051     for (size_t i = 0; i < outputs.size(); i++) {
3052         Buffer<> im(func.output_types()[i], nullptr, sizes);
3053         outputs[i] = std::move(im);
3054     }
3055     Realization r(outputs);
3056     infer_input_bounds(r, target, param_map);
3057 }
3058 
output_buffer() const3059 OutputImageParam Func::output_buffer() const {
3060     user_assert(defined())
3061         << "Can't access output buffer of undefined Func.\n";
3062     user_assert(func.output_buffers().size() == 1)
3063         << "Can't call Func::output_buffer on Func \"" << name()
3064         << "\" because it returns a Tuple.\n";
3065     return OutputImageParam(func.output_buffers()[0], Argument::OutputBuffer, *this);
3066 }
3067 
output_buffers() const3068 vector<OutputImageParam> Func::output_buffers() const {
3069     user_assert(defined())
3070         << "Can't access output buffers of undefined Func.\n";
3071 
3072     vector<OutputImageParam> bufs(func.output_buffers().size());
3073     for (size_t i = 0; i < bufs.size(); i++) {
3074         bufs[i] = OutputImageParam(func.output_buffers()[i], Argument::OutputBuffer, *this);
3075     }
3076     return bufs;
3077 }
3078 
operator ExternFuncArgument() const3079 Func::operator ExternFuncArgument() const {
3080     return ExternFuncArgument(func);
3081 }
3082 
pipeline()3083 Pipeline Func::pipeline() {
3084     if (!pipeline_.defined()) {
3085         pipeline_ = Pipeline(*this);
3086     }
3087     internal_assert(pipeline_.defined());
3088     return pipeline_;
3089 }
3090 
infer_arguments() const3091 vector<Argument> Func::infer_arguments() const {
3092     return Pipeline(*this).infer_arguments();
3093 }
3094 
source_location() const3095 std::string Func::source_location() const {
3096     user_assert(defined()) << "A Func with no definition has no source_location\n";
3097     return func.definition().source_location();
3098 }
3099 
compile_to_module(const vector<Argument> & args,const std::string & fn_name,const Target & target)3100 Module Func::compile_to_module(const vector<Argument> &args, const std::string &fn_name, const Target &target) {
3101     return pipeline().compile_to_module(args, fn_name, target);
3102 }
3103 
compile_to(const map<Output,string> & output_files,const vector<Argument> & args,const string & fn_name,const Target & target)3104 void Func::compile_to(const map<Output, string> &output_files,
3105                       const vector<Argument> &args,
3106                       const string &fn_name,
3107                       const Target &target) {
3108     pipeline().compile_to(output_files, args, fn_name, target);
3109 }
3110 
compile_to_bitcode(const string & filename,const vector<Argument> & args,const string & fn_name,const Target & target)3111 void Func::compile_to_bitcode(const string &filename, const vector<Argument> &args, const string &fn_name,
3112                               const Target &target) {
3113     pipeline().compile_to_bitcode(filename, args, fn_name, target);
3114 }
3115 
compile_to_bitcode(const string & filename,const vector<Argument> & args,const Target & target)3116 void Func::compile_to_bitcode(const string &filename, const vector<Argument> &args,
3117                               const Target &target) {
3118     pipeline().compile_to_bitcode(filename, args, "", target);
3119 }
3120 
compile_to_llvm_assembly(const string & filename,const vector<Argument> & args,const string & fn_name,const Target & target)3121 void Func::compile_to_llvm_assembly(const string &filename, const vector<Argument> &args, const string &fn_name,
3122                                     const Target &target) {
3123     pipeline().compile_to_llvm_assembly(filename, args, fn_name, target);
3124 }
3125 
compile_to_llvm_assembly(const string & filename,const vector<Argument> & args,const Target & target)3126 void Func::compile_to_llvm_assembly(const string &filename, const vector<Argument> &args,
3127                                     const Target &target) {
3128     pipeline().compile_to_llvm_assembly(filename, args, "", target);
3129 }
3130 
compile_to_object(const string & filename,const vector<Argument> & args,const string & fn_name,const Target & target)3131 void Func::compile_to_object(const string &filename, const vector<Argument> &args,
3132                              const string &fn_name, const Target &target) {
3133     pipeline().compile_to_object(filename, args, fn_name, target);
3134 }
3135 
compile_to_object(const string & filename,const vector<Argument> & args,const Target & target)3136 void Func::compile_to_object(const string &filename, const vector<Argument> &args,
3137                              const Target &target) {
3138     pipeline().compile_to_object(filename, args, "", target);
3139 }
3140 
compile_to_header(const string & filename,const vector<Argument> & args,const string & fn_name,const Target & target)3141 void Func::compile_to_header(const string &filename, const vector<Argument> &args,
3142                              const string &fn_name, const Target &target) {
3143     pipeline().compile_to_header(filename, args, fn_name, target);
3144 }
3145 
compile_to_c(const string & filename,const vector<Argument> & args,const string & fn_name,const Target & target)3146 void Func::compile_to_c(const string &filename, const vector<Argument> &args,
3147                         const string &fn_name, const Target &target) {
3148     pipeline().compile_to_c(filename, args, fn_name, target);
3149 }
3150 
compile_to_lowered_stmt(const string & filename,const vector<Argument> & args,StmtOutputFormat fmt,const Target & target)3151 void Func::compile_to_lowered_stmt(const string &filename,
3152                                    const vector<Argument> &args,
3153                                    StmtOutputFormat fmt,
3154                                    const Target &target) {
3155     pipeline().compile_to_lowered_stmt(filename, args, fmt, target);
3156 }
3157 
print_loop_nest()3158 void Func::print_loop_nest() {
3159     pipeline().print_loop_nest();
3160 }
3161 
compile_to_file(const string & filename_prefix,const vector<Argument> & args,const std::string & fn_name,const Target & target)3162 void Func::compile_to_file(const string &filename_prefix,
3163                            const vector<Argument> &args,
3164                            const std::string &fn_name,
3165                            const Target &target) {
3166     pipeline().compile_to_file(filename_prefix, args, fn_name, target);
3167 }
3168 
compile_to_static_library(const string & filename_prefix,const vector<Argument> & args,const std::string & fn_name,const Target & target)3169 void Func::compile_to_static_library(const string &filename_prefix,
3170                                      const vector<Argument> &args,
3171                                      const std::string &fn_name,
3172                                      const Target &target) {
3173     pipeline().compile_to_static_library(filename_prefix, args, fn_name, target);
3174 }
3175 
compile_to_multitarget_static_library(const std::string & filename_prefix,const std::vector<Argument> & args,const std::vector<Target> & targets)3176 void Func::compile_to_multitarget_static_library(const std::string &filename_prefix,
3177                                                  const std::vector<Argument> &args,
3178                                                  const std::vector<Target> &targets) {
3179     pipeline().compile_to_multitarget_static_library(filename_prefix, args, targets);
3180 }
3181 
compile_to_multitarget_object_files(const std::string & filename_prefix,const std::vector<Argument> & args,const std::vector<Target> & targets,const std::vector<std::string> & suffixes)3182 void Func::compile_to_multitarget_object_files(const std::string &filename_prefix,
3183                                                const std::vector<Argument> &args,
3184                                                const std::vector<Target> &targets,
3185                                                const std::vector<std::string> &suffixes) {
3186     pipeline().compile_to_multitarget_object_files(filename_prefix, args, targets, suffixes);
3187 }
3188 
compile_to_assembly(const string & filename,const vector<Argument> & args,const string & fn_name,const Target & target)3189 void Func::compile_to_assembly(const string &filename, const vector<Argument> &args, const string &fn_name,
3190                                const Target &target) {
3191     pipeline().compile_to_assembly(filename, args, fn_name, target);
3192 }
3193 
compile_to_assembly(const string & filename,const vector<Argument> & args,const Target & target)3194 void Func::compile_to_assembly(const string &filename, const vector<Argument> &args, const Target &target) {
3195     pipeline().compile_to_assembly(filename, args, "", target);
3196 }
3197 
3198 // JIT-related code
3199 
set_error_handler(void (* handler)(void *,const char *))3200 void Func::set_error_handler(void (*handler)(void *, const char *)) {
3201     pipeline().set_error_handler(handler);
3202 }
3203 
set_custom_allocator(void * (* cust_malloc)(void *,size_t),void (* cust_free)(void *,void *))3204 void Func::set_custom_allocator(void *(*cust_malloc)(void *, size_t),
3205                                 void (*cust_free)(void *, void *)) {
3206     pipeline().set_custom_allocator(cust_malloc, cust_free);
3207 }
3208 
set_custom_do_par_for(int (* cust_do_par_for)(void *,int (*)(void *,int,uint8_t *),int,int,uint8_t *))3209 void Func::set_custom_do_par_for(int (*cust_do_par_for)(void *, int (*)(void *, int, uint8_t *), int, int, uint8_t *)) {
3210     pipeline().set_custom_do_par_for(cust_do_par_for);
3211 }
3212 
set_custom_do_task(int (* cust_do_task)(void *,int (*)(void *,int,uint8_t *),int,uint8_t *))3213 void Func::set_custom_do_task(int (*cust_do_task)(void *, int (*)(void *, int, uint8_t *), int, uint8_t *)) {
3214     pipeline().set_custom_do_task(cust_do_task);
3215 }
3216 
set_custom_trace(int (* trace_fn)(void *,const halide_trace_event_t *))3217 void Func::set_custom_trace(int (*trace_fn)(void *, const halide_trace_event_t *)) {
3218     pipeline().set_custom_trace(trace_fn);
3219 }
3220 
set_custom_print(void (* cust_print)(void *,const char *))3221 void Func::set_custom_print(void (*cust_print)(void *, const char *)) {
3222     pipeline().set_custom_print(cust_print);
3223 }
3224 
add_custom_lowering_pass(IRMutator * pass,std::function<void ()> deleter)3225 void Func::add_custom_lowering_pass(IRMutator *pass, std::function<void()> deleter) {
3226     pipeline().add_custom_lowering_pass(pass, std::move(deleter));
3227 }
3228 
clear_custom_lowering_passes()3229 void Func::clear_custom_lowering_passes() {
3230     pipeline().clear_custom_lowering_passes();
3231 }
3232 
custom_lowering_passes()3233 const vector<CustomLoweringPass> &Func::custom_lowering_passes() {
3234     return pipeline().custom_lowering_passes();
3235 }
3236 
jit_handlers()3237 const Internal::JITHandlers &Func::jit_handlers() {
3238     return pipeline().jit_handlers();
3239 }
3240 
realize(Pipeline::RealizationArg outputs,const Target & target,const ParamMap & param_map)3241 void Func::realize(Pipeline::RealizationArg outputs, const Target &target,
3242                    const ParamMap &param_map) {
3243     pipeline().realize(std::move(outputs), target, param_map);
3244 }
3245 
infer_input_bounds(Pipeline::RealizationArg outputs,const Target & target,const ParamMap & param_map)3246 void Func::infer_input_bounds(Pipeline::RealizationArg outputs, const Target &target,
3247                               const ParamMap &param_map) {
3248     pipeline().infer_input_bounds(std::move(outputs), target, param_map);
3249 }
3250 
compile_jit(const Target & target)3251 void Func::compile_jit(const Target &target) {
3252     pipeline().compile_jit(target);
3253 }
3254 
3255 }  // namespace Halide
3256