1 #include <atomic>
2 #include <memory>
3 #include <set>
4 #include <stdlib.h>
5 #include <utility>
6 
7 #include "CSE.h"
8 #include "Func.h"
9 #include "Function.h"
10 #include "IR.h"
11 #include "IREquality.h"
12 #include "IRMutator.h"
13 #include "IROperator.h"
14 #include "IRPrinter.h"
15 #include "ParallelRVar.h"
16 #include "Random.h"
17 #include "Scope.h"
18 #include "Var.h"
19 
20 namespace Halide {
21 namespace Internal {
22 
23 using std::map;
24 using std::pair;
25 using std::string;
26 using std::vector;
27 
28 typedef map<FunctionPtr, FunctionPtr> DeepCopyMap;
29 
30 struct FunctionContents;
31 
32 namespace {
33 // Weaken all the references to a particular Function to break
34 // reference cycles. Also count the number of references found.
35 class WeakenFunctionPtrs : public IRMutator {
36     using IRMutator::visit;
37 
visit(const Call * c)38     Expr visit(const Call *c) override {
39         Expr expr = IRMutator::visit(c);
40         c = expr.as<Call>();
41         internal_assert(c);
42         if (c->func.defined() &&
43             c->func.get() == func) {
44             FunctionPtr ptr = c->func;
45             ptr.weaken();
46             expr = Call::make(c->type, c->name, c->args, c->call_type,
47                               ptr, c->value_index,
48                               c->image, c->param);
49             count++;
50         }
51         return expr;
52     }
53     FunctionContents *func;
54 
55 public:
56     int count = 0;
WeakenFunctionPtrs(FunctionContents * f)57     WeakenFunctionPtrs(FunctionContents *f)
58         : func(f) {
59     }
60 };
61 }  // namespace
62 
63 struct FunctionContents {
64     std::string name;
65     std::string origin_name;
66     std::vector<Type> output_types;
67 
68     // The names of the dimensions of the Function. Corresponds to the
69     // LHS of the pure definition if there is one. Is also the initial
70     // stage of the dims and storage_dims. Used to identify dimensions
71     // of the Function by name.
72     std::vector<string> args;
73 
74     // Function-specific schedule. This schedule is applied to all stages
75     // within the function.
76     FuncSchedule func_schedule;
77 
78     Definition init_def;
79     std::vector<Definition> updates;
80 
81     std::string debug_file;
82 
83     std::vector<Parameter> output_buffers;
84 
85     std::vector<ExternFuncArgument> extern_arguments;
86     std::string extern_function_name;
87 
88     NameMangling extern_mangling = NameMangling::Default;
89     DeviceAPI extern_function_device_api = DeviceAPI::Host;
90     Expr extern_proxy_expr;
91 
92     bool trace_loads = false, trace_stores = false, trace_realizations = false;
93     std::vector<string> trace_tags;
94 
95     bool frozen = false;
96 
acceptHalide::Internal::FunctionContents97     void accept(IRVisitor *visitor) const {
98         func_schedule.accept(visitor);
99 
100         if (init_def.defined()) {
101             init_def.accept(visitor);
102         }
103         for (const Definition &def : updates) {
104             def.accept(visitor);
105         }
106 
107         if (!extern_function_name.empty()) {
108             for (ExternFuncArgument i : extern_arguments) {
109                 if (i.is_func()) {
110                     user_assert(i.func.get() != this)
111                         << "Extern Func has itself as an argument";
112                     i.func->accept(visitor);
113                 } else if (i.is_expr()) {
114                     i.expr.accept(visitor);
115                 }
116             }
117             if (extern_proxy_expr.defined()) {
118                 extern_proxy_expr.accept(visitor);
119             }
120         }
121 
122         for (Parameter i : output_buffers) {
123             for (size_t j = 0; j < args.size(); j++) {
124                 if (i.min_constraint(j).defined()) {
125                     i.min_constraint(j).accept(visitor);
126                 }
127                 if (i.stride_constraint(j).defined()) {
128                     i.stride_constraint(j).accept(visitor);
129                 }
130                 if (i.extent_constraint(j).defined()) {
131                     i.extent_constraint(j).accept(visitor);
132                 }
133             }
134         }
135     }
136 
137     // Pass an IRMutator through to all Exprs referenced in the FunctionContents
mutateHalide::Internal::FunctionContents138     void mutate(IRMutator *mutator) {
139         func_schedule.mutate(mutator);
140 
141         if (init_def.defined()) {
142             init_def.mutate(mutator);
143         }
144         for (Definition &def : updates) {
145             def.mutate(mutator);
146         }
147 
148         if (!extern_function_name.empty()) {
149             for (ExternFuncArgument &i : extern_arguments) {
150                 if (i.is_expr()) {
151                     i.expr = mutator->mutate(i.expr);
152                 }
153             }
154             extern_proxy_expr = mutator->mutate(extern_proxy_expr);
155         }
156     }
157 };
158 
159 struct FunctionGroup {
160     mutable RefCount ref_count;
161     vector<FunctionContents> members;
162 };
163 
get() const164 FunctionContents *FunctionPtr::get() const {
165     return &(group()->members[idx]);
166 }
167 
168 template<>
ref_count(const FunctionGroup * f)169 RefCount &ref_count<FunctionGroup>(const FunctionGroup *f) noexcept {
170     return f->ref_count;
171 }
172 
173 template<>
destroy(const FunctionGroup * f)174 void destroy<FunctionGroup>(const FunctionGroup *f) {
175     delete f;
176 }
177 
178 // All variables present in any part of a function definition must
179 // either be pure args, elements of the reduction domain, parameters
180 // (i.e. attached to some Parameter object), or part of a let node
181 // internal to the expression
182 struct CheckVars : public IRGraphVisitor {
183     vector<string> pure_args;
184     ReductionDomain reduction_domain;
185     Scope<> defined_internally;
186     const std::string name;
187     bool unbound_reduction_vars_ok = false;
188 
CheckVarsHalide::Internal::CheckVars189     CheckVars(const std::string &n)
190         : name(n) {
191     }
192 
193     using IRVisitor::visit;
194 
visitHalide::Internal::CheckVars195     void visit(const Let *let) override {
196         let->value.accept(this);
197         ScopedBinding<> bind(defined_internally, let->name);
198         let->body.accept(this);
199     }
200 
visitHalide::Internal::CheckVars201     void visit(const Call *op) override {
202         IRGraphVisitor::visit(op);
203         if (op->name == name && op->call_type == Call::Halide) {
204             for (size_t i = 0; i < op->args.size(); i++) {
205                 const Variable *var = op->args[i].as<Variable>();
206                 if (!pure_args[i].empty()) {
207                     user_assert(var && var->name == pure_args[i])
208                         << "In definition of Func \"" << name << "\":\n"
209                         << "All of a function's recursive references to itself"
210                         << " must contain the same pure variables in the same"
211                         << " places as on the left-hand-side.\n";
212                 }
213             }
214         }
215     }
216 
visitHalide::Internal::CheckVars217     void visit(const Variable *var) override {
218         // Is it a parameter?
219         if (var->param.defined()) return;
220 
221         // Was it defined internally by a let expression?
222         if (defined_internally.contains(var->name)) return;
223 
224         // Is it a pure argument?
225         for (size_t i = 0; i < pure_args.size(); i++) {
226             if (var->name == pure_args[i]) return;
227         }
228 
229         // Is it in a reduction domain?
230         if (var->reduction_domain.defined()) {
231             if (!reduction_domain.defined()) {
232                 reduction_domain = var->reduction_domain;
233                 return;
234             } else if (var->reduction_domain.same_as(reduction_domain)) {
235                 // It's in a reduction domain we already know about
236                 return;
237             } else {
238                 user_error << "Multiple reduction domains found in definition of Func \"" << name << "\"\n";
239             }
240         } else if (reduction_domain.defined() && unbound_reduction_vars_ok) {
241             // Is it one of the RVars from the reduction domain we already
242             // know about (this can happen in the RDom predicate).
243             for (const ReductionVariable &rv : reduction_domain.domain()) {
244                 if (rv.var == var->name) {
245                     return;
246                 }
247             }
248         }
249 
250         user_error << "Undefined variable \"" << var->name << "\" in definition of Func \"" << name << "\"\n";
251     }
252 };
253 
254 // Mark all functions found in an expr as frozen.
255 class FreezeFunctions : public IRGraphVisitor {
256     using IRGraphVisitor::visit;
257 
258     const string &func;
259 
visit(const Call * op)260     void visit(const Call *op) override {
261         IRGraphVisitor::visit(op);
262         if (op->call_type == Call::Halide &&
263             op->func.defined() &&
264             op->name != func) {
265             Function f(op->func);
266             f.freeze();
267         }
268     }
269 
270 public:
FreezeFunctions(const string & f)271     FreezeFunctions(const string &f)
272         : func(f) {
273     }
274 };
275 
276 // A counter to use in tagging random variables
277 namespace {
278 static std::atomic<int> rand_counter{0};
279 }
280 
Function()281 Function::Function() {
282 }
283 
Function(const FunctionPtr & ptr)284 Function::Function(const FunctionPtr &ptr)
285     : contents(ptr) {
286     contents.strengthen();
287     internal_assert(ptr.defined())
288         << "Can't construct Function from undefined FunctionContents ptr\n";
289 }
290 
Function(const std::string & n)291 Function::Function(const std::string &n) {
292     for (size_t i = 0; i < n.size(); i++) {
293         user_assert(n[i] != '.')
294             << "Func name \"" << n << "\" is invalid. "
295             << "Func names may not contain the character '.', "
296             << "as it is used internally by Halide as a separator\n";
297     }
298     contents.strong = new FunctionGroup;
299     contents.strong->members.resize(1);
300     contents->name = n;
301     contents->origin_name = n;
302 }
303 
304 // Return deep-copy of ExternFuncArgument 'src'
deep_copy_extern_func_argument_helper(const ExternFuncArgument & src,DeepCopyMap & copied_map)305 ExternFuncArgument deep_copy_extern_func_argument_helper(
306     const ExternFuncArgument &src, DeepCopyMap &copied_map) {
307     ExternFuncArgument copy;
308     copy.arg_type = src.arg_type;
309     copy.buffer = src.buffer;
310     copy.expr = src.expr;
311     copy.image_param = src.image_param;
312 
313     if (!src.func.defined()) {  // No need to deep-copy the func if it's undefined
314         internal_assert(!src.is_func())
315             << "ExternFuncArgument has type FuncArg but has no function definition\n";
316         return copy;
317     }
318 
319     // If the FunctionContents has already been deep-copied previously, i.e.
320     // it's in the 'copied_map', use the deep-copied version from the map instead
321     // of creating a new deep-copy
322     FunctionPtr &copied_func = copied_map[src.func];
323     internal_assert(copied_func.defined());
324     copy.func = copied_func;
325     return copy;
326 }
327 
deep_copy(const FunctionPtr & copy,DeepCopyMap & copied_map) const328 void Function::deep_copy(const FunctionPtr &copy, DeepCopyMap &copied_map) const {
329     internal_assert(copy.defined() && contents.defined())
330         << "Cannot deep-copy undefined Function\n";
331 
332     // Add reference to this Function's deep-copy to the map in case of
333     // self-reference, e.g. self-reference in an Definition.
334     copied_map[contents] = copy;
335 
336     debug(4) << "Deep-copy function contents: \"" << contents->name << "\"\n";
337 
338     copy->name = contents->name;
339     copy->origin_name = contents->origin_name;
340     copy->args = contents->args;
341     copy->output_types = contents->output_types;
342     copy->debug_file = contents->debug_file;
343     copy->extern_function_name = contents->extern_function_name;
344     copy->extern_mangling = contents->extern_mangling;
345     copy->extern_function_device_api = contents->extern_function_device_api;
346     copy->extern_proxy_expr = contents->extern_proxy_expr;
347     copy->trace_loads = contents->trace_loads;
348     copy->trace_stores = contents->trace_stores;
349     copy->trace_realizations = contents->trace_realizations;
350     copy->trace_tags = contents->trace_tags;
351     copy->frozen = contents->frozen;
352     copy->output_buffers = contents->output_buffers;
353     copy->func_schedule = contents->func_schedule.deep_copy(copied_map);
354 
355     // Copy the pure definition
356     if (contents->init_def.defined()) {
357         copy->init_def = contents->init_def.get_copy();
358         internal_assert(copy->init_def.is_init());
359         internal_assert(copy->init_def.schedule().rvars().empty())
360             << "Init definition shouldn't have reduction domain\n";
361     }
362 
363     for (const Definition &def : contents->updates) {
364         internal_assert(!def.is_init());
365         Definition def_copy = def.get_copy();
366         internal_assert(!def_copy.is_init());
367         copy->updates.push_back(std::move(def_copy));
368     }
369 
370     for (const ExternFuncArgument &e : contents->extern_arguments) {
371         ExternFuncArgument e_copy = deep_copy_extern_func_argument_helper(e, copied_map);
372         copy->extern_arguments.push_back(std::move(e_copy));
373     }
374 }
375 
deep_copy(string name,const FunctionPtr & copy,DeepCopyMap & copied_map) const376 void Function::deep_copy(string name, const FunctionPtr &copy, DeepCopyMap &copied_map) const {
377     deep_copy(copy, copied_map);
378     copy->name = std::move(name);
379 }
380 
define(const vector<string> & args,vector<Expr> values)381 void Function::define(const vector<string> &args, vector<Expr> values) {
382     user_assert(!frozen())
383         << "Func " << name() << " cannot be given a new pure definition, "
384         << "because it has already been realized or used in the definition of another Func.\n";
385     user_assert(!has_extern_definition())
386         << "In pure definition of Func \"" << name() << "\":\n"
387         << "Func with extern definition cannot be given a pure definition.\n";
388     user_assert(!name().empty()) << "A Func may not have an empty name.\n";
389     for (size_t i = 0; i < values.size(); i++) {
390         user_assert(values[i].defined())
391             << "In pure definition of Func \"" << name() << "\":\n"
392             << "Undefined expression in right-hand-side of definition.\n";
393     }
394 
395     // Make sure all the vars in the value are either args or are
396     // attached to some parameter
397     CheckVars check(name());
398     check.pure_args = args;
399     for (const auto &value : values) {
400         value.accept(&check);
401     }
402 
403     // Freeze all called functions
404     FreezeFunctions freezer(name());
405     for (const auto &value : values) {
406         value.accept(&freezer);
407     }
408 
409     // Make sure all the vars in the args have unique non-empty names
410     for (size_t i = 0; i < args.size(); i++) {
411         user_assert(!args[i].empty())
412             << "In pure definition of Func \"" << name() << "\":\n"
413             << "In left-hand-side of definition, argument "
414             << i << " has an empty name.\n";
415         for (size_t j = 0; j < i; j++) {
416             user_assert(args[i] != args[j])
417                 << "In pure definition of Func \"" << name() << "\":\n"
418                 << "In left-hand-side of definition, arguments "
419                 << i << " and " << j
420                 << " both have the name \"" + args[i] + "\"\n";
421         }
422     }
423 
424     for (auto &value : values) {
425         value = common_subexpression_elimination(value);
426     }
427 
428     // Tag calls to random() with the free vars
429     int tag = rand_counter++;
430     vector<VarOrRVar> free_vars;
431     free_vars.reserve(args.size());
432     for (const auto &arg : args) {
433         free_vars.emplace_back(Var(arg));
434     }
435     for (auto &value : values) {
436         value = lower_random(value, free_vars, tag);
437     }
438 
439     user_assert(!check.reduction_domain.defined())
440         << "In pure definition of Func \"" << name() << "\":\n"
441         << "Reduction domain referenced in pure function definition.\n";
442 
443     if (!contents.defined()) {
444         contents.strong = new FunctionGroup;
445         contents.strong->members.resize(1);
446         contents->name = unique_name('f');
447         contents->origin_name = contents->name;
448     }
449 
450     user_assert(!contents->init_def.defined())
451         << "In pure definition of Func \"" << name() << "\":\n"
452         << "Func is already defined.\n";
453 
454     contents->args = args;
455 
456     std::vector<Expr> init_def_args;
457     init_def_args.resize(args.size());
458     for (size_t i = 0; i < args.size(); i++) {
459         init_def_args[i] = Var(args[i]);
460     }
461 
462     ReductionDomain rdom;
463     contents->init_def = Definition(init_def_args, values, rdom, true);
464 
465     for (size_t i = 0; i < args.size(); i++) {
466         Dim d = {args[i], ForType::Serial, DeviceAPI::None, DimType::PureVar};
467         contents->init_def.schedule().dims().push_back(d);
468         StorageDim sd = {args[i]};
469         contents->func_schedule.storage_dims().push_back(sd);
470     }
471 
472     // Add the dummy outermost dim
473     {
474         Dim d = {Var::outermost().name(), ForType::Serial, DeviceAPI::None, DimType::PureVar};
475         contents->init_def.schedule().dims().push_back(d);
476     }
477 
478     contents->output_types.resize(values.size());
479     for (size_t i = 0; i < contents->output_types.size(); i++) {
480         contents->output_types[i] = values[i].type();
481     }
482 
483     for (size_t i = 0; i < values.size(); i++) {
484         string buffer_name = name();
485         if (values.size() > 1) {
486             buffer_name += '.' + std::to_string((int)i);
487         }
488         Parameter output(values[i].type(), true, args.size(), buffer_name);
489         contents->output_buffers.push_back(output);
490     }
491 }
492 
define_update(const vector<Expr> & _args,vector<Expr> values)493 void Function::define_update(const vector<Expr> &_args, vector<Expr> values) {
494     int update_idx = static_cast<int>(contents->updates.size());
495 
496     user_assert(!name().empty())
497         << "Func has an empty name.\n";
498     user_assert(has_pure_definition())
499         << "In update definition " << update_idx << " of Func \"" << name() << "\":\n"
500         << "Can't add an update definition without a pure definition first.\n";
501     user_assert(!frozen())
502         << "Func " << name() << " cannot be given a new update definition, "
503         << "because it has already been realized or used in the definition of another Func.\n";
504 
505     for (size_t i = 0; i < values.size(); i++) {
506         user_assert(values[i].defined())
507             << "In update definition " << update_idx << " of Func \"" << name() << "\":\n"
508             << "Undefined expression in right-hand-side of update.\n";
509     }
510 
511     // Check the dimensionality matches
512     user_assert((int)_args.size() == dimensions())
513         << "In update definition " << update_idx << " of Func \"" << name() << "\":\n"
514         << "Dimensionality of update definition must match dimensionality of pure definition.\n";
515 
516     user_assert(values.size() == contents->init_def.values().size())
517         << "In update definition " << update_idx << " of Func \"" << name() << "\":\n"
518         << "Number of tuple elements for update definition must "
519         << "match number of tuple elements for pure definition.\n";
520 
521     const auto &pure_def_vals = contents->init_def.values();
522     for (size_t i = 0; i < values.size(); i++) {
523         // Check that pure value and the update value have the same
524         // type.  Without this check, allocations may be the wrong size
525         // relative to what update code expects.
526         Type pure_type = pure_def_vals[i].type();
527         if (pure_type != values[i].type()) {
528             std::ostringstream err;
529             err << "In update definition " << update_idx << " of Func \"" << name() << "\":\n";
530             if (values.size()) {
531                 err << "Tuple element " << i << " of update definition has type ";
532             } else {
533                 err << "Update definition has type ";
534             }
535             err << values[i].type() << ", but pure definition has type " << pure_type;
536             user_error << err.str() << "\n";
537         }
538         values[i] = common_subexpression_elimination(values[i]);
539     }
540 
541     vector<Expr> args(_args.size());
542     for (size_t i = 0; i < args.size(); i++) {
543         args[i] = common_subexpression_elimination(_args[i]);
544     }
545 
546     // The pure args are those naked vars in the args that are not in
547     // a reduction domain and are not parameters and line up with the
548     // pure args in the pure definition.
549     bool pure = true;
550     vector<string> pure_args(args.size());
551     for (size_t i = 0; i < args.size(); i++) {
552         pure_args[i] = "";  // Will never match a var name
553         user_assert(args[i].defined())
554             << "In update definition " << update_idx << " of Func \"" << name() << "\":\n"
555             << "Argument " << i
556             << " in left-hand-side of update definition is undefined.\n";
557         if (const Variable *var = args[i].as<Variable>()) {
558             if (!var->param.defined() &&
559                 !var->reduction_domain.defined() &&
560                 var->name == contents->args[i]) {
561                 pure_args[i] = var->name;
562             } else {
563                 pure = false;
564             }
565         } else {
566             pure = false;
567         }
568     }
569 
570     // Make sure all the vars in the args and the value are either
571     // pure args, in the reduction domain, or a parameter. Also checks
572     // that recursive references to the function contain all the pure
573     // vars in the LHS in the correct places.
574     CheckVars check(name());
575     check.pure_args = pure_args;
576     for (const auto &arg : args) {
577         arg.accept(&check);
578     }
579     for (const auto &value : values) {
580         value.accept(&check);
581     }
582     if (check.reduction_domain.defined()) {
583         check.unbound_reduction_vars_ok = true;
584         check.reduction_domain.predicate().accept(&check);
585     }
586 
587     // Freeze all called functions
588     FreezeFunctions freezer(name());
589     for (const auto &arg : args) {
590         arg.accept(&freezer);
591     }
592     for (const auto &value : values) {
593         value.accept(&freezer);
594     }
595 
596     // Freeze the reduction domain if defined
597     if (check.reduction_domain.defined()) {
598         check.reduction_domain.predicate().accept(&freezer);
599         check.reduction_domain.freeze();
600     }
601 
602     // Tag calls to random() with the free vars
603     vector<VarOrRVar> free_vars;
604     int num_free_vars = (int)pure_args.size();
605     if (check.reduction_domain.defined()) {
606         num_free_vars += (int)check.reduction_domain.domain().size();
607     }
608     free_vars.reserve(num_free_vars);
609     for (const auto &pure_arg : pure_args) {
610         if (!pure_arg.empty()) {
611             free_vars.emplace_back(Var(pure_arg));
612         }
613     }
614     if (check.reduction_domain.defined()) {
615         for (size_t i = 0; i < check.reduction_domain.domain().size(); i++) {
616             free_vars.emplace_back(RVar(check.reduction_domain, i));
617         }
618     }
619     int tag = rand_counter++;
620     for (auto &arg : args) {
621         arg = lower_random(arg, free_vars, tag);
622     }
623     for (auto &value : values) {
624         value = lower_random(value, free_vars, tag);
625     }
626     if (check.reduction_domain.defined()) {
627         check.reduction_domain.set_predicate(lower_random(check.reduction_domain.predicate(), free_vars, tag));
628     }
629 
630     // The update value and args probably refer back to the
631     // function itself, introducing circular references and hence
632     // memory leaks. We need to break these cycles.
633     WeakenFunctionPtrs weakener(contents.get());
634     for (auto &arg : args) {
635         arg = weakener.mutate(arg);
636     }
637     for (auto &value : values) {
638         value = weakener.mutate(value);
639     }
640     if (check.reduction_domain.defined()) {
641         check.reduction_domain.set_predicate(
642             weakener.mutate(check.reduction_domain.predicate()));
643     }
644 
645     Definition r(args, values, check.reduction_domain, false);
646     internal_assert(!r.is_init()) << "Should have been an update definition\n";
647 
648     // First add any reduction domain
649     if (check.reduction_domain.defined()) {
650         for (size_t i = 0; i < check.reduction_domain.domain().size(); i++) {
651             // Is this RVar actually pure (safe to parallelize and
652             // reorder)? It's pure if one value of the RVar can never
653             // access from the same memory that another RVar is
654             // writing to.
655             const ReductionVariable &rvar = check.reduction_domain.domain()[i];
656             const string &v = rvar.var;
657 
658             bool pure = can_parallelize_rvar(v, name(), r);
659             Dim d = {v, ForType::Serial, DeviceAPI::None,
660                      pure ? DimType::PureRVar : DimType::ImpureRVar};
661             r.schedule().dims().push_back(d);
662         }
663     }
664 
665     // Then add the pure args outside of that
666     for (const auto &pure_arg : pure_args) {
667         if (!pure_arg.empty()) {
668             Dim d = {pure_arg, ForType::Serial, DeviceAPI::None, DimType::PureVar};
669             r.schedule().dims().push_back(d);
670         }
671     }
672 
673     // Then the dummy outermost dim
674     {
675         Dim d = {Var::outermost().name(), ForType::Serial, DeviceAPI::None, DimType::PureVar};
676         r.schedule().dims().push_back(d);
677     }
678 
679     // If there's no recursive reference, no reduction domain, and all
680     // the args are pure, then this definition completely hides
681     // earlier ones!
682     if (!check.reduction_domain.defined() &&
683         weakener.count == 0 &&
684         pure) {
685         user_warning
686             << "In update definition " << update_idx << " of Func \"" << name() << "\":\n"
687             << "Update definition completely hides earlier definitions, "
688             << " because all the arguments are pure, it contains no self-references, "
689             << " and no reduction domain. This may be an accidental re-definition of "
690             << " an already-defined function.\n";
691     }
692 
693     contents->updates.push_back(r);
694 }
695 
define_extern(const std::string & function_name,const std::vector<ExternFuncArgument> & extern_args,const std::vector<Type> & types,const std::vector<Var> & args,NameMangling mangling,DeviceAPI device_api)696 void Function::define_extern(const std::string &function_name,
697                              const std::vector<ExternFuncArgument> &extern_args,
698                              const std::vector<Type> &types,
699                              const std::vector<Var> &args,
700                              NameMangling mangling,
701                              DeviceAPI device_api) {
702 
703     user_assert(!has_pure_definition() && !has_update_definition())
704         << "In extern definition for Func \"" << name() << "\":\n"
705         << "Func with a pure definition cannot have an extern definition.\n";
706 
707     user_assert(!has_extern_definition())
708         << "In extern definition for Func \"" << name() << "\":\n"
709         << "Func already has an extern definition.\n";
710 
711     std::vector<string> arg_names;
712     std::vector<Expr> arg_exprs;
713     for (size_t i = 0; i < args.size(); i++) {
714         arg_names.push_back(args[i].name());
715         arg_exprs.push_back(args[i]);
716     }
717     contents->args = arg_names;
718     contents->extern_function_name = function_name;
719     contents->extern_arguments = extern_args;
720     contents->output_types = types;
721     contents->extern_mangling = mangling;
722     contents->extern_function_device_api = device_api;
723 
724     std::vector<Expr> values;
725     contents->output_buffers.clear();
726     for (size_t i = 0; i < types.size(); i++) {
727         string buffer_name = name();
728         if (types.size() > 1) {
729             buffer_name += '.' + std::to_string((int)i);
730         }
731         Parameter output(types[i], true, (int)args.size(), buffer_name);
732         contents->output_buffers.push_back(output);
733 
734         values.push_back(undef(types[i]));
735     }
736 
737     contents->init_def = Definition(arg_exprs, values, ReductionDomain(), true);
738 
739     // Reset the storage dims to match the pure args
740     contents->func_schedule.storage_dims().clear();
741     contents->init_def.schedule().dims().clear();
742     for (size_t i = 0; i < args.size(); i++) {
743         contents->func_schedule.storage_dims().push_back(StorageDim{arg_names[i]});
744         contents->init_def.schedule().dims().push_back(
745             Dim{arg_names[i], ForType::Extern, DeviceAPI::None, DimType::PureVar});
746     }
747     // Add the dummy outermost dim
748     contents->init_def.schedule().dims().push_back(
749         Dim{Var::outermost().name(), ForType::Serial, DeviceAPI::None, DimType::PureVar});
750 }
751 
accept(IRVisitor * visitor) const752 void Function::accept(IRVisitor *visitor) const {
753     contents->accept(visitor);
754 }
755 
mutate(IRMutator * mutator)756 void Function::mutate(IRMutator *mutator) {
757     contents->mutate(mutator);
758 }
759 
name() const760 const std::string &Function::name() const {
761     return contents->name;
762 }
763 
origin_name() const764 const std::string &Function::origin_name() const {
765     return contents->origin_name;
766 }
767 
definition()768 Definition &Function::definition() {
769     internal_assert(contents->init_def.defined());
770     return contents->init_def;
771 }
772 
definition() const773 const Definition &Function::definition() const {
774     internal_assert(contents->init_def.defined());
775     return contents->init_def;
776 }
777 
args() const778 const std::vector<std::string> &Function::args() const {
779     return contents->args;
780 }
781 
is_pure_arg(const std::string & name) const782 bool Function::is_pure_arg(const std::string &name) const {
783     return std::find(args().begin(), args().end(), name) != args().end();
784 }
785 
dimensions() const786 int Function::dimensions() const {
787     return args().size();
788 }
789 
output_types() const790 const std::vector<Type> &Function::output_types() const {
791     return contents->output_types;
792 }
793 
values() const794 const std::vector<Expr> &Function::values() const {
795     static const std::vector<Expr> empty;
796     if (has_pure_definition()) {
797         return contents->init_def.values();
798     } else {
799         return empty;
800     }
801 }
802 
schedule()803 FuncSchedule &Function::schedule() {
804     return contents->func_schedule;
805 }
806 
schedule() const807 const FuncSchedule &Function::schedule() const {
808     return contents->func_schedule;
809 }
810 
output_buffers() const811 const std::vector<Parameter> &Function::output_buffers() const {
812     return contents->output_buffers;
813 }
814 
update_schedule(int idx)815 StageSchedule &Function::update_schedule(int idx) {
816     internal_assert(idx < (int)contents->updates.size()) << "Invalid update definition index\n";
817     return contents->updates[idx].schedule();
818 }
819 
update(int idx)820 Definition &Function::update(int idx) {
821     internal_assert(idx < (int)contents->updates.size()) << "Invalid update definition index\n";
822     return contents->updates[idx];
823 }
824 
update(int idx) const825 const Definition &Function::update(int idx) const {
826     internal_assert(idx < (int)contents->updates.size()) << "Invalid update definition index\n";
827     return contents->updates[idx];
828 }
829 
updates() const830 const std::vector<Definition> &Function::updates() const {
831     return contents->updates;
832 }
833 
has_pure_definition() const834 bool Function::has_pure_definition() const {
835     return contents->init_def.defined();
836 }
837 
can_be_inlined() const838 bool Function::can_be_inlined() const {
839     return is_pure() && definition().specializations().empty();
840 }
841 
has_update_definition() const842 bool Function::has_update_definition() const {
843     return !contents->updates.empty();
844 }
845 
has_extern_definition() const846 bool Function::has_extern_definition() const {
847     return !contents->extern_function_name.empty();
848 }
849 
extern_definition_name_mangling() const850 NameMangling Function::extern_definition_name_mangling() const {
851     return contents->extern_mangling;
852 }
853 
make_call_to_extern_definition(const std::vector<Expr> & args,const Target & target) const854 Expr Function::make_call_to_extern_definition(const std::vector<Expr> &args,
855                                               const Target &target) const {
856     internal_assert(has_extern_definition());
857 
858     Call::CallType call_type = Call::Extern;
859     switch (contents->extern_mangling) {
860     case NameMangling::Default:
861         call_type = (target.has_feature(Target::CPlusPlusMangling) ? Call::ExternCPlusPlus : Call::Extern);
862         break;
863     case NameMangling::CPlusPlus:
864         call_type = Call::ExternCPlusPlus;
865         break;
866     case NameMangling::C:
867         call_type = Call::Extern;
868         break;
869     }
870     return Call::make(Int(32), contents->extern_function_name, args, call_type, contents);
871 }
872 
extern_definition_proxy_expr() const873 Expr Function::extern_definition_proxy_expr() const {
874     return contents->extern_proxy_expr;
875 }
876 
extern_definition_proxy_expr()877 Expr &Function::extern_definition_proxy_expr() {
878     return contents->extern_proxy_expr;
879 }
880 
extern_arguments() const881 const std::vector<ExternFuncArgument> &Function::extern_arguments() const {
882     return contents->extern_arguments;
883 }
884 
extern_arguments()885 std::vector<ExternFuncArgument> &Function::extern_arguments() {
886     return contents->extern_arguments;
887 }
888 
extern_function_name() const889 const std::string &Function::extern_function_name() const {
890     return contents->extern_function_name;
891 }
892 
extern_function_device_api() const893 DeviceAPI Function::extern_function_device_api() const {
894     return contents->extern_function_device_api;
895 }
896 
debug_file() const897 const std::string &Function::debug_file() const {
898     return contents->debug_file;
899 }
900 
debug_file()901 std::string &Function::debug_file() {
902     return contents->debug_file;
903 }
904 
operator ExternFuncArgument() const905 Function::operator ExternFuncArgument() const {
906     return ExternFuncArgument(contents);
907 }
908 
trace_loads()909 void Function::trace_loads() {
910     contents->trace_loads = true;
911 }
trace_stores()912 void Function::trace_stores() {
913     contents->trace_stores = true;
914 }
trace_realizations()915 void Function::trace_realizations() {
916     contents->trace_realizations = true;
917 }
add_trace_tag(const std::string & trace_tag)918 void Function::add_trace_tag(const std::string &trace_tag) {
919     contents->trace_tags.push_back(trace_tag);
920 }
921 
is_tracing_loads() const922 bool Function::is_tracing_loads() const {
923     return contents->trace_loads;
924 }
is_tracing_stores() const925 bool Function::is_tracing_stores() const {
926     return contents->trace_stores;
927 }
is_tracing_realizations() const928 bool Function::is_tracing_realizations() const {
929     return contents->trace_realizations;
930 }
get_trace_tags() const931 const std::vector<std::string> &Function::get_trace_tags() const {
932     return contents->trace_tags;
933 }
934 
freeze()935 void Function::freeze() {
936     contents->frozen = true;
937 }
938 
lock_loop_levels()939 void Function::lock_loop_levels() {
940     auto &schedule = contents->func_schedule;
941     schedule.compute_level().lock();
942     schedule.store_level().lock();
943     // If store_level is inlined, use the compute_level instead.
944     // (Note that we deliberately do *not* do the same if store_level
945     // is undefined.)
946     if (schedule.store_level().is_inlined()) {
947         schedule.store_level() = schedule.compute_level();
948     }
949     if (contents->init_def.defined()) {
950         contents->init_def.schedule().fuse_level().level.lock();
951     }
952     for (Definition &def : contents->updates) {
953         internal_assert(def.defined());
954         def.schedule().fuse_level().level.lock();
955     }
956 }
957 
frozen() const958 bool Function::frozen() const {
959     return contents->frozen;
960 }
961 
wrappers() const962 const map<string, FunctionPtr> &Function::wrappers() const {
963     return contents->func_schedule.wrappers();
964 }
965 
new_function_in_same_group(const std::string & f)966 Function Function::new_function_in_same_group(const std::string &f) {
967     int group_size = (int)(contents.group()->members.size());
968     contents.group()->members.resize(group_size + 1);
969     contents.group()->members[group_size].name = f;
970     FunctionPtr ptr;
971     ptr.strong = contents.group();
972     ptr.idx = group_size;
973     return Function(ptr);
974 }
975 
add_wrapper(const std::string & f,Function & wrapper)976 void Function::add_wrapper(const std::string &f, Function &wrapper) {
977     wrapper.freeze();
978     FunctionPtr ptr = wrapper.contents;
979 
980     // Weaken the pointer from the function to its wrapper
981     ptr.weaken();
982     contents->func_schedule.add_wrapper(f, ptr);
983 
984     // Weaken the pointer from the wrapper back to the function.
985     WeakenFunctionPtrs weakener(contents.get());
986     wrapper.mutate(&weakener);
987 }
988 
is_wrapper() const989 const Call *Function::is_wrapper() const {
990     const vector<Expr> &rhs = values();
991     if (rhs.size() != 1) {
992         return nullptr;
993     }
994     const Call *call = rhs[0].as<Call>();
995     if (!call) {
996         return nullptr;
997     }
998     vector<Expr> expected_args;
999     for (const string &v : args()) {
1000         expected_args.push_back(Variable::make(Int(32), v));
1001     }
1002     Expr expected_rhs =
1003         Call::make(call->type, call->name, expected_args, call->call_type,
1004                    call->func, call->value_index, call->image, call->param);
1005     if (equal(rhs[0], expected_rhs)) {
1006         return call;
1007     } else {
1008         return nullptr;
1009     }
1010 }
1011 
1012 namespace {
1013 
1014 // Replace all calls to functions listed in 'substitutions' with their wrappers.
1015 class SubstituteCalls : public IRMutator {
1016     using IRMutator::visit;
1017 
1018     const map<FunctionPtr, FunctionPtr> &substitutions;
1019 
visit(const Call * c)1020     Expr visit(const Call *c) override {
1021         Expr expr = IRMutator::visit(c);
1022         c = expr.as<Call>();
1023         internal_assert(c);
1024 
1025         if ((c->call_type == Call::Halide) &&
1026             c->func.defined() &&
1027             substitutions.count(c->func)) {
1028             auto it = substitutions.find(c->func);
1029             internal_assert(it != substitutions.end())
1030                 << "Function not in environment: " << c->func->name << "\n";
1031             FunctionPtr subs = it->second;
1032             debug(4) << "...Replace call to Func \"" << c->name << "\" with "
1033                      << "\"" << subs->name << "\"\n";
1034             expr = Call::make(c->type, subs->name, c->args, c->call_type,
1035                               subs, c->value_index,
1036                               c->image, c->param);
1037         }
1038         return expr;
1039     }
1040 
1041 public:
SubstituteCalls(const map<FunctionPtr,FunctionPtr> & substitutions)1042     SubstituteCalls(const map<FunctionPtr, FunctionPtr> &substitutions)
1043         : substitutions(substitutions) {
1044     }
1045 };
1046 
1047 }  // anonymous namespace
1048 
substitute_calls(const map<FunctionPtr,FunctionPtr> & substitutions)1049 Function &Function::substitute_calls(const map<FunctionPtr, FunctionPtr> &substitutions) {
1050     debug(4) << "Substituting calls in " << name() << "\n";
1051     if (substitutions.empty()) {
1052         return *this;
1053     }
1054     SubstituteCalls subs_calls(substitutions);
1055     contents->mutate(&subs_calls);
1056     return *this;
1057 }
1058 
substitute_calls(const Function & orig,const Function & substitute)1059 Function &Function::substitute_calls(const Function &orig, const Function &substitute) {
1060     map<FunctionPtr, FunctionPtr> substitutions;
1061     substitutions.emplace(orig.get_contents(), substitute.get_contents());
1062     return substitute_calls(substitutions);
1063 }
1064 
1065 // Deep copy an entire Function DAG.
deep_copy(const vector<Function> & outputs,const map<string,Function> & env)1066 pair<vector<Function>, map<string, Function>> deep_copy(
1067     const vector<Function> &outputs, const map<string, Function> &env) {
1068     vector<Function> copy_outputs;
1069     map<string, Function> copy_env;
1070 
1071     // Create empty deep-copies of all Functions in 'env'
1072     DeepCopyMap copied_map;  // Original Function -> Deep-copy
1073     IntrusivePtr<FunctionGroup> group(new FunctionGroup);
1074     group->members.resize(env.size());
1075     int i = 0;
1076     for (const auto &iter : env) {
1077         // Make a weak pointer to the function to use for within-group references.
1078         FunctionPtr ptr;
1079         ptr.weak = group.get();
1080         ptr.idx = i;
1081         ptr->name = iter.second.name();
1082         copied_map[iter.second.get_contents()] = ptr;
1083         i++;
1084     }
1085 
1086     // Deep copy all Functions in 'env' into their corresponding empty copies
1087     for (const auto &iter : env) {
1088         iter.second.deep_copy(copied_map[iter.second.get_contents()], copied_map);
1089     }
1090 
1091     // Need to substitute-in all old Function references in all Exprs referenced
1092     // within the Function with the deep-copy versions
1093     for (auto &iter : copied_map) {
1094         Function(iter.second).substitute_calls(copied_map);
1095     }
1096 
1097     // Populate the env with the deep-copy version
1098     for (const auto &iter : copied_map) {
1099         FunctionPtr ptr = iter.second;
1100         copy_env.emplace(iter.first->name, Function(ptr));
1101     }
1102 
1103     for (const auto &func : outputs) {
1104         const auto &iter = copied_map.find(func.get_contents());
1105         if (iter != copied_map.end()) {
1106             FunctionPtr ptr = iter->second;
1107             debug(4) << "Adding deep-copied version to outputs: " << func.name() << "\n";
1108             copy_outputs.emplace_back(ptr);
1109         } else {
1110             debug(4) << "Adding original version to outputs: " << func.name() << "\n";
1111             copy_outputs.push_back(func);
1112         }
1113     }
1114 
1115     return {copy_outputs, copy_env};
1116 }
1117 
1118 }  // namespace Internal
1119 }  // namespace Halide
1120