1 #include <array>
2 #include <utility>
3 
4 #include "Generator.h"
5 #include "IREquality.h"
6 #include "IROperator.h"
7 #include "IRPrinter.h"
8 #include "ImageParam.h"
9 #include "RDom.h"
10 #include "Simplify.h"
11 #include "Util.h"
12 
13 namespace Halide {
14 
15 using namespace Internal;
16 
17 using std::string;
18 using std::vector;
19 
20 namespace {
21 
22 static const char *const dom_var_names[] = {"$x", "$y", "$z", "$w"};
23 
24 // T is an ImageParam, Buffer<>, Input<Buffer<>>
25 template<typename T>
make_dom_from_dimensions(const T & t,const std::string & name)26 Internal::ReductionDomain make_dom_from_dimensions(const T &t, const std::string &name) {
27     std::vector<Internal::ReductionVariable> vars;
28     for (int i = 0; i < t.dimensions(); i++) {
29         vars.push_back({name + dom_var_names[i],
30                         t.dim(i).min(),
31                         t.dim(i).extent()});
32     }
33 
34     return Internal::ReductionDomain(vars);
35 }
36 
37 }  // namespace
38 
operator Expr() const39 RVar::operator Expr() const {
40     if (!min().defined() || !extent().defined()) {
41         user_error << "Use of undefined RDom dimension: " << (name().empty() ? "<unknown>" : name()) << "\n";
42     }
43     return Variable::make(Int(32), name(), domain());
44 }
45 
min() const46 Expr RVar::min() const {
47     if (_domain.defined()) {
48         return _var().min;
49     } else {
50         return Expr();
51     }
52 }
53 
extent() const54 Expr RVar::extent() const {
55     if (_domain.defined()) {
56         return _var().extent;
57     } else {
58         return Expr();
59     }
60 }
61 
name() const62 const std::string &RVar::name() const {
63     if (_domain.defined()) {
64         return _var().var;
65     } else {
66         return _name;
67     }
68 }
69 
70 template<int N>
build_domain(ReductionVariable (& vars)[N])71 ReductionDomain build_domain(ReductionVariable (&vars)[N]) {
72     vector<ReductionVariable> d(&vars[0], &vars[N]);
73     ReductionDomain dom(d);
74     return dom;
75 }
76 
77 // This just initializes the predefined x, y, z, w members of RDom.
init_vars(const string & name)78 void RDom::init_vars(const string &name) {
79     const std::vector<ReductionVariable> &dom_vars = dom.domain();
80     std::array<RVar *, 4> vars = {{&x, &y, &z, &w}};
81 
82     for (size_t i = 0; i < vars.size(); i++) {
83         if (i < dom_vars.size()) {
84             *(vars[i]) = RVar(dom, i);
85         } else {
86             *(vars[i]) = RVar(name + dom_var_names[i]);
87         }
88     }
89 }
90 
RDom(const ReductionDomain & d)91 RDom::RDom(const ReductionDomain &d)
92     : dom(d) {
93     if (d.defined()) {
94         init_vars("");
95     }
96 }
97 
98 namespace {
99 class CheckRDomBounds : public IRGraphVisitor {
100 
101     using IRGraphVisitor::visit;
102 
visit(const Call * op)103     void visit(const Call *op) override {
104         IRGraphVisitor::visit(op);
105         if (op->call_type == Call::Halide) {
106             offending_func = op->name;
107         }
108     }
109 
visit(const Variable * op)110     void visit(const Variable *op) override {
111         if (!op->param.defined() &&
112             !op->image.defined() &&
113             !internal_vars.contains(op->name)) {
114             offending_free_var = op->name;
115         }
116     }
117 
visit(const Let * op)118     void visit(const Let *op) override {
119         ScopedBinding<int> bind(internal_vars, op->name, 0);
120         IRGraphVisitor::visit(op);
121     }
122     Scope<int> internal_vars;
123 
124 public:
125     string offending_func;
126     string offending_free_var;
127 };
128 }  // namespace
129 
initialize_from_region(const Region & region,string name)130 void RDom::initialize_from_region(const Region &region, string name) {
131     if (name.empty()) {
132         name = make_entity_name(this, "Halide:.*:RDom", 'r');
133     }
134 
135     std::vector<ReductionVariable> vars;
136     for (size_t i = 0; i < region.size(); i++) {
137         CheckRDomBounds checker;
138         user_assert(region[i].min.defined() && region[i].extent.defined())
139             << "The RDom " << name << " may not be constructed with undefined Exprs.\n";
140         region[i].min.accept(&checker);
141         region[i].extent.accept(&checker);
142         user_assert(checker.offending_func.empty())
143             << "The bounds of the RDom " << name
144             << " in dimension " << i
145             << " are:\n"
146             << "  " << region[i].min << " ... " << region[i].extent << "\n"
147             << "These depend on a call to the Func " << checker.offending_func << ".\n"
148             << "The bounds of an RDom may not depend on a call to a Func.\n";
149         user_assert(checker.offending_free_var.empty())
150             << "The bounds of the RDom " << name
151             << " in dimension " << i
152             << " are:\n"
153             << "  " << region[i].min << " ... " << region[i].extent << "\n"
154             << "These depend on the variable " << checker.offending_free_var << ".\n"
155             << "The bounds of an RDom may not depend on a free variable.\n";
156 
157         std::string rvar_uniquifier;
158         switch (i) {
159         case 0:
160             rvar_uniquifier = "x";
161             break;
162         case 1:
163             rvar_uniquifier = "y";
164             break;
165         case 2:
166             rvar_uniquifier = "z";
167             break;
168         case 3:
169             rvar_uniquifier = "w";
170             break;
171         default:
172             rvar_uniquifier = std::to_string(i);
173             break;
174         }
175         ReductionVariable rv;
176         rv.var = name + "$" + rvar_uniquifier;
177         rv.min = cast<int32_t>(region[i].min);
178         rv.extent = cast<int32_t>(region[i].extent);
179         vars.push_back(rv);
180     }
181     dom = ReductionDomain(vars);
182     init_vars(name);
183 }
184 
RDom(const Buffer<> & b)185 RDom::RDom(const Buffer<> &b) {
186     std::string name = unique_name('r');
187     dom = make_dom_from_dimensions(b, name);
188     init_vars(name);
189 }
190 
RDom(const OutputImageParam & p)191 RDom::RDom(const OutputImageParam &p) {
192     const std::string &name = p.name();
193     dom = make_dom_from_dimensions(p, name);
194     init_vars(name);
195 }
196 
dimensions() const197 int RDom::dimensions() const {
198     return (int)dom.domain().size();
199 }
200 
operator [](int i) const201 RVar RDom::operator[](int i) const {
202     if (i == 0) return x;
203     if (i == 1) return y;
204     if (i == 2) return z;
205     if (i == 3) return w;
206     if (i < dimensions()) {
207         return RVar(dom, i);
208     }
209     user_error << "Reduction domain index out of bounds: " << i << "\n";
210     return x;  // Keep the compiler happy
211 }
212 
operator Expr() const213 RDom::operator Expr() const {
214     if (dimensions() != 1) {
215         user_error << "Error: Can't treat this multidimensional RDom as an Expr:\n"
216                    << (*this) << "\n"
217                    << "Only single-dimensional RDoms can be cast to Expr.\n";
218     }
219     return Expr(x);
220 }
221 
operator RVar() const222 RDom::operator RVar() const {
223     if (dimensions() != 1) {
224         user_error << "Error: Can't treat this multidimensional RDom as an RVar:\n"
225                    << (*this) << "\n"
226                    << "Only single-dimensional RDoms can be cast to RVar.\n";
227     }
228     return x;
229 }
230 
where(Expr predicate)231 void RDom::where(Expr predicate) {
232     user_assert(!dom.frozen())
233         << (*this) << " cannot be given a new predicate, because it has already"
234         << " been used in the update definition of some function.\n";
235     user_assert(dom.defined()) << "Error: Can't add predicate to undefined RDom.\n";
236     dom.where(std::move(predicate));
237 }
238 
239 /** Emit an RVar in a human-readable form */
operator <<(std::ostream & stream,const RVar & v)240 std::ostream &operator<<(std::ostream &stream, const RVar &v) {
241     stream << v.name() << "(" << v.min() << ", " << v.extent() << ")";
242     return stream;
243 }
244 
245 /** Emit an RDom in a human-readable form. */
operator <<(std::ostream & stream,const RDom & dom)246 std::ostream &operator<<(std::ostream &stream, const RDom &dom) {
247     stream << "RDom(\n";
248     for (int i = 0; i < dom.dimensions(); i++) {
249         stream << "  " << dom[i] << "\n";
250     }
251     stream << ")";
252     Expr pred = simplify(dom.domain().predicate());
253     if (!equal(const_true(), pred)) {
254         stream << " where (\n  " << pred << ")";
255     }
256     stream << "\n";
257     return stream;
258 }
259 
260 }  // namespace Halide
261