1 #include <set>
2 #include <string>
3 #include <vector>
4 
5 #include "ExternFuncArgument.h"
6 #include "Function.h"
7 #include "IRVisitor.h"
8 #include "InferArguments.h"
9 
10 namespace Halide {
11 namespace Internal {
12 
13 using std::set;
14 using std::string;
15 using std::vector;
16 
17 namespace {
18 
19 class InferArguments : public IRGraphVisitor {
20 public:
21     vector<InferredArgument> &args;
22 
InferArguments(vector<InferredArgument> & a,const vector<Function> & o,const Stmt & body)23     InferArguments(vector<InferredArgument> &a, const vector<Function> &o, const Stmt &body)
24         : args(a), outputs(o) {
25         args.clear();
26         for (const Function &f : outputs) {
27             visit_function(f);
28         }
29         if (body.defined()) {
30             body.accept(this);
31         }
32     }
33 
34 private:
35     vector<Function> outputs;
36     set<string> visited_functions;
37 
38     using IRGraphVisitor::visit;
39 
already_have(const string & name)40     bool already_have(const string &name) {
41         // Ignore dependencies on the output buffers
42         for (const Function &output : outputs) {
43             if (name == output.name() || starts_with(name, output.name() + ".")) {
44                 return true;
45             }
46         }
47         for (const InferredArgument &arg : args) {
48             if (arg.arg.name == name) {
49                 return true;
50             }
51         }
52         return false;
53     }
54 
visit_exprs(const vector<Expr> & v)55     void visit_exprs(const vector<Expr> &v) {
56         for (Expr i : v) {
57             visit_expr(i);
58         }
59     }
60 
visit_expr(const Expr & e)61     void visit_expr(const Expr &e) {
62         if (!e.defined()) return;
63         e.accept(this);
64     }
65 
visit_function(const Function & func)66     void visit_function(const Function &func) {
67         if (visited_functions.count(func.name())) return;
68         visited_functions.insert(func.name());
69 
70         func.accept(this);
71 
72         // Function::accept hits all the Expr children of the
73         // Function, but misses the buffers and images that might be
74         // extern arguments.
75         if (func.has_extern_definition()) {
76             for (const ExternFuncArgument &extern_arg : func.extern_arguments()) {
77                 if (extern_arg.is_func()) {
78                     visit_function(Function(extern_arg.func));
79                 } else if (extern_arg.is_buffer()) {
80                     include_buffer(extern_arg.buffer);
81                 } else if (extern_arg.is_image_param()) {
82                     include_parameter(extern_arg.image_param);
83                 }
84             }
85         }
86     }
87 
include_parameter(const Parameter & p)88     void include_parameter(const Parameter &p) {
89         if (!p.defined()) return;
90         if (already_have(p.name())) return;
91 
92         ArgumentEstimates argument_estimates = p.get_argument_estimates();
93         if (!p.is_buffer()) {
94             argument_estimates.scalar_def = p.scalar_expr();
95             argument_estimates.scalar_min = p.min_value();
96             argument_estimates.scalar_max = p.max_value();
97             argument_estimates.scalar_estimate = p.estimate();
98         }
99 
100         InferredArgument a = {
101             Argument(p.name(),
102                      p.is_buffer() ? Argument::InputBuffer : Argument::InputScalar,
103                      p.type(), p.dimensions(), argument_estimates),
104             p,
105             Buffer<>()};
106         args.push_back(a);
107 
108         // Visit child expressions
109         visit_expr(argument_estimates.scalar_def);
110         visit_expr(argument_estimates.scalar_min);
111         visit_expr(argument_estimates.scalar_max);
112         visit_expr(argument_estimates.scalar_estimate);
113         for (const auto &be : argument_estimates.buffer_estimates) {
114             visit_expr(be.min);
115             visit_expr(be.extent);
116         }
117 
118         if (p.is_buffer()) {
119             for (int i = 0; i < p.dimensions(); i++) {
120                 visit_expr(p.min_constraint(i));
121                 visit_expr(p.extent_constraint(i));
122                 visit_expr(p.stride_constraint(i));
123             }
124         }
125     }
126 
include_buffer(const Buffer<> & b)127     void include_buffer(const Buffer<> &b) {
128         if (!b.defined()) return;
129         if (already_have(b.name())) return;
130 
131         InferredArgument a = {
132             Argument(b.name(), Argument::InputBuffer, b.type(), b.dimensions(), ArgumentEstimates{}),
133             Parameter(),
134             b};
135         args.push_back(a);
136     }
137 
visit(const Load * op)138     void visit(const Load *op) override {
139         IRGraphVisitor::visit(op);
140         include_parameter(op->param);
141         include_buffer(op->image);
142     }
143 
visit(const Variable * op)144     void visit(const Variable *op) override {
145         IRGraphVisitor::visit(op);
146         include_parameter(op->param);
147         include_buffer(op->image);
148     }
149 
visit(const Call * op)150     void visit(const Call *op) override {
151         IRGraphVisitor::visit(op);
152         if (op->func.defined()) {
153             Function fn(op->func);
154             visit_function(fn);
155         }
156         include_buffer(op->image);
157         include_parameter(op->param);
158     }
159 };
160 
161 }  // namespace
162 
infer_arguments(const Stmt & body,const vector<Function> & outputs)163 vector<InferredArgument> infer_arguments(const Stmt &body, const vector<Function> &outputs) {
164     vector<InferredArgument> inferred_args;
165     // Infer an arguments vector by walking the IR
166     InferArguments infer_args(inferred_args,
167                               outputs,
168                               body);
169 
170     // Sort the Arguments with all buffers first (alphabetical by name),
171     // followed by all non-buffers (alphabetical by name).
172     std::sort(inferred_args.begin(), inferred_args.end());
173 
174     return inferred_args;
175 }
176 
177 }  // namespace Internal
178 }  // namespace Halide
179