1 #include <atomic>
2 #include <memory>
3 #include <set>
4 #include <stdlib.h>
5 #include <utility>
6
7 #include "CSE.h"
8 #include "Func.h"
9 #include "Function.h"
10 #include "IR.h"
11 #include "IREquality.h"
12 #include "IRMutator.h"
13 #include "IROperator.h"
14 #include "IRPrinter.h"
15 #include "ParallelRVar.h"
16 #include "Random.h"
17 #include "Scope.h"
18 #include "Var.h"
19
20 namespace Halide {
21 namespace Internal {
22
23 using std::map;
24 using std::pair;
25 using std::string;
26 using std::vector;
27
28 typedef map<FunctionPtr, FunctionPtr> DeepCopyMap;
29
30 struct FunctionContents;
31
32 namespace {
33 // Weaken all the references to a particular Function to break
34 // reference cycles. Also count the number of references found.
35 class WeakenFunctionPtrs : public IRMutator {
36 using IRMutator::visit;
37
visit(const Call * c)38 Expr visit(const Call *c) override {
39 Expr expr = IRMutator::visit(c);
40 c = expr.as<Call>();
41 internal_assert(c);
42 if (c->func.defined() &&
43 c->func.get() == func) {
44 FunctionPtr ptr = c->func;
45 ptr.weaken();
46 expr = Call::make(c->type, c->name, c->args, c->call_type,
47 ptr, c->value_index,
48 c->image, c->param);
49 count++;
50 }
51 return expr;
52 }
53 FunctionContents *func;
54
55 public:
56 int count = 0;
WeakenFunctionPtrs(FunctionContents * f)57 WeakenFunctionPtrs(FunctionContents *f)
58 : func(f) {
59 }
60 };
61 } // namespace
62
63 struct FunctionContents {
64 std::string name;
65 std::string origin_name;
66 std::vector<Type> output_types;
67
68 // The names of the dimensions of the Function. Corresponds to the
69 // LHS of the pure definition if there is one. Is also the initial
70 // stage of the dims and storage_dims. Used to identify dimensions
71 // of the Function by name.
72 std::vector<string> args;
73
74 // Function-specific schedule. This schedule is applied to all stages
75 // within the function.
76 FuncSchedule func_schedule;
77
78 Definition init_def;
79 std::vector<Definition> updates;
80
81 std::string debug_file;
82
83 std::vector<Parameter> output_buffers;
84
85 std::vector<ExternFuncArgument> extern_arguments;
86 std::string extern_function_name;
87
88 NameMangling extern_mangling = NameMangling::Default;
89 DeviceAPI extern_function_device_api = DeviceAPI::Host;
90 Expr extern_proxy_expr;
91
92 bool trace_loads = false, trace_stores = false, trace_realizations = false;
93 std::vector<string> trace_tags;
94
95 bool frozen = false;
96
acceptHalide::Internal::FunctionContents97 void accept(IRVisitor *visitor) const {
98 func_schedule.accept(visitor);
99
100 if (init_def.defined()) {
101 init_def.accept(visitor);
102 }
103 for (const Definition &def : updates) {
104 def.accept(visitor);
105 }
106
107 if (!extern_function_name.empty()) {
108 for (ExternFuncArgument i : extern_arguments) {
109 if (i.is_func()) {
110 user_assert(i.func.get() != this)
111 << "Extern Func has itself as an argument";
112 i.func->accept(visitor);
113 } else if (i.is_expr()) {
114 i.expr.accept(visitor);
115 }
116 }
117 if (extern_proxy_expr.defined()) {
118 extern_proxy_expr.accept(visitor);
119 }
120 }
121
122 for (Parameter i : output_buffers) {
123 for (size_t j = 0; j < args.size(); j++) {
124 if (i.min_constraint(j).defined()) {
125 i.min_constraint(j).accept(visitor);
126 }
127 if (i.stride_constraint(j).defined()) {
128 i.stride_constraint(j).accept(visitor);
129 }
130 if (i.extent_constraint(j).defined()) {
131 i.extent_constraint(j).accept(visitor);
132 }
133 }
134 }
135 }
136
137 // Pass an IRMutator through to all Exprs referenced in the FunctionContents
mutateHalide::Internal::FunctionContents138 void mutate(IRMutator *mutator) {
139 func_schedule.mutate(mutator);
140
141 if (init_def.defined()) {
142 init_def.mutate(mutator);
143 }
144 for (Definition &def : updates) {
145 def.mutate(mutator);
146 }
147
148 if (!extern_function_name.empty()) {
149 for (ExternFuncArgument &i : extern_arguments) {
150 if (i.is_expr()) {
151 i.expr = mutator->mutate(i.expr);
152 }
153 }
154 extern_proxy_expr = mutator->mutate(extern_proxy_expr);
155 }
156 }
157 };
158
159 struct FunctionGroup {
160 mutable RefCount ref_count;
161 vector<FunctionContents> members;
162 };
163
get() const164 FunctionContents *FunctionPtr::get() const {
165 return &(group()->members[idx]);
166 }
167
168 template<>
ref_count(const FunctionGroup * f)169 RefCount &ref_count<FunctionGroup>(const FunctionGroup *f) noexcept {
170 return f->ref_count;
171 }
172
173 template<>
destroy(const FunctionGroup * f)174 void destroy<FunctionGroup>(const FunctionGroup *f) {
175 delete f;
176 }
177
178 // All variables present in any part of a function definition must
179 // either be pure args, elements of the reduction domain, parameters
180 // (i.e. attached to some Parameter object), or part of a let node
181 // internal to the expression
182 struct CheckVars : public IRGraphVisitor {
183 vector<string> pure_args;
184 ReductionDomain reduction_domain;
185 Scope<> defined_internally;
186 const std::string name;
187 bool unbound_reduction_vars_ok = false;
188
CheckVarsHalide::Internal::CheckVars189 CheckVars(const std::string &n)
190 : name(n) {
191 }
192
193 using IRVisitor::visit;
194
visitHalide::Internal::CheckVars195 void visit(const Let *let) override {
196 let->value.accept(this);
197 ScopedBinding<> bind(defined_internally, let->name);
198 let->body.accept(this);
199 }
200
visitHalide::Internal::CheckVars201 void visit(const Call *op) override {
202 IRGraphVisitor::visit(op);
203 if (op->name == name && op->call_type == Call::Halide) {
204 for (size_t i = 0; i < op->args.size(); i++) {
205 const Variable *var = op->args[i].as<Variable>();
206 if (!pure_args[i].empty()) {
207 user_assert(var && var->name == pure_args[i])
208 << "In definition of Func \"" << name << "\":\n"
209 << "All of a function's recursive references to itself"
210 << " must contain the same pure variables in the same"
211 << " places as on the left-hand-side.\n";
212 }
213 }
214 }
215 }
216
visitHalide::Internal::CheckVars217 void visit(const Variable *var) override {
218 // Is it a parameter?
219 if (var->param.defined()) return;
220
221 // Was it defined internally by a let expression?
222 if (defined_internally.contains(var->name)) return;
223
224 // Is it a pure argument?
225 for (size_t i = 0; i < pure_args.size(); i++) {
226 if (var->name == pure_args[i]) return;
227 }
228
229 // Is it in a reduction domain?
230 if (var->reduction_domain.defined()) {
231 if (!reduction_domain.defined()) {
232 reduction_domain = var->reduction_domain;
233 return;
234 } else if (var->reduction_domain.same_as(reduction_domain)) {
235 // It's in a reduction domain we already know about
236 return;
237 } else {
238 user_error << "Multiple reduction domains found in definition of Func \"" << name << "\"\n";
239 }
240 } else if (reduction_domain.defined() && unbound_reduction_vars_ok) {
241 // Is it one of the RVars from the reduction domain we already
242 // know about (this can happen in the RDom predicate).
243 for (const ReductionVariable &rv : reduction_domain.domain()) {
244 if (rv.var == var->name) {
245 return;
246 }
247 }
248 }
249
250 user_error << "Undefined variable \"" << var->name << "\" in definition of Func \"" << name << "\"\n";
251 }
252 };
253
254 // Mark all functions found in an expr as frozen.
255 class FreezeFunctions : public IRGraphVisitor {
256 using IRGraphVisitor::visit;
257
258 const string &func;
259
visit(const Call * op)260 void visit(const Call *op) override {
261 IRGraphVisitor::visit(op);
262 if (op->call_type == Call::Halide &&
263 op->func.defined() &&
264 op->name != func) {
265 Function f(op->func);
266 f.freeze();
267 }
268 }
269
270 public:
FreezeFunctions(const string & f)271 FreezeFunctions(const string &f)
272 : func(f) {
273 }
274 };
275
276 // A counter to use in tagging random variables
277 namespace {
278 static std::atomic<int> rand_counter{0};
279 }
280
Function()281 Function::Function() {
282 }
283
Function(const FunctionPtr & ptr)284 Function::Function(const FunctionPtr &ptr)
285 : contents(ptr) {
286 contents.strengthen();
287 internal_assert(ptr.defined())
288 << "Can't construct Function from undefined FunctionContents ptr\n";
289 }
290
Function(const std::string & n)291 Function::Function(const std::string &n) {
292 for (size_t i = 0; i < n.size(); i++) {
293 user_assert(n[i] != '.')
294 << "Func name \"" << n << "\" is invalid. "
295 << "Func names may not contain the character '.', "
296 << "as it is used internally by Halide as a separator\n";
297 }
298 contents.strong = new FunctionGroup;
299 contents.strong->members.resize(1);
300 contents->name = n;
301 contents->origin_name = n;
302 }
303
304 // Return deep-copy of ExternFuncArgument 'src'
deep_copy_extern_func_argument_helper(const ExternFuncArgument & src,DeepCopyMap & copied_map)305 ExternFuncArgument deep_copy_extern_func_argument_helper(
306 const ExternFuncArgument &src, DeepCopyMap &copied_map) {
307 ExternFuncArgument copy;
308 copy.arg_type = src.arg_type;
309 copy.buffer = src.buffer;
310 copy.expr = src.expr;
311 copy.image_param = src.image_param;
312
313 if (!src.func.defined()) { // No need to deep-copy the func if it's undefined
314 internal_assert(!src.is_func())
315 << "ExternFuncArgument has type FuncArg but has no function definition\n";
316 return copy;
317 }
318
319 // If the FunctionContents has already been deep-copied previously, i.e.
320 // it's in the 'copied_map', use the deep-copied version from the map instead
321 // of creating a new deep-copy
322 FunctionPtr &copied_func = copied_map[src.func];
323 internal_assert(copied_func.defined());
324 copy.func = copied_func;
325 return copy;
326 }
327
deep_copy(const FunctionPtr & copy,DeepCopyMap & copied_map) const328 void Function::deep_copy(const FunctionPtr ©, DeepCopyMap &copied_map) const {
329 internal_assert(copy.defined() && contents.defined())
330 << "Cannot deep-copy undefined Function\n";
331
332 // Add reference to this Function's deep-copy to the map in case of
333 // self-reference, e.g. self-reference in an Definition.
334 copied_map[contents] = copy;
335
336 debug(4) << "Deep-copy function contents: \"" << contents->name << "\"\n";
337
338 copy->name = contents->name;
339 copy->origin_name = contents->origin_name;
340 copy->args = contents->args;
341 copy->output_types = contents->output_types;
342 copy->debug_file = contents->debug_file;
343 copy->extern_function_name = contents->extern_function_name;
344 copy->extern_mangling = contents->extern_mangling;
345 copy->extern_function_device_api = contents->extern_function_device_api;
346 copy->extern_proxy_expr = contents->extern_proxy_expr;
347 copy->trace_loads = contents->trace_loads;
348 copy->trace_stores = contents->trace_stores;
349 copy->trace_realizations = contents->trace_realizations;
350 copy->trace_tags = contents->trace_tags;
351 copy->frozen = contents->frozen;
352 copy->output_buffers = contents->output_buffers;
353 copy->func_schedule = contents->func_schedule.deep_copy(copied_map);
354
355 // Copy the pure definition
356 if (contents->init_def.defined()) {
357 copy->init_def = contents->init_def.get_copy();
358 internal_assert(copy->init_def.is_init());
359 internal_assert(copy->init_def.schedule().rvars().empty())
360 << "Init definition shouldn't have reduction domain\n";
361 }
362
363 for (const Definition &def : contents->updates) {
364 internal_assert(!def.is_init());
365 Definition def_copy = def.get_copy();
366 internal_assert(!def_copy.is_init());
367 copy->updates.push_back(std::move(def_copy));
368 }
369
370 for (const ExternFuncArgument &e : contents->extern_arguments) {
371 ExternFuncArgument e_copy = deep_copy_extern_func_argument_helper(e, copied_map);
372 copy->extern_arguments.push_back(std::move(e_copy));
373 }
374 }
375
deep_copy(string name,const FunctionPtr & copy,DeepCopyMap & copied_map) const376 void Function::deep_copy(string name, const FunctionPtr ©, DeepCopyMap &copied_map) const {
377 deep_copy(copy, copied_map);
378 copy->name = std::move(name);
379 }
380
define(const vector<string> & args,vector<Expr> values)381 void Function::define(const vector<string> &args, vector<Expr> values) {
382 user_assert(!frozen())
383 << "Func " << name() << " cannot be given a new pure definition, "
384 << "because it has already been realized or used in the definition of another Func.\n";
385 user_assert(!has_extern_definition())
386 << "In pure definition of Func \"" << name() << "\":\n"
387 << "Func with extern definition cannot be given a pure definition.\n";
388 user_assert(!name().empty()) << "A Func may not have an empty name.\n";
389 for (size_t i = 0; i < values.size(); i++) {
390 user_assert(values[i].defined())
391 << "In pure definition of Func \"" << name() << "\":\n"
392 << "Undefined expression in right-hand-side of definition.\n";
393 }
394
395 // Make sure all the vars in the value are either args or are
396 // attached to some parameter
397 CheckVars check(name());
398 check.pure_args = args;
399 for (const auto &value : values) {
400 value.accept(&check);
401 }
402
403 // Freeze all called functions
404 FreezeFunctions freezer(name());
405 for (const auto &value : values) {
406 value.accept(&freezer);
407 }
408
409 // Make sure all the vars in the args have unique non-empty names
410 for (size_t i = 0; i < args.size(); i++) {
411 user_assert(!args[i].empty())
412 << "In pure definition of Func \"" << name() << "\":\n"
413 << "In left-hand-side of definition, argument "
414 << i << " has an empty name.\n";
415 for (size_t j = 0; j < i; j++) {
416 user_assert(args[i] != args[j])
417 << "In pure definition of Func \"" << name() << "\":\n"
418 << "In left-hand-side of definition, arguments "
419 << i << " and " << j
420 << " both have the name \"" + args[i] + "\"\n";
421 }
422 }
423
424 for (auto &value : values) {
425 value = common_subexpression_elimination(value);
426 }
427
428 // Tag calls to random() with the free vars
429 int tag = rand_counter++;
430 vector<VarOrRVar> free_vars;
431 free_vars.reserve(args.size());
432 for (const auto &arg : args) {
433 free_vars.emplace_back(Var(arg));
434 }
435 for (auto &value : values) {
436 value = lower_random(value, free_vars, tag);
437 }
438
439 user_assert(!check.reduction_domain.defined())
440 << "In pure definition of Func \"" << name() << "\":\n"
441 << "Reduction domain referenced in pure function definition.\n";
442
443 if (!contents.defined()) {
444 contents.strong = new FunctionGroup;
445 contents.strong->members.resize(1);
446 contents->name = unique_name('f');
447 contents->origin_name = contents->name;
448 }
449
450 user_assert(!contents->init_def.defined())
451 << "In pure definition of Func \"" << name() << "\":\n"
452 << "Func is already defined.\n";
453
454 contents->args = args;
455
456 std::vector<Expr> init_def_args;
457 init_def_args.resize(args.size());
458 for (size_t i = 0; i < args.size(); i++) {
459 init_def_args[i] = Var(args[i]);
460 }
461
462 ReductionDomain rdom;
463 contents->init_def = Definition(init_def_args, values, rdom, true);
464
465 for (size_t i = 0; i < args.size(); i++) {
466 Dim d = {args[i], ForType::Serial, DeviceAPI::None, DimType::PureVar};
467 contents->init_def.schedule().dims().push_back(d);
468 StorageDim sd = {args[i]};
469 contents->func_schedule.storage_dims().push_back(sd);
470 }
471
472 // Add the dummy outermost dim
473 {
474 Dim d = {Var::outermost().name(), ForType::Serial, DeviceAPI::None, DimType::PureVar};
475 contents->init_def.schedule().dims().push_back(d);
476 }
477
478 contents->output_types.resize(values.size());
479 for (size_t i = 0; i < contents->output_types.size(); i++) {
480 contents->output_types[i] = values[i].type();
481 }
482
483 for (size_t i = 0; i < values.size(); i++) {
484 string buffer_name = name();
485 if (values.size() > 1) {
486 buffer_name += '.' + std::to_string((int)i);
487 }
488 Parameter output(values[i].type(), true, args.size(), buffer_name);
489 contents->output_buffers.push_back(output);
490 }
491 }
492
define_update(const vector<Expr> & _args,vector<Expr> values)493 void Function::define_update(const vector<Expr> &_args, vector<Expr> values) {
494 int update_idx = static_cast<int>(contents->updates.size());
495
496 user_assert(!name().empty())
497 << "Func has an empty name.\n";
498 user_assert(has_pure_definition())
499 << "In update definition " << update_idx << " of Func \"" << name() << "\":\n"
500 << "Can't add an update definition without a pure definition first.\n";
501 user_assert(!frozen())
502 << "Func " << name() << " cannot be given a new update definition, "
503 << "because it has already been realized or used in the definition of another Func.\n";
504
505 for (size_t i = 0; i < values.size(); i++) {
506 user_assert(values[i].defined())
507 << "In update definition " << update_idx << " of Func \"" << name() << "\":\n"
508 << "Undefined expression in right-hand-side of update.\n";
509 }
510
511 // Check the dimensionality matches
512 user_assert((int)_args.size() == dimensions())
513 << "In update definition " << update_idx << " of Func \"" << name() << "\":\n"
514 << "Dimensionality of update definition must match dimensionality of pure definition.\n";
515
516 user_assert(values.size() == contents->init_def.values().size())
517 << "In update definition " << update_idx << " of Func \"" << name() << "\":\n"
518 << "Number of tuple elements for update definition must "
519 << "match number of tuple elements for pure definition.\n";
520
521 const auto &pure_def_vals = contents->init_def.values();
522 for (size_t i = 0; i < values.size(); i++) {
523 // Check that pure value and the update value have the same
524 // type. Without this check, allocations may be the wrong size
525 // relative to what update code expects.
526 Type pure_type = pure_def_vals[i].type();
527 if (pure_type != values[i].type()) {
528 std::ostringstream err;
529 err << "In update definition " << update_idx << " of Func \"" << name() << "\":\n";
530 if (values.size()) {
531 err << "Tuple element " << i << " of update definition has type ";
532 } else {
533 err << "Update definition has type ";
534 }
535 err << values[i].type() << ", but pure definition has type " << pure_type;
536 user_error << err.str() << "\n";
537 }
538 values[i] = common_subexpression_elimination(values[i]);
539 }
540
541 vector<Expr> args(_args.size());
542 for (size_t i = 0; i < args.size(); i++) {
543 args[i] = common_subexpression_elimination(_args[i]);
544 }
545
546 // The pure args are those naked vars in the args that are not in
547 // a reduction domain and are not parameters and line up with the
548 // pure args in the pure definition.
549 bool pure = true;
550 vector<string> pure_args(args.size());
551 for (size_t i = 0; i < args.size(); i++) {
552 pure_args[i] = ""; // Will never match a var name
553 user_assert(args[i].defined())
554 << "In update definition " << update_idx << " of Func \"" << name() << "\":\n"
555 << "Argument " << i
556 << " in left-hand-side of update definition is undefined.\n";
557 if (const Variable *var = args[i].as<Variable>()) {
558 if (!var->param.defined() &&
559 !var->reduction_domain.defined() &&
560 var->name == contents->args[i]) {
561 pure_args[i] = var->name;
562 } else {
563 pure = false;
564 }
565 } else {
566 pure = false;
567 }
568 }
569
570 // Make sure all the vars in the args and the value are either
571 // pure args, in the reduction domain, or a parameter. Also checks
572 // that recursive references to the function contain all the pure
573 // vars in the LHS in the correct places.
574 CheckVars check(name());
575 check.pure_args = pure_args;
576 for (const auto &arg : args) {
577 arg.accept(&check);
578 }
579 for (const auto &value : values) {
580 value.accept(&check);
581 }
582 if (check.reduction_domain.defined()) {
583 check.unbound_reduction_vars_ok = true;
584 check.reduction_domain.predicate().accept(&check);
585 }
586
587 // Freeze all called functions
588 FreezeFunctions freezer(name());
589 for (const auto &arg : args) {
590 arg.accept(&freezer);
591 }
592 for (const auto &value : values) {
593 value.accept(&freezer);
594 }
595
596 // Freeze the reduction domain if defined
597 if (check.reduction_domain.defined()) {
598 check.reduction_domain.predicate().accept(&freezer);
599 check.reduction_domain.freeze();
600 }
601
602 // Tag calls to random() with the free vars
603 vector<VarOrRVar> free_vars;
604 int num_free_vars = (int)pure_args.size();
605 if (check.reduction_domain.defined()) {
606 num_free_vars += (int)check.reduction_domain.domain().size();
607 }
608 free_vars.reserve(num_free_vars);
609 for (const auto &pure_arg : pure_args) {
610 if (!pure_arg.empty()) {
611 free_vars.emplace_back(Var(pure_arg));
612 }
613 }
614 if (check.reduction_domain.defined()) {
615 for (size_t i = 0; i < check.reduction_domain.domain().size(); i++) {
616 free_vars.emplace_back(RVar(check.reduction_domain, i));
617 }
618 }
619 int tag = rand_counter++;
620 for (auto &arg : args) {
621 arg = lower_random(arg, free_vars, tag);
622 }
623 for (auto &value : values) {
624 value = lower_random(value, free_vars, tag);
625 }
626 if (check.reduction_domain.defined()) {
627 check.reduction_domain.set_predicate(lower_random(check.reduction_domain.predicate(), free_vars, tag));
628 }
629
630 // The update value and args probably refer back to the
631 // function itself, introducing circular references and hence
632 // memory leaks. We need to break these cycles.
633 WeakenFunctionPtrs weakener(contents.get());
634 for (auto &arg : args) {
635 arg = weakener.mutate(arg);
636 }
637 for (auto &value : values) {
638 value = weakener.mutate(value);
639 }
640 if (check.reduction_domain.defined()) {
641 check.reduction_domain.set_predicate(
642 weakener.mutate(check.reduction_domain.predicate()));
643 }
644
645 Definition r(args, values, check.reduction_domain, false);
646 internal_assert(!r.is_init()) << "Should have been an update definition\n";
647
648 // First add any reduction domain
649 if (check.reduction_domain.defined()) {
650 for (size_t i = 0; i < check.reduction_domain.domain().size(); i++) {
651 // Is this RVar actually pure (safe to parallelize and
652 // reorder)? It's pure if one value of the RVar can never
653 // access from the same memory that another RVar is
654 // writing to.
655 const ReductionVariable &rvar = check.reduction_domain.domain()[i];
656 const string &v = rvar.var;
657
658 bool pure = can_parallelize_rvar(v, name(), r);
659 Dim d = {v, ForType::Serial, DeviceAPI::None,
660 pure ? DimType::PureRVar : DimType::ImpureRVar};
661 r.schedule().dims().push_back(d);
662 }
663 }
664
665 // Then add the pure args outside of that
666 for (const auto &pure_arg : pure_args) {
667 if (!pure_arg.empty()) {
668 Dim d = {pure_arg, ForType::Serial, DeviceAPI::None, DimType::PureVar};
669 r.schedule().dims().push_back(d);
670 }
671 }
672
673 // Then the dummy outermost dim
674 {
675 Dim d = {Var::outermost().name(), ForType::Serial, DeviceAPI::None, DimType::PureVar};
676 r.schedule().dims().push_back(d);
677 }
678
679 // If there's no recursive reference, no reduction domain, and all
680 // the args are pure, then this definition completely hides
681 // earlier ones!
682 if (!check.reduction_domain.defined() &&
683 weakener.count == 0 &&
684 pure) {
685 user_warning
686 << "In update definition " << update_idx << " of Func \"" << name() << "\":\n"
687 << "Update definition completely hides earlier definitions, "
688 << " because all the arguments are pure, it contains no self-references, "
689 << " and no reduction domain. This may be an accidental re-definition of "
690 << " an already-defined function.\n";
691 }
692
693 contents->updates.push_back(r);
694 }
695
define_extern(const std::string & function_name,const std::vector<ExternFuncArgument> & extern_args,const std::vector<Type> & types,const std::vector<Var> & args,NameMangling mangling,DeviceAPI device_api)696 void Function::define_extern(const std::string &function_name,
697 const std::vector<ExternFuncArgument> &extern_args,
698 const std::vector<Type> &types,
699 const std::vector<Var> &args,
700 NameMangling mangling,
701 DeviceAPI device_api) {
702
703 user_assert(!has_pure_definition() && !has_update_definition())
704 << "In extern definition for Func \"" << name() << "\":\n"
705 << "Func with a pure definition cannot have an extern definition.\n";
706
707 user_assert(!has_extern_definition())
708 << "In extern definition for Func \"" << name() << "\":\n"
709 << "Func already has an extern definition.\n";
710
711 std::vector<string> arg_names;
712 std::vector<Expr> arg_exprs;
713 for (size_t i = 0; i < args.size(); i++) {
714 arg_names.push_back(args[i].name());
715 arg_exprs.push_back(args[i]);
716 }
717 contents->args = arg_names;
718 contents->extern_function_name = function_name;
719 contents->extern_arguments = extern_args;
720 contents->output_types = types;
721 contents->extern_mangling = mangling;
722 contents->extern_function_device_api = device_api;
723
724 std::vector<Expr> values;
725 contents->output_buffers.clear();
726 for (size_t i = 0; i < types.size(); i++) {
727 string buffer_name = name();
728 if (types.size() > 1) {
729 buffer_name += '.' + std::to_string((int)i);
730 }
731 Parameter output(types[i], true, (int)args.size(), buffer_name);
732 contents->output_buffers.push_back(output);
733
734 values.push_back(undef(types[i]));
735 }
736
737 contents->init_def = Definition(arg_exprs, values, ReductionDomain(), true);
738
739 // Reset the storage dims to match the pure args
740 contents->func_schedule.storage_dims().clear();
741 contents->init_def.schedule().dims().clear();
742 for (size_t i = 0; i < args.size(); i++) {
743 contents->func_schedule.storage_dims().push_back(StorageDim{arg_names[i]});
744 contents->init_def.schedule().dims().push_back(
745 Dim{arg_names[i], ForType::Extern, DeviceAPI::None, DimType::PureVar});
746 }
747 // Add the dummy outermost dim
748 contents->init_def.schedule().dims().push_back(
749 Dim{Var::outermost().name(), ForType::Serial, DeviceAPI::None, DimType::PureVar});
750 }
751
accept(IRVisitor * visitor) const752 void Function::accept(IRVisitor *visitor) const {
753 contents->accept(visitor);
754 }
755
mutate(IRMutator * mutator)756 void Function::mutate(IRMutator *mutator) {
757 contents->mutate(mutator);
758 }
759
name() const760 const std::string &Function::name() const {
761 return contents->name;
762 }
763
origin_name() const764 const std::string &Function::origin_name() const {
765 return contents->origin_name;
766 }
767
definition()768 Definition &Function::definition() {
769 internal_assert(contents->init_def.defined());
770 return contents->init_def;
771 }
772
definition() const773 const Definition &Function::definition() const {
774 internal_assert(contents->init_def.defined());
775 return contents->init_def;
776 }
777
args() const778 const std::vector<std::string> &Function::args() const {
779 return contents->args;
780 }
781
is_pure_arg(const std::string & name) const782 bool Function::is_pure_arg(const std::string &name) const {
783 return std::find(args().begin(), args().end(), name) != args().end();
784 }
785
dimensions() const786 int Function::dimensions() const {
787 return args().size();
788 }
789
output_types() const790 const std::vector<Type> &Function::output_types() const {
791 return contents->output_types;
792 }
793
values() const794 const std::vector<Expr> &Function::values() const {
795 static const std::vector<Expr> empty;
796 if (has_pure_definition()) {
797 return contents->init_def.values();
798 } else {
799 return empty;
800 }
801 }
802
schedule()803 FuncSchedule &Function::schedule() {
804 return contents->func_schedule;
805 }
806
schedule() const807 const FuncSchedule &Function::schedule() const {
808 return contents->func_schedule;
809 }
810
output_buffers() const811 const std::vector<Parameter> &Function::output_buffers() const {
812 return contents->output_buffers;
813 }
814
update_schedule(int idx)815 StageSchedule &Function::update_schedule(int idx) {
816 internal_assert(idx < (int)contents->updates.size()) << "Invalid update definition index\n";
817 return contents->updates[idx].schedule();
818 }
819
update(int idx)820 Definition &Function::update(int idx) {
821 internal_assert(idx < (int)contents->updates.size()) << "Invalid update definition index\n";
822 return contents->updates[idx];
823 }
824
update(int idx) const825 const Definition &Function::update(int idx) const {
826 internal_assert(idx < (int)contents->updates.size()) << "Invalid update definition index\n";
827 return contents->updates[idx];
828 }
829
updates() const830 const std::vector<Definition> &Function::updates() const {
831 return contents->updates;
832 }
833
has_pure_definition() const834 bool Function::has_pure_definition() const {
835 return contents->init_def.defined();
836 }
837
can_be_inlined() const838 bool Function::can_be_inlined() const {
839 return is_pure() && definition().specializations().empty();
840 }
841
has_update_definition() const842 bool Function::has_update_definition() const {
843 return !contents->updates.empty();
844 }
845
has_extern_definition() const846 bool Function::has_extern_definition() const {
847 return !contents->extern_function_name.empty();
848 }
849
extern_definition_name_mangling() const850 NameMangling Function::extern_definition_name_mangling() const {
851 return contents->extern_mangling;
852 }
853
make_call_to_extern_definition(const std::vector<Expr> & args,const Target & target) const854 Expr Function::make_call_to_extern_definition(const std::vector<Expr> &args,
855 const Target &target) const {
856 internal_assert(has_extern_definition());
857
858 Call::CallType call_type = Call::Extern;
859 switch (contents->extern_mangling) {
860 case NameMangling::Default:
861 call_type = (target.has_feature(Target::CPlusPlusMangling) ? Call::ExternCPlusPlus : Call::Extern);
862 break;
863 case NameMangling::CPlusPlus:
864 call_type = Call::ExternCPlusPlus;
865 break;
866 case NameMangling::C:
867 call_type = Call::Extern;
868 break;
869 }
870 return Call::make(Int(32), contents->extern_function_name, args, call_type, contents);
871 }
872
extern_definition_proxy_expr() const873 Expr Function::extern_definition_proxy_expr() const {
874 return contents->extern_proxy_expr;
875 }
876
extern_definition_proxy_expr()877 Expr &Function::extern_definition_proxy_expr() {
878 return contents->extern_proxy_expr;
879 }
880
extern_arguments() const881 const std::vector<ExternFuncArgument> &Function::extern_arguments() const {
882 return contents->extern_arguments;
883 }
884
extern_arguments()885 std::vector<ExternFuncArgument> &Function::extern_arguments() {
886 return contents->extern_arguments;
887 }
888
extern_function_name() const889 const std::string &Function::extern_function_name() const {
890 return contents->extern_function_name;
891 }
892
extern_function_device_api() const893 DeviceAPI Function::extern_function_device_api() const {
894 return contents->extern_function_device_api;
895 }
896
debug_file() const897 const std::string &Function::debug_file() const {
898 return contents->debug_file;
899 }
900
debug_file()901 std::string &Function::debug_file() {
902 return contents->debug_file;
903 }
904
operator ExternFuncArgument() const905 Function::operator ExternFuncArgument() const {
906 return ExternFuncArgument(contents);
907 }
908
trace_loads()909 void Function::trace_loads() {
910 contents->trace_loads = true;
911 }
trace_stores()912 void Function::trace_stores() {
913 contents->trace_stores = true;
914 }
trace_realizations()915 void Function::trace_realizations() {
916 contents->trace_realizations = true;
917 }
add_trace_tag(const std::string & trace_tag)918 void Function::add_trace_tag(const std::string &trace_tag) {
919 contents->trace_tags.push_back(trace_tag);
920 }
921
is_tracing_loads() const922 bool Function::is_tracing_loads() const {
923 return contents->trace_loads;
924 }
is_tracing_stores() const925 bool Function::is_tracing_stores() const {
926 return contents->trace_stores;
927 }
is_tracing_realizations() const928 bool Function::is_tracing_realizations() const {
929 return contents->trace_realizations;
930 }
get_trace_tags() const931 const std::vector<std::string> &Function::get_trace_tags() const {
932 return contents->trace_tags;
933 }
934
freeze()935 void Function::freeze() {
936 contents->frozen = true;
937 }
938
lock_loop_levels()939 void Function::lock_loop_levels() {
940 auto &schedule = contents->func_schedule;
941 schedule.compute_level().lock();
942 schedule.store_level().lock();
943 // If store_level is inlined, use the compute_level instead.
944 // (Note that we deliberately do *not* do the same if store_level
945 // is undefined.)
946 if (schedule.store_level().is_inlined()) {
947 schedule.store_level() = schedule.compute_level();
948 }
949 if (contents->init_def.defined()) {
950 contents->init_def.schedule().fuse_level().level.lock();
951 }
952 for (Definition &def : contents->updates) {
953 internal_assert(def.defined());
954 def.schedule().fuse_level().level.lock();
955 }
956 }
957
frozen() const958 bool Function::frozen() const {
959 return contents->frozen;
960 }
961
wrappers() const962 const map<string, FunctionPtr> &Function::wrappers() const {
963 return contents->func_schedule.wrappers();
964 }
965
new_function_in_same_group(const std::string & f)966 Function Function::new_function_in_same_group(const std::string &f) {
967 int group_size = (int)(contents.group()->members.size());
968 contents.group()->members.resize(group_size + 1);
969 contents.group()->members[group_size].name = f;
970 FunctionPtr ptr;
971 ptr.strong = contents.group();
972 ptr.idx = group_size;
973 return Function(ptr);
974 }
975
add_wrapper(const std::string & f,Function & wrapper)976 void Function::add_wrapper(const std::string &f, Function &wrapper) {
977 wrapper.freeze();
978 FunctionPtr ptr = wrapper.contents;
979
980 // Weaken the pointer from the function to its wrapper
981 ptr.weaken();
982 contents->func_schedule.add_wrapper(f, ptr);
983
984 // Weaken the pointer from the wrapper back to the function.
985 WeakenFunctionPtrs weakener(contents.get());
986 wrapper.mutate(&weakener);
987 }
988
is_wrapper() const989 const Call *Function::is_wrapper() const {
990 const vector<Expr> &rhs = values();
991 if (rhs.size() != 1) {
992 return nullptr;
993 }
994 const Call *call = rhs[0].as<Call>();
995 if (!call) {
996 return nullptr;
997 }
998 vector<Expr> expected_args;
999 for (const string &v : args()) {
1000 expected_args.push_back(Variable::make(Int(32), v));
1001 }
1002 Expr expected_rhs =
1003 Call::make(call->type, call->name, expected_args, call->call_type,
1004 call->func, call->value_index, call->image, call->param);
1005 if (equal(rhs[0], expected_rhs)) {
1006 return call;
1007 } else {
1008 return nullptr;
1009 }
1010 }
1011
1012 namespace {
1013
1014 // Replace all calls to functions listed in 'substitutions' with their wrappers.
1015 class SubstituteCalls : public IRMutator {
1016 using IRMutator::visit;
1017
1018 const map<FunctionPtr, FunctionPtr> &substitutions;
1019
visit(const Call * c)1020 Expr visit(const Call *c) override {
1021 Expr expr = IRMutator::visit(c);
1022 c = expr.as<Call>();
1023 internal_assert(c);
1024
1025 if ((c->call_type == Call::Halide) &&
1026 c->func.defined() &&
1027 substitutions.count(c->func)) {
1028 auto it = substitutions.find(c->func);
1029 internal_assert(it != substitutions.end())
1030 << "Function not in environment: " << c->func->name << "\n";
1031 FunctionPtr subs = it->second;
1032 debug(4) << "...Replace call to Func \"" << c->name << "\" with "
1033 << "\"" << subs->name << "\"\n";
1034 expr = Call::make(c->type, subs->name, c->args, c->call_type,
1035 subs, c->value_index,
1036 c->image, c->param);
1037 }
1038 return expr;
1039 }
1040
1041 public:
SubstituteCalls(const map<FunctionPtr,FunctionPtr> & substitutions)1042 SubstituteCalls(const map<FunctionPtr, FunctionPtr> &substitutions)
1043 : substitutions(substitutions) {
1044 }
1045 };
1046
1047 } // anonymous namespace
1048
substitute_calls(const map<FunctionPtr,FunctionPtr> & substitutions)1049 Function &Function::substitute_calls(const map<FunctionPtr, FunctionPtr> &substitutions) {
1050 debug(4) << "Substituting calls in " << name() << "\n";
1051 if (substitutions.empty()) {
1052 return *this;
1053 }
1054 SubstituteCalls subs_calls(substitutions);
1055 contents->mutate(&subs_calls);
1056 return *this;
1057 }
1058
substitute_calls(const Function & orig,const Function & substitute)1059 Function &Function::substitute_calls(const Function &orig, const Function &substitute) {
1060 map<FunctionPtr, FunctionPtr> substitutions;
1061 substitutions.emplace(orig.get_contents(), substitute.get_contents());
1062 return substitute_calls(substitutions);
1063 }
1064
1065 // Deep copy an entire Function DAG.
deep_copy(const vector<Function> & outputs,const map<string,Function> & env)1066 pair<vector<Function>, map<string, Function>> deep_copy(
1067 const vector<Function> &outputs, const map<string, Function> &env) {
1068 vector<Function> copy_outputs;
1069 map<string, Function> copy_env;
1070
1071 // Create empty deep-copies of all Functions in 'env'
1072 DeepCopyMap copied_map; // Original Function -> Deep-copy
1073 IntrusivePtr<FunctionGroup> group(new FunctionGroup);
1074 group->members.resize(env.size());
1075 int i = 0;
1076 for (const auto &iter : env) {
1077 // Make a weak pointer to the function to use for within-group references.
1078 FunctionPtr ptr;
1079 ptr.weak = group.get();
1080 ptr.idx = i;
1081 ptr->name = iter.second.name();
1082 copied_map[iter.second.get_contents()] = ptr;
1083 i++;
1084 }
1085
1086 // Deep copy all Functions in 'env' into their corresponding empty copies
1087 for (const auto &iter : env) {
1088 iter.second.deep_copy(copied_map[iter.second.get_contents()], copied_map);
1089 }
1090
1091 // Need to substitute-in all old Function references in all Exprs referenced
1092 // within the Function with the deep-copy versions
1093 for (auto &iter : copied_map) {
1094 Function(iter.second).substitute_calls(copied_map);
1095 }
1096
1097 // Populate the env with the deep-copy version
1098 for (const auto &iter : copied_map) {
1099 FunctionPtr ptr = iter.second;
1100 copy_env.emplace(iter.first->name, Function(ptr));
1101 }
1102
1103 for (const auto &func : outputs) {
1104 const auto &iter = copied_map.find(func.get_contents());
1105 if (iter != copied_map.end()) {
1106 FunctionPtr ptr = iter->second;
1107 debug(4) << "Adding deep-copied version to outputs: " << func.name() << "\n";
1108 copy_outputs.emplace_back(ptr);
1109 } else {
1110 debug(4) << "Adding original version to outputs: " << func.name() << "\n";
1111 copy_outputs.push_back(func);
1112 }
1113 }
1114
1115 return {copy_outputs, copy_env};
1116 }
1117
1118 } // namespace Internal
1119 } // namespace Halide
1120