1 #include "StorageFlattening.h"
2 
3 #include "Bounds.h"
4 #include "Function.h"
5 #include "FuseGPUThreadLoops.h"
6 #include "IRMutator.h"
7 #include "IROperator.h"
8 #include "Parameter.h"
9 #include "Scope.h"
10 
11 #include <sstream>
12 
13 namespace Halide {
14 namespace Internal {
15 
16 using std::map;
17 using std::ostringstream;
18 using std::pair;
19 using std::set;
20 using std::string;
21 using std::vector;
22 
23 namespace {
24 
25 class FlattenDimensions : public IRMutator {
26 public:
FlattenDimensions(const map<string,pair<Function,int>> & e,const vector<Function> & o,const Target & t)27     FlattenDimensions(const map<string, pair<Function, int>> &e,
28                       const vector<Function> &o,
29                       const Target &t)
30         : env(e), target(t) {
31         for (auto &f : o) {
32             outputs.insert(f.name());
33         }
34     }
35 
36 private:
37     const map<string, pair<Function, int>> &env;
38     set<string> outputs;
39     const Target &target;
40     Scope<> realizations, shader_scope_realizations;
41     bool in_shader = false;
42 
make_shape_var(string name,const string & field,size_t dim,const Buffer<> & buf,const Parameter & param)43     Expr make_shape_var(string name, const string &field, size_t dim,
44                         const Buffer<> &buf, const Parameter &param) {
45         ReductionDomain rdom;
46         name = name + "." + field + "." + std::to_string(dim);
47         return Variable::make(Int(32), name, buf, param, rdom);
48     }
49 
flatten_args(const string & name,vector<Expr> args,const Buffer<> & buf,const Parameter & param)50     Expr flatten_args(const string &name, vector<Expr> args,
51                       const Buffer<> &buf, const Parameter &param) {
52         bool internal = realizations.contains(name);
53         Expr idx = target.has_large_buffers() ? make_zero(Int(64)) : 0;
54         vector<Expr> mins(args.size()), strides(args.size());
55 
56         for (size_t i = 0; i < args.size(); i++) {
57             strides[i] = make_shape_var(name, "stride", i, buf, param);
58             mins[i] = make_shape_var(name, "min", i, buf, param);
59             if (target.has_large_buffers()) {
60                 strides[i] = cast<int64_t>(strides[i]);
61             }
62         }
63 
64         Expr zero = target.has_large_buffers() ? make_zero(Int(64)) : 0;
65 
66         // We peel off constant offsets so that multiple stencil
67         // taps can share the same base address.
68         Expr constant_term = zero;
69         for (size_t i = 0; i < args.size(); i++) {
70             const Add *add = args[i].as<Add>();
71             if (add && is_const(add->b)) {
72                 constant_term += strides[i] * add->b;
73                 args[i] = add->a;
74             }
75         }
76 
77         if (internal) {
78             // f(x, y) -> f[(x-xmin)*xstride + (y-ymin)*ystride] This
79             // strategy makes sense when we expect x to cancel with
80             // something in xmin.  We use this for internal allocations.
81             for (size_t i = 0; i < args.size(); i++) {
82                 idx += (args[i] - mins[i]) * strides[i];
83             }
84         } else {
85             // f(x, y) -> f[x*stride + y*ystride - (xstride*xmin +
86             // ystride*ymin)]. The idea here is that the last term
87             // will be pulled outside the inner loop. We use this for
88             // external buffers, where the mins and strides are likely
89             // to be symbolic
90             Expr base = zero;
91             for (size_t i = 0; i < args.size(); i++) {
92                 idx += args[i] * strides[i];
93                 base += mins[i] * strides[i];
94             }
95             idx -= base;
96         }
97 
98         if (!is_zero(constant_term)) {
99             idx += constant_term;
100         }
101 
102         return idx;
103     }
104 
105     using IRMutator::visit;
106 
visit(const Realize * op)107     Stmt visit(const Realize *op) override {
108         realizations.push(op->name);
109 
110         if (in_shader) {
111             shader_scope_realizations.push(op->name);
112         }
113 
114         Stmt body = mutate(op->body);
115 
116         // Compute the size
117         vector<Expr> extents;
118         for (size_t i = 0; i < op->bounds.size(); i++) {
119             extents.push_back(op->bounds[i].extent);
120             extents[i] = mutate(extents[i]);
121         }
122         Expr condition = mutate(op->condition);
123 
124         realizations.pop(op->name);
125 
126         if (in_shader) {
127             shader_scope_realizations.pop(op->name);
128         }
129 
130         // The allocation extents of the function taken into account of
131         // the align_storage directives. It is only used to determine the
132         // host allocation size and the strides in halide_buffer_t objects (which
133         // also affects the device allocation in some backends).
134         vector<Expr> allocation_extents(extents.size());
135         vector<int> storage_permutation;
136         {
137             auto iter = env.find(op->name);
138             internal_assert(iter != env.end()) << "Realize node refers to function not in environment.\n";
139             Function f = iter->second.first;
140             const vector<StorageDim> &storage_dims = f.schedule().storage_dims();
141             const vector<string> &args = f.args();
142             for (size_t i = 0; i < storage_dims.size(); i++) {
143                 for (size_t j = 0; j < args.size(); j++) {
144                     if (args[j] == storage_dims[i].var) {
145                         storage_permutation.push_back((int)j);
146                         Expr alignment = storage_dims[i].alignment;
147                         if (alignment.defined()) {
148                             allocation_extents[j] = ((extents[j] + alignment - 1) / alignment) * alignment;
149                         } else {
150                             allocation_extents[j] = extents[j];
151                         }
152                     }
153                 }
154                 internal_assert(storage_permutation.size() == i + 1);
155             }
156         }
157 
158         internal_assert(storage_permutation.size() == op->bounds.size());
159 
160         Stmt stmt = body;
161         internal_assert(op->types.size() == 1);
162 
163         // Make the names for the mins, extents, and strides
164         int dims = op->bounds.size();
165         vector<string> min_name(dims), extent_name(dims), stride_name(dims);
166         for (int i = 0; i < dims; i++) {
167             string d = std::to_string(i);
168             min_name[i] = op->name + ".min." + d;
169             stride_name[i] = op->name + ".stride." + d;
170             extent_name[i] = op->name + ".extent." + d;
171         }
172         vector<Expr> min_var(dims), extent_var(dims), stride_var(dims);
173         for (int i = 0; i < dims; i++) {
174             min_var[i] = Variable::make(Int(32), min_name[i]);
175             extent_var[i] = Variable::make(Int(32), extent_name[i]);
176             stride_var[i] = Variable::make(Int(32), stride_name[i]);
177         }
178 
179         // Create a halide_buffer_t object for this allocation.
180         BufferBuilder builder;
181         builder.host = Variable::make(Handle(), op->name);
182         builder.type = op->types[0];
183         builder.dimensions = dims;
184         for (int i = 0; i < dims; i++) {
185             builder.mins.push_back(min_var[i]);
186             builder.extents.push_back(extent_var[i]);
187             builder.strides.push_back(stride_var[i]);
188         }
189         stmt = LetStmt::make(op->name + ".buffer", builder.build(), stmt);
190 
191         // Make the allocation node
192         stmt = Allocate::make(op->name, op->types[0], op->memory_type, allocation_extents, condition, stmt);
193 
194         // Compute the strides
195         for (int i = (int)op->bounds.size() - 1; i > 0; i--) {
196             int prev_j = storage_permutation[i - 1];
197             int j = storage_permutation[i];
198             Expr stride = stride_var[prev_j] * allocation_extents[prev_j];
199             stmt = LetStmt::make(stride_name[j], stride, stmt);
200         }
201 
202         // Innermost stride is one
203         if (dims > 0) {
204             int innermost = storage_permutation.empty() ? 0 : storage_permutation[0];
205             stmt = LetStmt::make(stride_name[innermost], 1, stmt);
206         }
207 
208         // Assign the mins and extents stored
209         for (size_t i = op->bounds.size(); i > 0; i--) {
210             stmt = LetStmt::make(min_name[i - 1], op->bounds[i - 1].min, stmt);
211             stmt = LetStmt::make(extent_name[i - 1], extents[i - 1], stmt);
212         }
213         return stmt;
214     }
215 
visit(const Provide * op)216     Stmt visit(const Provide *op) override {
217         internal_assert(op->values.size() == 1);
218 
219         Parameter output_buf;
220         auto it = env.find(op->name);
221         if (it != env.end()) {
222             const Function &f = it->second.first;
223             int idx = it->second.second;
224 
225             // We only want to do this for actual pipeline outputs,
226             // even though every Function has an output buffer. Any
227             // constraints you set on the output buffer of a Func that
228             // isn't actually an output is ignored. This is a language
229             // wart.
230             if (outputs.count(f.name())) {
231                 output_buf = f.output_buffers()[idx];
232             }
233         }
234 
235         Expr value = mutate(op->values[0]);
236         if (in_shader && !shader_scope_realizations.contains(op->name)) {
237             user_assert(op->args.size() == 3)
238                 << "Image stores require three coordinates.\n";
239             Expr buffer_var =
240                 Variable::make(type_of<halide_buffer_t *>(), op->name + ".buffer", output_buf);
241             vector<Expr> args = {
242                 op->name, buffer_var,
243                 op->args[0], op->args[1], op->args[2],
244                 value};
245             Expr store = Call::make(value.type(), Call::image_store,
246                                     args, Call::Intrinsic);
247             return Evaluate::make(store);
248         } else {
249             Expr idx = mutate(flatten_args(op->name, op->args, Buffer<>(), output_buf));
250             return Store::make(op->name, value, idx, output_buf, const_true(value.type().lanes()), ModulusRemainder());
251         }
252     }
253 
visit(const Call * op)254     Expr visit(const Call *op) override {
255         if (op->call_type == Call::Halide ||
256             op->call_type == Call::Image) {
257 
258             internal_assert(op->value_index == 0);
259 
260             if (in_shader && !shader_scope_realizations.contains(op->name)) {
261                 ReductionDomain rdom;
262                 Expr buffer_var =
263                     Variable::make(type_of<halide_buffer_t *>(), op->name + ".buffer",
264                                    op->image, op->param, rdom);
265 
266                 // Create image_load("name", name.buffer, x - x_min, x_extent,
267                 // y - y_min, y_extent, ...).  Extents can be used by
268                 // successive passes. OpenGL, for example, uses them
269                 // for coordinate normalization.
270                 vector<Expr> args(2);
271                 args[0] = op->name;
272                 args[1] = buffer_var;
273                 for (size_t i = 0; i < op->args.size(); i++) {
274                     Expr min = make_shape_var(op->name, "min", i, op->image, op->param);
275                     Expr extent = make_shape_var(op->name, "extent", i, op->image, op->param);
276                     args.push_back(mutate(op->args[i]) - min);
277                     args.push_back(extent);
278                 }
279                 for (size_t i = op->args.size(); i < 3; i++) {
280                     args.emplace_back(0);
281                     args.emplace_back(1);
282                 }
283 
284                 return Call::make(op->type,
285                                   Call::image_load,
286                                   args,
287                                   Call::PureIntrinsic,
288                                   FunctionPtr(),
289                                   0,
290                                   op->image,
291                                   op->param);
292             } else {
293                 Expr idx = mutate(flatten_args(op->name, op->args, op->image, op->param));
294                 return Load::make(op->type, op->name, idx, op->image, op->param,
295                                   const_true(op->type.lanes()), ModulusRemainder());
296             }
297 
298         } else {
299             return IRMutator::visit(op);
300         }
301     }
302 
visit(const Prefetch * op)303     Stmt visit(const Prefetch *op) override {
304         internal_assert(op->types.size() == 1)
305             << "Prefetch from multi-dimensional halide tuple should have been split\n";
306 
307         Expr condition = mutate(op->condition);
308 
309         vector<Expr> prefetch_min(op->bounds.size());
310         vector<Expr> prefetch_extent(op->bounds.size());
311         vector<Expr> prefetch_stride(op->bounds.size());
312         for (size_t i = 0; i < op->bounds.size(); i++) {
313             prefetch_min[i] = mutate(op->bounds[i].min);
314             prefetch_extent[i] = mutate(op->bounds[i].extent);
315             prefetch_stride[i] = Variable::make(Int(32), op->name + ".stride." + std::to_string(i), op->prefetch.param);
316         }
317 
318         Expr base_offset = mutate(flatten_args(op->name, prefetch_min, Buffer<>(), op->prefetch.param));
319         Expr base_address = Variable::make(Handle(), op->name);
320         vector<Expr> args = {base_address, base_offset};
321 
322         auto iter = env.find(op->name);
323         if (iter != env.end()) {
324             // Order the <min, extent> args based on the storage dims (i.e. innermost
325             // dimension should be first in args)
326             vector<int> storage_permutation;
327             {
328                 Function f = iter->second.first;
329                 const vector<StorageDim> &storage_dims = f.schedule().storage_dims();
330                 const vector<string> &args = f.args();
331                 for (size_t i = 0; i < storage_dims.size(); i++) {
332                     for (size_t j = 0; j < args.size(); j++) {
333                         if (args[j] == storage_dims[i].var) {
334                             storage_permutation.push_back((int)j);
335                         }
336                     }
337                     internal_assert(storage_permutation.size() == i + 1);
338                 }
339             }
340             internal_assert(storage_permutation.size() == op->bounds.size());
341 
342             for (size_t i = 0; i < op->bounds.size(); i++) {
343                 internal_assert(storage_permutation[i] < (int)op->bounds.size());
344                 args.push_back(prefetch_extent[storage_permutation[i]]);
345                 args.push_back(prefetch_stride[storage_permutation[i]]);
346             }
347         } else {
348             for (size_t i = 0; i < op->bounds.size(); i++) {
349                 args.push_back(prefetch_extent[i]);
350                 args.push_back(prefetch_stride[i]);
351             }
352         }
353 
354         // TODO: Consider generating a prefetch call for each tuple element.
355         Stmt prefetch_call = Evaluate::make(Call::make(op->types[0], Call::prefetch, args, Call::Intrinsic));
356         if (!is_one(condition)) {
357             prefetch_call = IfThenElse::make(condition, prefetch_call);
358         }
359         Stmt body = mutate(op->body);
360         return Block::make(prefetch_call, body);
361     }
362 
visit(const For * op)363     Stmt visit(const For *op) override {
364         bool old_in_shader = in_shader;
365         if ((op->for_type == ForType::GPUBlock ||
366              op->for_type == ForType::GPUThread) &&
367             op->device_api == DeviceAPI::GLSL) {
368             in_shader = true;
369         }
370         Stmt stmt = IRMutator::visit(op);
371         in_shader = old_in_shader;
372         return stmt;
373     }
374 };
375 
376 // Realizations, stores, and loads must all be on types that are
377 // multiples of 8-bits. This really only affects bools
378 class PromoteToMemoryType : public IRMutator {
379     using IRMutator::visit;
380 
upgrade(Type t)381     Type upgrade(Type t) {
382         return t.with_bits(((t.bits() + 7) / 8) * 8);
383     }
384 
visit(const Load * op)385     Expr visit(const Load *op) override {
386         Type t = upgrade(op->type);
387         if (t != op->type) {
388             return Cast::make(op->type,
389                               Load::make(t, op->name, mutate(op->index),
390                                          op->image, op->param, mutate(op->predicate), ModulusRemainder()));
391         } else {
392             return IRMutator::visit(op);
393         }
394     }
395 
visit(const Store * op)396     Stmt visit(const Store *op) override {
397         Type t = upgrade(op->value.type());
398         if (t != op->value.type()) {
399             return Store::make(op->name, Cast::make(t, mutate(op->value)), mutate(op->index),
400                                op->param, mutate(op->predicate), ModulusRemainder());
401         } else {
402             return IRMutator::visit(op);
403         }
404     }
405 
visit(const Allocate * op)406     Stmt visit(const Allocate *op) override {
407         Type t = upgrade(op->type);
408         if (t != op->type) {
409             vector<Expr> extents;
410             for (Expr e : op->extents) {
411                 extents.push_back(mutate(e));
412             }
413             return Allocate::make(op->name, t, op->memory_type, extents,
414                                   mutate(op->condition), mutate(op->body),
415                                   mutate(op->new_expr), op->free_function);
416         } else {
417             return IRMutator::visit(op);
418         }
419     }
420 };
421 
422 }  // namespace
423 
storage_flattening(Stmt s,const vector<Function> & outputs,const map<string,Function> & env,const Target & target)424 Stmt storage_flattening(Stmt s,
425                         const vector<Function> &outputs,
426                         const map<string, Function> &env,
427                         const Target &target) {
428     // The OpenGL backend requires loop mins to be zero'd at this point.
429     s = zero_gpu_loop_mins(s);
430 
431     // Make an environment that makes it easier to figure out which
432     // Function corresponds to a tuple component. foo.0, foo.1, foo.2,
433     // all point to the function foo.
434     map<string, pair<Function, int>> tuple_env;
435     for (auto p : env) {
436         if (p.second.outputs() > 1) {
437             for (int i = 0; i < p.second.outputs(); i++) {
438                 tuple_env[p.first + "." + std::to_string(i)] = {p.second, i};
439             }
440         } else {
441             tuple_env[p.first] = {p.second, 0};
442         }
443     }
444 
445     s = FlattenDimensions(tuple_env, outputs, target).mutate(s);
446     s = PromoteToMemoryType().mutate(s);
447     return s;
448 }
449 
450 }  // namespace Internal
451 }  // namespace Halide
452