1 #include "StorageFlattening.h"
2
3 #include "Bounds.h"
4 #include "Function.h"
5 #include "FuseGPUThreadLoops.h"
6 #include "IRMutator.h"
7 #include "IROperator.h"
8 #include "Parameter.h"
9 #include "Scope.h"
10
11 #include <sstream>
12
13 namespace Halide {
14 namespace Internal {
15
16 using std::map;
17 using std::ostringstream;
18 using std::pair;
19 using std::set;
20 using std::string;
21 using std::vector;
22
23 namespace {
24
25 class FlattenDimensions : public IRMutator {
26 public:
FlattenDimensions(const map<string,pair<Function,int>> & e,const vector<Function> & o,const Target & t)27 FlattenDimensions(const map<string, pair<Function, int>> &e,
28 const vector<Function> &o,
29 const Target &t)
30 : env(e), target(t) {
31 for (auto &f : o) {
32 outputs.insert(f.name());
33 }
34 }
35
36 private:
37 const map<string, pair<Function, int>> &env;
38 set<string> outputs;
39 const Target ⌖
40 Scope<> realizations, shader_scope_realizations;
41 bool in_shader = false;
42
make_shape_var(string name,const string & field,size_t dim,const Buffer<> & buf,const Parameter & param)43 Expr make_shape_var(string name, const string &field, size_t dim,
44 const Buffer<> &buf, const Parameter ¶m) {
45 ReductionDomain rdom;
46 name = name + "." + field + "." + std::to_string(dim);
47 return Variable::make(Int(32), name, buf, param, rdom);
48 }
49
flatten_args(const string & name,vector<Expr> args,const Buffer<> & buf,const Parameter & param)50 Expr flatten_args(const string &name, vector<Expr> args,
51 const Buffer<> &buf, const Parameter ¶m) {
52 bool internal = realizations.contains(name);
53 Expr idx = target.has_large_buffers() ? make_zero(Int(64)) : 0;
54 vector<Expr> mins(args.size()), strides(args.size());
55
56 for (size_t i = 0; i < args.size(); i++) {
57 strides[i] = make_shape_var(name, "stride", i, buf, param);
58 mins[i] = make_shape_var(name, "min", i, buf, param);
59 if (target.has_large_buffers()) {
60 strides[i] = cast<int64_t>(strides[i]);
61 }
62 }
63
64 Expr zero = target.has_large_buffers() ? make_zero(Int(64)) : 0;
65
66 // We peel off constant offsets so that multiple stencil
67 // taps can share the same base address.
68 Expr constant_term = zero;
69 for (size_t i = 0; i < args.size(); i++) {
70 const Add *add = args[i].as<Add>();
71 if (add && is_const(add->b)) {
72 constant_term += strides[i] * add->b;
73 args[i] = add->a;
74 }
75 }
76
77 if (internal) {
78 // f(x, y) -> f[(x-xmin)*xstride + (y-ymin)*ystride] This
79 // strategy makes sense when we expect x to cancel with
80 // something in xmin. We use this for internal allocations.
81 for (size_t i = 0; i < args.size(); i++) {
82 idx += (args[i] - mins[i]) * strides[i];
83 }
84 } else {
85 // f(x, y) -> f[x*stride + y*ystride - (xstride*xmin +
86 // ystride*ymin)]. The idea here is that the last term
87 // will be pulled outside the inner loop. We use this for
88 // external buffers, where the mins and strides are likely
89 // to be symbolic
90 Expr base = zero;
91 for (size_t i = 0; i < args.size(); i++) {
92 idx += args[i] * strides[i];
93 base += mins[i] * strides[i];
94 }
95 idx -= base;
96 }
97
98 if (!is_zero(constant_term)) {
99 idx += constant_term;
100 }
101
102 return idx;
103 }
104
105 using IRMutator::visit;
106
visit(const Realize * op)107 Stmt visit(const Realize *op) override {
108 realizations.push(op->name);
109
110 if (in_shader) {
111 shader_scope_realizations.push(op->name);
112 }
113
114 Stmt body = mutate(op->body);
115
116 // Compute the size
117 vector<Expr> extents;
118 for (size_t i = 0; i < op->bounds.size(); i++) {
119 extents.push_back(op->bounds[i].extent);
120 extents[i] = mutate(extents[i]);
121 }
122 Expr condition = mutate(op->condition);
123
124 realizations.pop(op->name);
125
126 if (in_shader) {
127 shader_scope_realizations.pop(op->name);
128 }
129
130 // The allocation extents of the function taken into account of
131 // the align_storage directives. It is only used to determine the
132 // host allocation size and the strides in halide_buffer_t objects (which
133 // also affects the device allocation in some backends).
134 vector<Expr> allocation_extents(extents.size());
135 vector<int> storage_permutation;
136 {
137 auto iter = env.find(op->name);
138 internal_assert(iter != env.end()) << "Realize node refers to function not in environment.\n";
139 Function f = iter->second.first;
140 const vector<StorageDim> &storage_dims = f.schedule().storage_dims();
141 const vector<string> &args = f.args();
142 for (size_t i = 0; i < storage_dims.size(); i++) {
143 for (size_t j = 0; j < args.size(); j++) {
144 if (args[j] == storage_dims[i].var) {
145 storage_permutation.push_back((int)j);
146 Expr alignment = storage_dims[i].alignment;
147 if (alignment.defined()) {
148 allocation_extents[j] = ((extents[j] + alignment - 1) / alignment) * alignment;
149 } else {
150 allocation_extents[j] = extents[j];
151 }
152 }
153 }
154 internal_assert(storage_permutation.size() == i + 1);
155 }
156 }
157
158 internal_assert(storage_permutation.size() == op->bounds.size());
159
160 Stmt stmt = body;
161 internal_assert(op->types.size() == 1);
162
163 // Make the names for the mins, extents, and strides
164 int dims = op->bounds.size();
165 vector<string> min_name(dims), extent_name(dims), stride_name(dims);
166 for (int i = 0; i < dims; i++) {
167 string d = std::to_string(i);
168 min_name[i] = op->name + ".min." + d;
169 stride_name[i] = op->name + ".stride." + d;
170 extent_name[i] = op->name + ".extent." + d;
171 }
172 vector<Expr> min_var(dims), extent_var(dims), stride_var(dims);
173 for (int i = 0; i < dims; i++) {
174 min_var[i] = Variable::make(Int(32), min_name[i]);
175 extent_var[i] = Variable::make(Int(32), extent_name[i]);
176 stride_var[i] = Variable::make(Int(32), stride_name[i]);
177 }
178
179 // Create a halide_buffer_t object for this allocation.
180 BufferBuilder builder;
181 builder.host = Variable::make(Handle(), op->name);
182 builder.type = op->types[0];
183 builder.dimensions = dims;
184 for (int i = 0; i < dims; i++) {
185 builder.mins.push_back(min_var[i]);
186 builder.extents.push_back(extent_var[i]);
187 builder.strides.push_back(stride_var[i]);
188 }
189 stmt = LetStmt::make(op->name + ".buffer", builder.build(), stmt);
190
191 // Make the allocation node
192 stmt = Allocate::make(op->name, op->types[0], op->memory_type, allocation_extents, condition, stmt);
193
194 // Compute the strides
195 for (int i = (int)op->bounds.size() - 1; i > 0; i--) {
196 int prev_j = storage_permutation[i - 1];
197 int j = storage_permutation[i];
198 Expr stride = stride_var[prev_j] * allocation_extents[prev_j];
199 stmt = LetStmt::make(stride_name[j], stride, stmt);
200 }
201
202 // Innermost stride is one
203 if (dims > 0) {
204 int innermost = storage_permutation.empty() ? 0 : storage_permutation[0];
205 stmt = LetStmt::make(stride_name[innermost], 1, stmt);
206 }
207
208 // Assign the mins and extents stored
209 for (size_t i = op->bounds.size(); i > 0; i--) {
210 stmt = LetStmt::make(min_name[i - 1], op->bounds[i - 1].min, stmt);
211 stmt = LetStmt::make(extent_name[i - 1], extents[i - 1], stmt);
212 }
213 return stmt;
214 }
215
visit(const Provide * op)216 Stmt visit(const Provide *op) override {
217 internal_assert(op->values.size() == 1);
218
219 Parameter output_buf;
220 auto it = env.find(op->name);
221 if (it != env.end()) {
222 const Function &f = it->second.first;
223 int idx = it->second.second;
224
225 // We only want to do this for actual pipeline outputs,
226 // even though every Function has an output buffer. Any
227 // constraints you set on the output buffer of a Func that
228 // isn't actually an output is ignored. This is a language
229 // wart.
230 if (outputs.count(f.name())) {
231 output_buf = f.output_buffers()[idx];
232 }
233 }
234
235 Expr value = mutate(op->values[0]);
236 if (in_shader && !shader_scope_realizations.contains(op->name)) {
237 user_assert(op->args.size() == 3)
238 << "Image stores require three coordinates.\n";
239 Expr buffer_var =
240 Variable::make(type_of<halide_buffer_t *>(), op->name + ".buffer", output_buf);
241 vector<Expr> args = {
242 op->name, buffer_var,
243 op->args[0], op->args[1], op->args[2],
244 value};
245 Expr store = Call::make(value.type(), Call::image_store,
246 args, Call::Intrinsic);
247 return Evaluate::make(store);
248 } else {
249 Expr idx = mutate(flatten_args(op->name, op->args, Buffer<>(), output_buf));
250 return Store::make(op->name, value, idx, output_buf, const_true(value.type().lanes()), ModulusRemainder());
251 }
252 }
253
visit(const Call * op)254 Expr visit(const Call *op) override {
255 if (op->call_type == Call::Halide ||
256 op->call_type == Call::Image) {
257
258 internal_assert(op->value_index == 0);
259
260 if (in_shader && !shader_scope_realizations.contains(op->name)) {
261 ReductionDomain rdom;
262 Expr buffer_var =
263 Variable::make(type_of<halide_buffer_t *>(), op->name + ".buffer",
264 op->image, op->param, rdom);
265
266 // Create image_load("name", name.buffer, x - x_min, x_extent,
267 // y - y_min, y_extent, ...). Extents can be used by
268 // successive passes. OpenGL, for example, uses them
269 // for coordinate normalization.
270 vector<Expr> args(2);
271 args[0] = op->name;
272 args[1] = buffer_var;
273 for (size_t i = 0; i < op->args.size(); i++) {
274 Expr min = make_shape_var(op->name, "min", i, op->image, op->param);
275 Expr extent = make_shape_var(op->name, "extent", i, op->image, op->param);
276 args.push_back(mutate(op->args[i]) - min);
277 args.push_back(extent);
278 }
279 for (size_t i = op->args.size(); i < 3; i++) {
280 args.emplace_back(0);
281 args.emplace_back(1);
282 }
283
284 return Call::make(op->type,
285 Call::image_load,
286 args,
287 Call::PureIntrinsic,
288 FunctionPtr(),
289 0,
290 op->image,
291 op->param);
292 } else {
293 Expr idx = mutate(flatten_args(op->name, op->args, op->image, op->param));
294 return Load::make(op->type, op->name, idx, op->image, op->param,
295 const_true(op->type.lanes()), ModulusRemainder());
296 }
297
298 } else {
299 return IRMutator::visit(op);
300 }
301 }
302
visit(const Prefetch * op)303 Stmt visit(const Prefetch *op) override {
304 internal_assert(op->types.size() == 1)
305 << "Prefetch from multi-dimensional halide tuple should have been split\n";
306
307 Expr condition = mutate(op->condition);
308
309 vector<Expr> prefetch_min(op->bounds.size());
310 vector<Expr> prefetch_extent(op->bounds.size());
311 vector<Expr> prefetch_stride(op->bounds.size());
312 for (size_t i = 0; i < op->bounds.size(); i++) {
313 prefetch_min[i] = mutate(op->bounds[i].min);
314 prefetch_extent[i] = mutate(op->bounds[i].extent);
315 prefetch_stride[i] = Variable::make(Int(32), op->name + ".stride." + std::to_string(i), op->prefetch.param);
316 }
317
318 Expr base_offset = mutate(flatten_args(op->name, prefetch_min, Buffer<>(), op->prefetch.param));
319 Expr base_address = Variable::make(Handle(), op->name);
320 vector<Expr> args = {base_address, base_offset};
321
322 auto iter = env.find(op->name);
323 if (iter != env.end()) {
324 // Order the <min, extent> args based on the storage dims (i.e. innermost
325 // dimension should be first in args)
326 vector<int> storage_permutation;
327 {
328 Function f = iter->second.first;
329 const vector<StorageDim> &storage_dims = f.schedule().storage_dims();
330 const vector<string> &args = f.args();
331 for (size_t i = 0; i < storage_dims.size(); i++) {
332 for (size_t j = 0; j < args.size(); j++) {
333 if (args[j] == storage_dims[i].var) {
334 storage_permutation.push_back((int)j);
335 }
336 }
337 internal_assert(storage_permutation.size() == i + 1);
338 }
339 }
340 internal_assert(storage_permutation.size() == op->bounds.size());
341
342 for (size_t i = 0; i < op->bounds.size(); i++) {
343 internal_assert(storage_permutation[i] < (int)op->bounds.size());
344 args.push_back(prefetch_extent[storage_permutation[i]]);
345 args.push_back(prefetch_stride[storage_permutation[i]]);
346 }
347 } else {
348 for (size_t i = 0; i < op->bounds.size(); i++) {
349 args.push_back(prefetch_extent[i]);
350 args.push_back(prefetch_stride[i]);
351 }
352 }
353
354 // TODO: Consider generating a prefetch call for each tuple element.
355 Stmt prefetch_call = Evaluate::make(Call::make(op->types[0], Call::prefetch, args, Call::Intrinsic));
356 if (!is_one(condition)) {
357 prefetch_call = IfThenElse::make(condition, prefetch_call);
358 }
359 Stmt body = mutate(op->body);
360 return Block::make(prefetch_call, body);
361 }
362
visit(const For * op)363 Stmt visit(const For *op) override {
364 bool old_in_shader = in_shader;
365 if ((op->for_type == ForType::GPUBlock ||
366 op->for_type == ForType::GPUThread) &&
367 op->device_api == DeviceAPI::GLSL) {
368 in_shader = true;
369 }
370 Stmt stmt = IRMutator::visit(op);
371 in_shader = old_in_shader;
372 return stmt;
373 }
374 };
375
376 // Realizations, stores, and loads must all be on types that are
377 // multiples of 8-bits. This really only affects bools
378 class PromoteToMemoryType : public IRMutator {
379 using IRMutator::visit;
380
upgrade(Type t)381 Type upgrade(Type t) {
382 return t.with_bits(((t.bits() + 7) / 8) * 8);
383 }
384
visit(const Load * op)385 Expr visit(const Load *op) override {
386 Type t = upgrade(op->type);
387 if (t != op->type) {
388 return Cast::make(op->type,
389 Load::make(t, op->name, mutate(op->index),
390 op->image, op->param, mutate(op->predicate), ModulusRemainder()));
391 } else {
392 return IRMutator::visit(op);
393 }
394 }
395
visit(const Store * op)396 Stmt visit(const Store *op) override {
397 Type t = upgrade(op->value.type());
398 if (t != op->value.type()) {
399 return Store::make(op->name, Cast::make(t, mutate(op->value)), mutate(op->index),
400 op->param, mutate(op->predicate), ModulusRemainder());
401 } else {
402 return IRMutator::visit(op);
403 }
404 }
405
visit(const Allocate * op)406 Stmt visit(const Allocate *op) override {
407 Type t = upgrade(op->type);
408 if (t != op->type) {
409 vector<Expr> extents;
410 for (Expr e : op->extents) {
411 extents.push_back(mutate(e));
412 }
413 return Allocate::make(op->name, t, op->memory_type, extents,
414 mutate(op->condition), mutate(op->body),
415 mutate(op->new_expr), op->free_function);
416 } else {
417 return IRMutator::visit(op);
418 }
419 }
420 };
421
422 } // namespace
423
storage_flattening(Stmt s,const vector<Function> & outputs,const map<string,Function> & env,const Target & target)424 Stmt storage_flattening(Stmt s,
425 const vector<Function> &outputs,
426 const map<string, Function> &env,
427 const Target &target) {
428 // The OpenGL backend requires loop mins to be zero'd at this point.
429 s = zero_gpu_loop_mins(s);
430
431 // Make an environment that makes it easier to figure out which
432 // Function corresponds to a tuple component. foo.0, foo.1, foo.2,
433 // all point to the function foo.
434 map<string, pair<Function, int>> tuple_env;
435 for (auto p : env) {
436 if (p.second.outputs() > 1) {
437 for (int i = 0; i < p.second.outputs(); i++) {
438 tuple_env[p.first + "." + std::to_string(i)] = {p.second, i};
439 }
440 } else {
441 tuple_env[p.first] = {p.second, 0};
442 }
443 }
444
445 s = FlattenDimensions(tuple_env, outputs, target).mutate(s);
446 s = PromoteToMemoryType().mutate(s);
447 return s;
448 }
449
450 } // namespace Internal
451 } // namespace Halide
452