1 #include <set>
2 #include <string>
3 #include <vector>
4
5 #include "ExternFuncArgument.h"
6 #include "Function.h"
7 #include "IRVisitor.h"
8 #include "InferArguments.h"
9
10 namespace Halide {
11 namespace Internal {
12
13 using std::set;
14 using std::string;
15 using std::vector;
16
17 namespace {
18
19 class InferArguments : public IRGraphVisitor {
20 public:
21 vector<InferredArgument> &args;
22
InferArguments(vector<InferredArgument> & a,const vector<Function> & o,const Stmt & body)23 InferArguments(vector<InferredArgument> &a, const vector<Function> &o, const Stmt &body)
24 : args(a), outputs(o) {
25 args.clear();
26 for (const Function &f : outputs) {
27 visit_function(f);
28 }
29 if (body.defined()) {
30 body.accept(this);
31 }
32 }
33
34 private:
35 vector<Function> outputs;
36 set<string> visited_functions;
37
38 using IRGraphVisitor::visit;
39
already_have(const string & name)40 bool already_have(const string &name) {
41 // Ignore dependencies on the output buffers
42 for (const Function &output : outputs) {
43 if (name == output.name() || starts_with(name, output.name() + ".")) {
44 return true;
45 }
46 }
47 for (const InferredArgument &arg : args) {
48 if (arg.arg.name == name) {
49 return true;
50 }
51 }
52 return false;
53 }
54
visit_exprs(const vector<Expr> & v)55 void visit_exprs(const vector<Expr> &v) {
56 for (Expr i : v) {
57 visit_expr(i);
58 }
59 }
60
visit_expr(const Expr & e)61 void visit_expr(const Expr &e) {
62 if (!e.defined()) return;
63 e.accept(this);
64 }
65
visit_function(const Function & func)66 void visit_function(const Function &func) {
67 if (visited_functions.count(func.name())) return;
68 visited_functions.insert(func.name());
69
70 func.accept(this);
71
72 // Function::accept hits all the Expr children of the
73 // Function, but misses the buffers and images that might be
74 // extern arguments.
75 if (func.has_extern_definition()) {
76 for (const ExternFuncArgument &extern_arg : func.extern_arguments()) {
77 if (extern_arg.is_func()) {
78 visit_function(Function(extern_arg.func));
79 } else if (extern_arg.is_buffer()) {
80 include_buffer(extern_arg.buffer);
81 } else if (extern_arg.is_image_param()) {
82 include_parameter(extern_arg.image_param);
83 }
84 }
85 }
86 }
87
include_parameter(const Parameter & p)88 void include_parameter(const Parameter &p) {
89 if (!p.defined()) return;
90 if (already_have(p.name())) return;
91
92 ArgumentEstimates argument_estimates = p.get_argument_estimates();
93 if (!p.is_buffer()) {
94 argument_estimates.scalar_def = p.scalar_expr();
95 argument_estimates.scalar_min = p.min_value();
96 argument_estimates.scalar_max = p.max_value();
97 argument_estimates.scalar_estimate = p.estimate();
98 }
99
100 InferredArgument a = {
101 Argument(p.name(),
102 p.is_buffer() ? Argument::InputBuffer : Argument::InputScalar,
103 p.type(), p.dimensions(), argument_estimates),
104 p,
105 Buffer<>()};
106 args.push_back(a);
107
108 // Visit child expressions
109 visit_expr(argument_estimates.scalar_def);
110 visit_expr(argument_estimates.scalar_min);
111 visit_expr(argument_estimates.scalar_max);
112 visit_expr(argument_estimates.scalar_estimate);
113 for (const auto &be : argument_estimates.buffer_estimates) {
114 visit_expr(be.min);
115 visit_expr(be.extent);
116 }
117
118 if (p.is_buffer()) {
119 for (int i = 0; i < p.dimensions(); i++) {
120 visit_expr(p.min_constraint(i));
121 visit_expr(p.extent_constraint(i));
122 visit_expr(p.stride_constraint(i));
123 }
124 }
125 }
126
include_buffer(const Buffer<> & b)127 void include_buffer(const Buffer<> &b) {
128 if (!b.defined()) return;
129 if (already_have(b.name())) return;
130
131 InferredArgument a = {
132 Argument(b.name(), Argument::InputBuffer, b.type(), b.dimensions(), ArgumentEstimates{}),
133 Parameter(),
134 b};
135 args.push_back(a);
136 }
137
visit(const Load * op)138 void visit(const Load *op) override {
139 IRGraphVisitor::visit(op);
140 include_parameter(op->param);
141 include_buffer(op->image);
142 }
143
visit(const Variable * op)144 void visit(const Variable *op) override {
145 IRGraphVisitor::visit(op);
146 include_parameter(op->param);
147 include_buffer(op->image);
148 }
149
visit(const Call * op)150 void visit(const Call *op) override {
151 IRGraphVisitor::visit(op);
152 if (op->func.defined()) {
153 Function fn(op->func);
154 visit_function(fn);
155 }
156 include_buffer(op->image);
157 include_parameter(op->param);
158 }
159 };
160
161 } // namespace
162
infer_arguments(const Stmt & body,const vector<Function> & outputs)163 vector<InferredArgument> infer_arguments(const Stmt &body, const vector<Function> &outputs) {
164 vector<InferredArgument> inferred_args;
165 // Infer an arguments vector by walking the IR
166 InferArguments infer_args(inferred_args,
167 outputs,
168 body);
169
170 // Sort the Arguments with all buffers first (alphabetical by name),
171 // followed by all non-buffers (alphabetical by name).
172 std::sort(inferred_args.begin(), inferred_args.end());
173
174 return inferred_args;
175 }
176
177 } // namespace Internal
178 } // namespace Halide
179